diff --git a/messaging/db/migrate_test.go b/messaging/db/migrate_test.go deleted file mode 100644 index c81e375ad..000000000 --- a/messaging/db/migrate_test.go +++ /dev/null @@ -1,22 +0,0 @@ -// +build migrations - -package db - -import ( - "os" - "testing" - - "github.com/cortezaproject/corteza-server/pkg/logger" - "github.com/titpetric/factory" - dbLogger "github.com/titpetric/factory/logger" -) - -func TestMigrations(t *testing.T) { - factory.Database.Add("messaging", os.Getenv("MESSAGING_DB_DSN")) - db := factory.Database.MustGet("messaging") - db.SetLogger(dbLogger.Default{}) - - if err := Migrate(db, logger.Default()); err != nil { - t.Fatalf("Unexpected error: %#v", err) - } -} diff --git a/messaging/repository/attachment_test.go b/messaging/repository/attachment_test.go deleted file mode 100644 index eebf243d1..000000000 --- a/messaging/repository/attachment_test.go +++ /dev/null @@ -1,50 +0,0 @@ -// +build integration - -package repository - -import ( - "context" - "testing" - - "github.com/titpetric/factory" - - "github.com/cortezaproject/corteza-server/internal/test" - "github.com/cortezaproject/corteza-server/messaging/types" -) - -func TestAttachment(t *testing.T) { - var err error - - if testing.Short() { - t.Skip("skipping test in short mode.") - return - } - - rpo := Attachment(context.Background(), factory.Database.MustGet("messaging")) - att := &types.Attachment{} - - att.UserID = 1 - - { - att, err = rpo.CreateAttachment(att) - test.Assert(t, err == nil, "CreateAttachment error: %+v", err) - test.Assert(t, att.UserID == 1, "Changes were not stored") - - { - att, err = rpo.FindAttachmentByID(att.ID) - test.Assert(t, err == nil, "FindAttachmentByID error: %+v", err) - test.Assert(t, att.UserID == 1, "Changes were not stored") - } - - { - att, err = rpo.FindAttachmentByID(att.ID) - test.Assert(t, err == nil, "FindAttachmentByMessageID error: %+v", err) - test.Assert(t, att != nil, "No results found") - } - - { - err = rpo.DeleteAttachmentByID(att.ID) - test.Assert(t, err == nil, "DeleteAttachmentByID error: %+v", err) - } - } -} diff --git a/messaging/repository/channel_test.go b/messaging/repository/channel_test.go deleted file mode 100644 index a079b77b0..000000000 --- a/messaging/repository/channel_test.go +++ /dev/null @@ -1,76 +0,0 @@ -// +build integration - -package repository - -import ( - "context" - "testing" - - "github.com/titpetric/factory" - - "github.com/cortezaproject/corteza-server/internal/test" - "github.com/cortezaproject/corteza-server/messaging/types" -) - -func TestChannel(t *testing.T) { - var err error - - if testing.Short() { - t.Skip("skipping test in short mode.") - return - } - - rpo := Channel(context.Background(), factory.Database.MustGet("messaging")) - chn := &types.Channel{} - - var name1, name2 = "Test channel v1", "Test channel v2" - - var cc []*types.Channel - - { - chn.Name = name1 - chn, err = rpo.Create(chn) - test.Assert(t, err == nil, "CreateChannel error: %+v", err) - test.Assert(t, chn.Name == name1, "Changes were not stored") - - { - chn.Name = name2 - - chn, err = rpo.Update(chn) - test.Assert(t, err == nil, "UpdateChannel error: %+v", err) - test.Assert(t, chn.Name == name2, "Changes were not stored") - } - - { - chn, err = rpo.FindByID(chn.ID) - test.Assert(t, err == nil, "FindByID error: %+v", err) - test.Assert(t, chn.Name == name2, "Changes were not stored") - } - - { - cc, err = rpo.Find(&types.ChannelFilter{Query: name2}) - test.Assert(t, err == nil, "FindChannels error: %+v", err) - test.Assert(t, len(cc) > 0, "No results found") - } - - { - err = rpo.ArchiveByID(chn.ID) - test.Assert(t, err == nil, "ArchiveByID error: %+v", err) - } - - { - err = rpo.UnarchiveByID(chn.ID) - test.Assert(t, err == nil, "UnarchiveByID error: %+v", err) - } - - { - err = rpo.DeleteByID(chn.ID) - test.Assert(t, err == nil, "DeleteByID error: %+v", err) - } - - { - err = rpo.UndeleteByID(chn.ID) - test.Assert(t, err == nil, "UndeleteByID error: %+v", err) - } - } -} diff --git a/messaging/repository/main_test.go b/messaging/repository/main_test.go deleted file mode 100644 index 371e9a5a5..000000000 --- a/messaging/repository/main_test.go +++ /dev/null @@ -1,29 +0,0 @@ -// +build integration - -package repository - -import ( - "fmt" - "os" - "testing" - - "github.com/titpetric/factory" - - migrate "github.com/cortezaproject/corteza-server/messaging/db" - "github.com/cortezaproject/corteza-server/pkg/logger" - dbLogger "github.com/titpetric/factory/logger" -) - -func TestMain(m *testing.M) { - factory.Database.Add("messaging", os.Getenv("MESSAGING_DB_DSN")) - db := factory.Database.MustGet("messaging") - db.SetLogger(dbLogger.Default{}) - - // migrate database schema - if err := migrate.Migrate(db, logger.Default()); err != nil { - fmt.Printf("Error running migrations: %+v\n", err) - return - } - - os.Exit(m.Run()) -} diff --git a/messaging/repository/message_flag_test.go b/messaging/repository/message_flag_test.go deleted file mode 100644 index 06716057f..000000000 --- a/messaging/repository/message_flag_test.go +++ /dev/null @@ -1,52 +0,0 @@ -// +build integration - -package repository - -import ( - "context" - "testing" - - "github.com/titpetric/factory" - - "github.com/cortezaproject/corteza-server/internal/test" - "github.com/cortezaproject/corteza-server/messaging/types" -) - -func TestReaction(t *testing.T) { - var err error - - if testing.Short() { - t.Skip("skipping test in short mode.") - return - } - - rpo := MessageFlag(context.Background(), factory.Database.MustGet("messaging")) - - tx(t, func() error { - var chID = factory.Sonyflake.NextID() - var msgID = factory.Sonyflake.NextID() - var f *types.MessageFlag - var ff types.MessageFlagSet - f, err = rpo.Create(&types.MessageFlag{ - ChannelID: chID, - MessageID: msgID, - UserID: 3, - Flag: "success", - }) - - test.Assert(t, err == nil, "Should create message flag without an error, got: %+v", err) - - f, err = rpo.FindByID(f.ID) - test.Assert(t, err == nil, "Should fetch message flag without an error, got: %+v", err) - test.Assert(t, f != nil && f.ChannelID == chID, "fetch should return valid type struct") - - ff, err = rpo.FindByMessageIDs(msgID) - test.Assert(t, err == nil, "Should fetch message flag by range without an error, got: %+v", err) - test.Assert(t, len(ff) == 1, "fetch by range should return 1 message") - - err = rpo.DeleteByID(f.ID) - test.Assert(t, err == nil, "Should delete message flag without an error, got: %+v", err) - - return nil - }) -} diff --git a/messaging/repository/message_test.go b/messaging/repository/message_test.go deleted file mode 100644 index fe174794b..000000000 --- a/messaging/repository/message_test.go +++ /dev/null @@ -1,219 +0,0 @@ -// +build integration - -package repository - -import ( - "context" - "fmt" - "testing" - - "github.com/titpetric/factory" - dbLogger "github.com/titpetric/factory/logger" - - "github.com/cortezaproject/corteza-server/internal/test" - "github.com/cortezaproject/corteza-server/messaging/types" -) - -func TestMessage(t *testing.T) { - var err error - - if testing.Short() { - t.Skip("skipping test in short mode.") - return - } - - msgRpo := Message(context.Background(), factory.Database.MustGet("messaging")) - chRpo := Channel(context.Background(), factory.Database.MustGet("messaging")) - - var msg1, msg2 = "Test message v1", "Test message v2" - - var mm types.MessageSet - - tx(t, func() error { - ch := &types.Channel{} - ch, err = chRpo.Create(ch) - ch.Type = types.ChannelTypePublic - - msg := &types.Message{ChannelID: ch.ID} - - msg.Message = msg1 - msg, err = msgRpo.Create(msg) - test.Assert(t, err == nil, "CreateMessage error: %+v", err) - test.Assert(t, msg.Message == msg1, "Changes were not stored") - - { - msg.Message = msg2 - msg, err = msgRpo.Update(msg) - test.Assert(t, err == nil, "UpdateMessage error: %+v", err) - test.Assert(t, msg.Message == msg2, "Changes were not stored") - } - - { - msg, err = msgRpo.FindByID(msg.ID) - test.Assert(t, err == nil, "FindMessageByID error: %+v", err) - test.Assert(t, msg.Message == msg2, "Changes were not stored") - } - - { - mm, err = msgRpo.Find(&types.MessageFilter{Query: msg2}) - test.Assert(t, err == nil, "FindMessages error: %+v", err) - test.Assert(t, len(mm) > 0, "No results found") - } - - { - err = msgRpo.DeleteByID(msg.ID) - test.Assert(t, err == nil, "DeleteMessageByID error: %+v", err) - } - - return nil - }) -} - -func TestBeforeMessageID(t *testing.T) { - var err error - - if testing.Short() { - t.Skip("skipping test in short mode.") - return - } - - db := factory.Database.MustGet("messaging") - msgRpo := Message(context.Background(), db) - chRpo := Channel(context.Background(), db) - - tx(t, func() error { - // insert 1 channel - ch := &types.Channel{} - ch, err = chRpo.Create(ch) - ch.Type = types.ChannelTypePublic - - // insert 100 messages - db.SetLogger(dbLogger.Silent{}) - messages := make([]*types.Message, 100) - for k, _ := range messages { - messages[k], err = msgRpo.Create(&types.Message{ - ChannelID: ch.ID, - Message: fmt.Sprintf("#%d: Lorem ipsum dolor sit amet", k), - }) - - test.Assert(t, err == nil, "CreateMessage error: %+v", err) - } - db.SetLogger(dbLogger.Default{}) - - // request last 10 messages from channel - lastPageRequest := &types.MessageFilter{ - ChannelID: []uint64{ch.ID}, - Limit: 10, - } - - var lastPage types.MessageSet - lastPage, err = msgRpo.Find(lastPageRequest) - - test.Assert(t, err == nil, "lastPageRequest error: %+v", err) - test.Assert(t, len(lastPage) > 0, "No results found (last page)") - - // request previous 10 messages from channel - prevPageRequest := &types.MessageFilter{ - ChannelID: []uint64{ch.ID}, - Limit: 10, - BeforeID: lastPage[9].ID, - } - - var prevPage types.MessageSet - prevPage, err = msgRpo.Find(prevPageRequest) - - test.Assert(t, err == nil, "prevPageRequest error: %+v", err) - test.Assert(t, prevPage[0].ID != messages[0].ID, "We have 100 IDs, second page shouldn't start with first ID") - test.Assert(t, prevPage[0].ID == messages[89].ID, "ID should match index 89 (max index - 10), but %d != %d", prevPage[0].ID, messages[89].ID) - test.Assert(t, len(prevPage) > 0, "No results found (previous page)") - - return nil - }) -} - -func TestReplies(t *testing.T) { - var err error - - if testing.Short() { - t.Skip("skipping test in short mode.") - return - } - - msgRpo := Message(context.Background(), factory.Database.MustGet("messaging")) - chRpo := Channel(context.Background(), factory.Database.MustGet("messaging")) - - var mm types.MessageSet - - tx(t, func() error { - ch := &types.Channel{} - ch, err = chRpo.Create(ch) - ch.Type = types.ChannelTypePublic - - msg := &types.Message{ChannelID: ch.ID} - rpl := &types.Message{ChannelID: ch.ID} - - msg, err = msgRpo.Create(msg) - test.Assert(t, err == nil, "CreateMessage error: %+v", err) - test.Assert(t, msg.ID > 0, "Message did not get its ID") - - rpl.ReplyTo = msg.ID - rpl, err = msgRpo.Create(rpl) - test.Assert(t, err == nil, "CreateMessage error: %+v", err) - test.Assert(t, rpl.ID > 0, "Reply did not get its ID") - - // Let's increase this so that FindThreads - // can include it into results - msgRpo.IncReplyCount(msg.ID) - - { - mm, err = msgRpo.Find(&types.MessageFilter{ - ThreadID: []uint64{msg.ID}, - ChannelID: []uint64{ch.ID}, - }) - - test.Assert(t, err == nil, "FindMessages error: %+v", err) - test.Assert(t, len(mm) == 1, "Failed to fetch only reply, got: %d", len(mm)) - test.Assert(t, mm[0].ID == rpl.ID, "Reply ID does not match") - } - - { - mm, err = msgRpo.FindThreads(&types.MessageFilter{ - ChannelID: []uint64{ch.ID}, - }) - - test.Assert(t, err == nil, "FindThreads error: %+v", err) - test.Assert(t, len(mm) == 2, "Failed to fetch messages in threads (2 messages), got: %d", len(mm)) - test.Assert(t, mm[0].ID == msg.ID, "Original message ID does not match") - test.Assert(t, mm[1].ID == rpl.ID, "Reply ID does not match") - } - - { - mm, err = msgRpo.Find(&types.MessageFilter{ - ChannelID: []uint64{ch.ID}, - }) - - test.Assert(t, err == nil, "FindMessages error: %+v", err) - test.Assert(t, len(mm) == 1, "Failed to fetch only original message") - test.Assert(t, mm[0].ID == msg.ID, "Reply ID does not match") - } - - { - - test.Assert(t, msgRpo.IncReplyCount(msg.ID) == nil, "IncReplyCount should not return an error") - test.Assert(t, msgRpo.IncReplyCount(msg.ID) == nil, "IncReplyCount should not return an error") - // +1 that we have from before - - msg, err = msgRpo.FindByID(msg.ID) - test.Assert(t, err == nil, "FindMessageByID error: %+v", err) - test.Assert(t, msg.Replies == 3, "Reply counter check failed, expecting 3, got %d", msg.Replies) - - test.Assert(t, msgRpo.DecReplyCount(msg.ID) == nil, "DecReplyCount should not return an error") - - msg, err = msgRpo.FindByID(msg.ID) - test.Assert(t, err == nil, "FindMessageByID error: %+v", err) - test.Assert(t, msg.Replies == 2, "Reply counter check failed, expecting 1, got %d", msg.Replies) - } - - return nil - }) -} diff --git a/messaging/repository/repository_test.go b/messaging/repository/repository_test.go deleted file mode 100644 index 2144fb992..000000000 --- a/messaging/repository/repository_test.go +++ /dev/null @@ -1,29 +0,0 @@ -// +build integration - -package repository - -import ( - "context" - "testing" - - "github.com/cortezaproject/corteza-server/internal/test" -) - -func TestRepository(t *testing.T) { - repo := &repository{} - repo.With(context.Background(), nil) -} - -func tx(t *testing.T, f func() error) { - var err error - db := DB(context.Background()) - - err = db.Begin() - test.Assert(t, err == nil, "Could not begin transaction: %+v", err) - - err = f() - test.Assert(t, err == nil, "Test transaction resulted in an error: %+v", err) - - err = db.Quiet().Rollback() - test.Assert(t, err == nil, "Could not rollback transaction: %+v", err) -} diff --git a/messaging/service/main_test.go b/messaging/service/main_test.go deleted file mode 100644 index 866846eae..000000000 --- a/messaging/service/main_test.go +++ /dev/null @@ -1,44 +0,0 @@ -// +build integration - -package service - -import ( - "context" - "fmt" - "os" - "testing" - - "github.com/titpetric/factory" - "go.uber.org/zap" - - messagingMigrate "github.com/cortezaproject/corteza-server/messaging/db" - "github.com/cortezaproject/corteza-server/pkg/cli/options" - "github.com/cortezaproject/corteza-server/pkg/logger" - dbLogger "github.com/titpetric/factory/logger" -) - -type mockDB struct{} - -func (mockDB) Transaction(callback func() error) error { return callback() } - -func TestMain(m *testing.M) { - logger.SetDefault(logger.MakeDebugLogger()) - - factory.Database.Add("messaging", os.Getenv("MESSAGING_DB_DSN")) - db := factory.Database.MustGet("messaging") - db.SetLogger(dbLogger.Default{}) - - // migrate database schema - if err := messagingMigrate.Migrate(db, logger.Default()); err != nil { - fmt.Printf("Error running migrations: %+v\n", err) - return - } - - Init(context.Background(), zap.NewNop(), Config{ - Storage: options.StorageOpt{ - Path: "/tmp/corteza-messaging-store", - }, - }) - - os.Exit(m.Run()) -} diff --git a/messaging/service/message.go b/messaging/service/message.go index 4dcb4f70a..dc9e5f740 100644 --- a/messaging/service/message.go +++ b/messaging/service/message.go @@ -390,10 +390,10 @@ func (svc message) Delete(messageID uint64) error { if ch, err = svc.channel.FindByID(original.ChannelID); err != nil { return } + if original.UserID == currentUserID && !svc.ac.CanUpdateOwnMessages(svc.ctx, ch) { return ErrNoPermissions.withStack() - } - if !svc.ac.CanUpdateMessages(svc.ctx, ch) { + } else if original.UserID != currentUserID && !svc.ac.CanUpdateMessages(svc.ctx, ch) { return ErrNoPermissions.withStack() } diff --git a/tests/messaging/message_flag_test.go b/tests/messaging/message_flag_test.go new file mode 100644 index 000000000..5270a4147 --- /dev/null +++ b/tests/messaging/message_flag_test.go @@ -0,0 +1,68 @@ +package messaging + +import ( + "fmt" + "net/http" + "testing" + + "github.com/steinfletcher/apitest" + + "github.com/cortezaproject/corteza-server/messaging/types" +) + +func (h helper) apiMessageSetFlag(msg *types.Message, method, flag string) *apitest.Response { + return h.apiInit(). + Method(method). + URL(fmt.Sprintf("/channels/%d/messages/%d/%s", msg.ChannelID, msg.ID, flag)). + Expect(h.t). + Status(http.StatusOK) +} + +func TestMessageFlag(t *testing.T) { + h := newHelper(t) + msg := h.repoMakeMessage("flag target", h.repoMakePublicCh(), h.cUser) + + initialState := h.repoMsgFlagLoad(msg.ID) + h.a.Len(initialState, 0) + h.a.False(initialState.IsPinned()) + h.a.False(initialState.IsBookmarked(h.cUser.ID)) + + // Pin flag (for everyone) + + h.apiMessageSetFlag(msg, "POST", "pin"). + End() + h.a.True(h.repoMsgFlagLoad(msg.ID).IsPinned()) + + h.apiMessageSetFlag(msg, "DELETE", "pin"). + End() + h.a.False(h.repoMsgFlagLoad(msg.ID).IsPinned()) + + // Bookmark flag (per user) + + h.apiMessageSetFlag(msg, "POST", "bookmark"). + End() + h.a.True(h.repoMsgFlagLoad(msg.ID).IsBookmarked(h.cUser.ID)) + + h.apiMessageSetFlag(msg, "DELETE", "bookmark"). + End() + h.a.False(h.repoMsgFlagLoad(msg.ID).IsBookmarked(h.cUser.ID)) + + // Custom flags (aka reactions) + hasReaction := func(flag string) bool { + ff, _ := h.repoMsgFlagLoad(msg.ID).Filter(func(f *types.MessageFlag) (b bool, e error) { + return f.Flag == flag && f.UserID == h.cUser.ID && f.DeletedAt == nil, nil + }) + + return len(ff) > 0 + } + + h.apiMessageSetFlag(msg, "POST", "reaction/foo"). + End() + + h.a.True(hasReaction("foo"), "expecting message to have reaction") + + h.apiMessageSetFlag(msg, "DELETE", "reaction/foo"). + End() + + h.a.False(hasReaction("foo"), "expecting message not to have reaction") +} diff --git a/tests/messaging/message_reply_test.go b/tests/messaging/message_reply_test.go index 58b65bfff..b327a9518 100644 --- a/tests/messaging/message_reply_test.go +++ b/tests/messaging/message_reply_test.go @@ -14,26 +14,58 @@ func TestMessagesReply(t *testing.T) { h := newHelper(t) msg := h.repoMakeMessage("old", h.repoMakePublicCh(), h.cUser) - rval := struct { - Response struct { - ID uint64 `json:"messageID,string"` - } - }{} + reply := func() uint64 { + rval := struct { + Response struct { + ID uint64 `json:"messageID,string"` + } + }{} + h.apiInit(). + Post(fmt.Sprintf("/channels/%d/messages/%d/replies", msg.ChannelID, msg.ID)). + JSON(`{"message":"new reply"}`). + Expect(t). + Status(http.StatusOK). + Assert(helpers.AssertNoErrors). + Assert(jsonpath.Present(`$.response.messageID`)). + Assert(jsonpath.Present(`$.response.replyTo`)). + Assert(jsonpath.Equal(`$.response.message`, `new reply`)). + End(). + JSON(&rval) + + r := h.repoMsgExistingLoad(rval.Response.ID) + h.a.Equal(`new reply`, r.Message) + h.a.Equal(msg.ID, r.ReplyTo) + return rval.Response.ID + } + + reply1ID := reply() + reply2ID := reply() + reply3ID := reply() + + _, _, _ = reply1ID, reply2ID, reply3ID + + msg = h.repoMsgExistingLoad(msg.ID) + h.a.Equal(msg.Replies, uint(3)) h.apiInit(). - Post(fmt.Sprintf("/channels/%d/messages/%d/replies", msg.ChannelID, msg.ID)). - JSON(`{"message":"new reply"}`). + Debug(). + Get("/search/threads"). + Query("channelID", fmt.Sprintf("%d", msg.ChannelID)). Expect(t). Status(http.StatusOK). Assert(helpers.AssertNoErrors). - Assert(jsonpath.Present(`$.response.messageID`)). - Assert(jsonpath.Present(`$.response.replyTo`)). - Assert(jsonpath.Equal(`$.response.message`, `new reply`)). - End(). - JSON(&rval) + Assert(jsonpath.Len(`$.response`, 4)). // 3 replies + original msg + End() - m := h.repoMsgExistingLoad(rval.Response.ID) - h.a.Equal(`new reply`, m.Message) - h.a.Equal(msg.ID, m.ReplyTo) + // Remove one of the replies + h.apiInit(). + Delete(fmt.Sprintf("/channels/%d/messages/%d", msg.ChannelID, reply2ID)). + Expect(t). + Status(http.StatusOK). + Assert(helpers.AssertNoErrors). + End() + + msg = h.repoMsgExistingLoad(msg.ID) + h.a.Equal(msg.Replies, uint(2)) } diff --git a/tests/messaging/message_test.go b/tests/messaging/message_test.go index 8ac799f9f..61440f8db 100644 --- a/tests/messaging/message_test.go +++ b/tests/messaging/message_test.go @@ -18,6 +18,14 @@ func (h helper) repoMessage() repository.MessageRepository { return repository.Message(ctx, db) } +func (h helper) repoMessageFlag() repository.MessageFlagRepository { + var ( + ctx = context.Background() + db = factory.Database.MustGet("messaging").With(ctx) + ) + + return repository.MessageFlag(ctx, db) +} func (h helper) repoMakeMessage(msg string, ch *types.Channel, u *sysTypes.User) *types.Message { m, err := h.repoMessage().Create(&types.Message{ @@ -36,3 +44,10 @@ func (h helper) repoMsgExistingLoad(ID uint64) *types.Message { h.a.NotNil(m) return m } + +func (h helper) repoMsgFlagLoad(ID uint64) types.MessageFlagSet { + ff, err := h.repoMessageFlag().FindByMessageIDs(ID) + h.a.NoError(err) + h.a.NotNil(ff) + return ff +}