From c697c86ee543d50177936e8ae006c6fb1889348b Mon Sep 17 00:00:00 2001 From: Denis Arh Date: Sun, 3 Feb 2019 13:52:27 +0100 Subject: [PATCH] Add system-cli and an user-merging tool --- cmd/system-cli/flags.go | 13 +++ cmd/system-cli/main.go | 48 +++++++++ cmd/system-cli/users.go | 168 +++++++++++++++++++++++++++++++ cmd/system/main.go | 18 +++- crm/repository/record.go | 37 +++++++ sam/repository/attachment.go | 14 +++ sam/repository/channel.go | 14 +++ sam/repository/channel_member.go | 26 +++++ sam/repository/mention.go | 14 +++ sam/repository/message.go | 31 ++++++ sam/repository/message_flag.go | 14 +++ sam/repository/unread.go | 34 ++++++- sam/service/channel.go | 2 +- sam/service/message.go | 2 +- system/start.go | 26 +++-- 15 files changed, 444 insertions(+), 17 deletions(-) create mode 100644 cmd/system-cli/flags.go create mode 100644 cmd/system-cli/main.go create mode 100644 cmd/system-cli/users.go diff --git a/cmd/system-cli/flags.go b/cmd/system-cli/flags.go new file mode 100644 index 000000000..d074ab67c --- /dev/null +++ b/cmd/system-cli/flags.go @@ -0,0 +1,13 @@ +package main + +import ( + _ "github.com/joho/godotenv/autoload" + "github.com/namsral/flag" +) + +func flags(prefix string, mountFlags ...func(...string)) { + for _, mount := range mountFlags { + mount(prefix) + } + flag.Parse() +} diff --git a/cmd/system-cli/main.go b/cmd/system-cli/main.go new file mode 100644 index 000000000..93f58d965 --- /dev/null +++ b/cmd/system-cli/main.go @@ -0,0 +1,48 @@ +package main + +import ( + "log" + "os" + + "github.com/crusttech/crust/system" + systemService "github.com/crusttech/crust/system/service" + + "github.com/crusttech/crust/internal/auth" + "github.com/crusttech/crust/internal/rbac" +) + +func main() { + flags("system", service.Flags, auth.Flags, rbac.Flags) + + // log to stdout not stderr + log.SetOutput(os.Stdout) + log.SetFlags(log.LstdFlags | log.Lshortfile) + + service.InitDb() + systemService.Init() + + var commands []string + if len(os.Args) > 0 { + + // @todo migrate to a proper solution (eg: https://github.com/spf13/cobra) + commands = os.Args[1:] + for a, arg := range os.Args { + if arg == "--" && a+1 < len(os.Args) { + commands = os.Args[a+1:] + } + } + } + + cliRouter(commands...) +} + +func cliRouter(commands ...string) { + if len(commands) == 0 { + return + } + + switch commands[0] { + case "users": + cliExecUsers(commands[1:]...) + } +} diff --git a/cmd/system-cli/users.go b/cmd/system-cli/users.go new file mode 100644 index 000000000..c59a16bb9 --- /dev/null +++ b/cmd/system-cli/users.go @@ -0,0 +1,168 @@ +package main + +import ( + "bufio" + "context" + "fmt" + "os" + "strings" + + crmRepository "github.com/crusttech/crust/crm/repository" + "github.com/crusttech/crust/internal/payload" + messagingRepository "github.com/crusttech/crust/sam/repository" + "github.com/crusttech/crust/system/service" + "github.com/crusttech/crust/system/types" +) + +func cliExecUsers(commands ...string) { + if len(commands) == 0 { + return + } + + switch commands[0] { + case "list": + cliExecUsersList(commands[1:]...) + case "merge": + cliExecUsersMerge(commands[1:]...) + } +} + +func cliExecUsersList(params ...string) { + var ( + err error + uu types.UserSet + ctx = context.Background() + ) + + if uu, err = service.DefaultUser.With(ctx).Find(nil); err != nil { + fmt.Printf("Error: %v\n", err) + os.Exit(1) + } + + for _, u := range uu { + fmt.Printf("%20d | %-40s | %-20s\n", u.ID, u.Email+" / "+u.Name+" / "+u.Username, u.UpdatedAt) + } +} + +func cliExecUsersMerge(params ...string) { + var ( + err error + uu = make([]*types.User, len(params)) + refs = make([]*userRefs, len(params)) + ids = payload.ParseUInt64s(params) + ctx = context.Background() + ) + + if len(ids) < 2 { + fmt.Printf("Expecting 2+ user IDs (2nd, 3rd ... user ID will be merged into first one\n") + os.Exit(1) + } + + for i, id := range ids { + if id == 0 { + fmt.Printf("Error: Invalid user ID %q\n", params[i]) + os.Exit(1) + } + + if uu[i], err = service.DefaultUser.With(ctx).FindByID(id); err != nil { + fmt.Printf("Error: %v\n", err) + os.Exit(1) + } + } + + db := messagingRepository.DB(ctx) + + mergers := []struct { + label string + count func(userID uint64) (c int, err error) + merge func(userID, target uint64) (err error) + }{ + {label: "MsgOw", + count: messagingRepository.Message(ctx, db).CountOwned, + merge: messagingRepository.Message(ctx, db).ChangeOwner}, + {label: "MTags", + count: messagingRepository.Message(ctx, db).CountUserTags, + merge: messagingRepository.Message(ctx, db).ChangeUserTag}, + {label: "ChCre", + count: messagingRepository.Channel(ctx, db).CountCreated, + merge: messagingRepository.Channel(ctx, db).ChangeCreator}, + {label: "Membr", + count: messagingRepository.ChannelMember(ctx, db).CountMemberships, + merge: messagingRepository.ChannelMember(ctx, db).ChangeMembership}, + {label: "AttOw", + count: messagingRepository.Attachment(ctx, db).CountOwned, + merge: messagingRepository.Attachment(ctx, db).ChangeOwnership}, + {label: "Menti", + count: messagingRepository.Mention(ctx, db).CountMentions, + merge: messagingRepository.Mention(ctx, db).ChangeMention}, + {label: "Unrd", + count: messagingRepository.Unread(ctx, db).CountOwned, + merge: messagingRepository.Unread(ctx, db).ChangeOwner}, + {label: "CRAut", + count: crmRepository.Record(ctx, db).CountAuthored, + merge: crmRepository.Record(ctx, db).ChangeAuthor}, + {label: "CRRef", + count: crmRepository.Record(ctx, db).CountReferenced, + merge: crmRepository.Record(ctx, db).ChangeReferences}, + } + + count := func(u *types.User, r *userRefs) (out string) { + out = fmt.Sprintf( + "%20d | %-40s", + u.ID, + u.Email+" / "+u.Name+" / "+u.Username, + ) + + for _, m := range mergers { + if count, err := m.count(u.ID); err != nil { + fmt.Printf("Error: %v\n", err) + os.Exit(1) + } else { + out = out + fmt.Sprintf(" | %5d", count) + } + } + + return out + fmt.Sprintln() + } + + stats := fmt.Sprintf( + "%20s | %40s", + "ID", + "Email", + ) + + for _, m := range mergers { + stats = stats + fmt.Sprintf(" | %5s", m.label) + } + + stats = stats + fmt.Sprintln() + fmt.Sprintf("Merge %d users:\n", len(uu)-1) + for i := 1; i < len(uu); i++ { + stats = stats + count(uu[i], refs[i]) + } + stats = stats + fmt.Sprintln("Target") + count(uu[0], refs[0]) + + fmt.Println(stats) + + reader := bufio.NewReader(os.Stdin) + fmt.Print("Merge [y/N]? ") + text, _ := reader.ReadByte() + if "y" != strings.ToLower(string(text)) { + os.Exit(0) + } + + for i := 1; i < len(uu); i++ { + for _, m := range mergers { + if err := m.merge(uu[i].ID, uu[0].ID); err != nil { + fmt.Printf("Error: %v\n", err) + os.Exit(1) + } + } + } + + fmt.Println("Done.") +} + +type userRefs struct { + messagesCreated int + messagesTagged int +} diff --git a/cmd/system/main.go b/cmd/system/main.go index 4b9518b79..d9edf7cd8 100644 --- a/cmd/system/main.go +++ b/cmd/system/main.go @@ -4,7 +4,7 @@ import ( "log" "os" - service "github.com/crusttech/crust/system" + "github.com/crusttech/crust/system" "github.com/crusttech/crust/internal/auth" "github.com/crusttech/crust/internal/rbac" @@ -20,7 +20,19 @@ func main() { if err := service.Init(); err != nil { log.Fatalf("Error initializing system: %+v", err) } - if err := service.Start(); err != nil { - log.Fatalf("Error starting/running system: %+v", err) + + var command string + if len(os.Args) > 1 { + command = os.Args[1] } + + switch command { + case "help": + case "merge-users": + default: + if err := service.Start(); err != nil { + log.Fatalf("Error starting/running system: %+v", err) + } + } + } diff --git a/crm/repository/record.go b/crm/repository/record.go index 686594d6c..86e089218 100644 --- a/crm/repository/record.go +++ b/crm/repository/record.go @@ -32,6 +32,11 @@ type ( LoadValues(IDs ...uint64) (rvs types.RecordValueSet, err error) DeleteValues(record *types.Record) error UpdateValues(recordID uint64, rvs types.RecordValueSet) (err error) + + CountAuthored(userID uint64) (c int, err error) + ChangeAuthor(userID, target uint64) error + CountReferenced(userID uint64) (c int, err error) + ChangeReferences(userID, target uint64) error } FindResponseMeta struct { @@ -362,3 +367,35 @@ func isRealRecordCol(name string) (string, bool) { return name, false } + +func (r *record) CountAuthored(userID uint64) (c int, err error) { + return c, r.db().Get(&c, + "SELECT COUNT(*) FROM crm_record WHERE created_by = ? OR updated_by = ? OR deleted_by = ?", + userID, userID, userID) +} + +func (r *record) ChangeAuthor(userID, target uint64) error { + if _, err := r.db().Exec("UPDATE crm_record SET created_by = ? WHERE created_by = ?", target, userID); err != nil { + return err + } + if _, err := r.db().Exec("UPDATE crm_record SET updated_by = ? WHERE updated_by = ?", target, userID); err != nil { + return err + } + if _, err := r.db().Exec("UPDATE crm_record SET deleted_by = ? WHERE deleted_by = ?", target, userID); err != nil { + return err + } + + return nil +} + +func (r *record) CountReferenced(userID uint64) (c int, err error) { + // @todo add field type (User) check + return c, r.db().Get(&c, + "SELECT COUNT(*) FROM crm_record_value WHERE ref = ?", + userID) +} + +func (r *record) ChangeReferences(userID, target uint64) error { + _, err := r.db().Exec("UPDATE crm_record_value SET ref = ? WHERE ref = ?", target, userID) + return err +} diff --git a/sam/repository/attachment.go b/sam/repository/attachment.go index 8c1dba52a..c312229f2 100644 --- a/sam/repository/attachment.go +++ b/sam/repository/attachment.go @@ -19,6 +19,9 @@ type ( CreateAttachment(mod *types.Attachment) (*types.Attachment, error) DeleteAttachmentByID(id uint64) error BindAttachment(attachmentId, messageId uint64) error + + CountOwned(userID uint64) (c int, err error) + ChangeOwnership(userID, target uint64) error } attachment struct { @@ -98,3 +101,14 @@ func (r *attachment) BindAttachment(attachmentId, messageId uint64) error { return r.db().Insert("message_attachment", bond) } + +func (r *attachment) CountOwned(userID uint64) (c int, err error) { + return c, r.db().Get(&c, + "SELECT COUNT(*) FROM attachments WHERE rel_user = ?", + userID) +} + +func (r *attachment) ChangeOwnership(userID, target uint64) error { + _, err := r.db().Exec("UPDATE attachments SET rel_user = ? WHERE rel_user = ?", target, userID) + return err +} diff --git a/sam/repository/channel.go b/sam/repository/channel.go index d4afac834..cdc1fe58f 100644 --- a/sam/repository/channel.go +++ b/sam/repository/channel.go @@ -25,6 +25,9 @@ type ( UnarchiveChannelByID(id uint64) error DeleteChannelByID(id uint64) error UndeleteChannelByID(id uint64) error + + CountCreated(userID uint64) (c int, err error) + ChangeCreator(userID, target uint64) error } channel struct { @@ -174,3 +177,14 @@ func (r *channel) DeleteChannelByID(id uint64) error { func (r *channel) UndeleteChannelByID(id uint64) error { return r.updateColumnByID("channels", "deleted_at", nil, id) } + +func (r *channel) CountCreated(userID uint64) (c int, err error) { + return c, r.db().Get(&c, + "SELECT COUNT(*) FROM channels WHERE rel_creator = ?", + userID) +} + +func (r *channel) ChangeCreator(userID, target uint64) error { + _, err := r.db().Exec("UPDATE channels SET rel_creator = ? WHERE rel_creator = ?", target, userID) + return err +} diff --git a/sam/repository/channel_member.go b/sam/repository/channel_member.go index 757bf8f44..cb97935b1 100644 --- a/sam/repository/channel_member.go +++ b/sam/repository/channel_member.go @@ -19,6 +19,9 @@ type ( Create(mod *types.ChannelMember) (*types.ChannelMember, error) Update(mod *types.ChannelMember) (*types.ChannelMember, error) Delete(channelID, userID uint64) error + + CountMemberships(userID uint64) (c int, err error) + ChangeMembership(userID, target uint64) error } channelMember struct { @@ -103,3 +106,26 @@ func (r *channelMember) Delete(channelID, userID uint64) error { sql := `DELETE FROM channel_members WHERE rel_channel = ? AND rel_user = ?` return exec(r.db().Exec(sql, channelID, userID)) } + +func (r *channelMember) CountMemberships(userID uint64) (c int, err error) { + return c, r.db().Get(&c, + "SELECT COUNT(*) FROM channel_members WHERE rel_user = ?", + userID) +} + +func (r *channelMember) ChangeMembership(userID, target uint64) (err error) { + // Remove dups + // with an ugly mysql workaround + _, err = r.db().Exec( + "DELETE FROM channel_members WHERE rel_user = ? "+ + "AND rel_channel IN (SELECT rel_channel FROM (SELECT * FROM channel_members) AS workaround WHERE rel_user = ?)", + userID, + target) + + if err != nil { + return err + } + + _, err = r.db().Exec("UPDATE channel_members SET rel_user = ? WHERE rel_user = ?", target, userID) + return err +} diff --git a/sam/repository/mention.go b/sam/repository/mention.go index 636a561a5..30fcd534b 100644 --- a/sam/repository/mention.go +++ b/sam/repository/mention.go @@ -20,6 +20,9 @@ type ( Create(m *types.Mention) (*types.Mention, error) DeleteByMessageID(ID uint64) error DeleteByID(ID uint64) error + + CountMentions(userID uint64) (c int, err error) + ChangeMention(userID, target uint64) error } mention struct { @@ -78,3 +81,14 @@ func (r *mention) DeleteByMessageID(ID uint64) error { func (r *mention) DeleteByID(ID uint64) error { return exec(r.db().Exec("DELETE FROM mentions WHERE id = ?", ID)) } + +func (r *mention) CountMentions(userID uint64) (c int, err error) { + return c, r.db().Get(&c, + "SELECT COUNT(*) FROM mentions WHERE rel_user = ?", + userID) +} + +func (r *mention) ChangeMention(userID, target uint64) error { + _, err := r.db().Exec("UPDATE mentions SET rel_user = ? WHERE rel_user = ?", target, userID) + return err +} diff --git a/sam/repository/message.go b/sam/repository/message.go index 5dcfc38df..6914811fa 100644 --- a/sam/repository/message.go +++ b/sam/repository/message.go @@ -2,6 +2,7 @@ package repository import ( "context" + "fmt" "time" "github.com/jmoiron/sqlx" @@ -24,6 +25,11 @@ type ( DeleteMessageByID(ID uint64) error IncReplyCount(ID uint64) error DecReplyCount(ID uint64) error + + CountOwned(userID uint64) (c int, err error) + ChangeOwner(userID, target uint64) error + CountUserTags(userID uint64) (c int, err error) + ChangeUserTag(userID, target uint64) error } message struct { @@ -272,3 +278,28 @@ func (r *message) DecReplyCount(ID uint64) error { _, err := r.db().Exec(sqlMessageRepliesDecCount, ID) return err } + +func (r *message) CountOwned(userID uint64) (c int, err error) { + return c, r.db().Get(&c, + "SELECT COUNT(*) FROM messages WHERE rel_user = ?", + userID) +} + +func (r *message) CountUserTags(userID uint64) (c int, err error) { + return c, r.db().Get(&c, + "SELECT COUNT(*) FROM messages WHERE message LIKE ?", + fmt.Sprintf("%%@%d%%", userID)) +} + +func (r *message) ChangeOwner(userID, target uint64) error { + _, err := r.db().Exec("UPDATE messages SET rel_user = ? WHERE rel_user = ?", target, userID) + return err +} + +func (r *message) ChangeUserTag(userID, target uint64) error { + _, err := r.db().Exec("UPDATE messages SET message = replace(message, ?, ?) WHERE message LIKE ?", + fmt.Sprintf("@%d", userID), + fmt.Sprintf("@%d", target), + fmt.Sprintf("%%@%d%%", userID)) + return err +} diff --git a/sam/repository/message_flag.go b/sam/repository/message_flag.go index 70e648628..3552a2bd3 100644 --- a/sam/repository/message_flag.go +++ b/sam/repository/message_flag.go @@ -19,6 +19,9 @@ type ( FindByFlag(messageID, userID uint64, flag string) (*types.MessageFlag, error) Create(mod *types.MessageFlag) (*types.MessageFlag, error) DeleteByID(ID uint64) error + + CountOwned(userID uint64) (c int, err error) + ChangeOwner(userID, target uint64) error } messageFlag struct { @@ -85,3 +88,14 @@ func (r *messageFlag) Create(mod *types.MessageFlag) (*types.MessageFlag, error) func (r *messageFlag) DeleteByID(ID uint64) error { return exec(r.db().Exec("DELETE FROM message_flags WHERE id = ?", ID)) } + +func (r *messageFlag) CountOwned(userID uint64) (c int, err error) { + return c, r.db().Get(&c, + "SELECT COUNT(*) FROM message_flag WHERE rel_user = ?", + userID) +} + +func (r *messageFlag) ChangeOwner(userID, target uint64) error { + _, err := r.db().Exec("UPDATE message_flag SET rel_user = ? WHERE rel_user = ?", target, userID) + return err +} diff --git a/sam/repository/unread.go b/sam/repository/unread.go index ea61463dc..714119dbb 100644 --- a/sam/repository/unread.go +++ b/sam/repository/unread.go @@ -17,6 +17,9 @@ type ( Record(userID, channelID, threadID, lastReadMessageID uint64, count uint32) error Inc(channelID, replyTo, userID uint64) error Dec(channelID, replyTo, userID uint64) error + + CountOwned(userID uint64) (c int, err error) + ChangeOwner(userID, target uint64) error } unread struct { @@ -39,8 +42,8 @@ const ( WHERE rel_channel = ? AND rel_reply_to = ? AND rel_user <> ? AND count > 0` ) -// ChannelView creates new instance of channel member repository -func ChannelView(ctx context.Context, db *factory.DB) UnreadRepository { +// Unread creates new instance of channel member repository +func Unread(ctx context.Context, db *factory.DB) UnreadRepository { return (&unread{}).With(ctx, db) } @@ -95,3 +98,30 @@ func (r *unread) Dec(channelID, threadID, userID uint64) error { _, err := r.db().Exec(sqlUnreadDecCount, channelID, threadID, userID) return err } + +func (r *unread) CountOwned(userID uint64) (c int, err error) { + return c, r.db().Get(&c, + "SELECT COUNT(*) FROM unreads WHERE rel_user = ?", + userID) +} + +func (r *unread) ChangeOwner(userID, target uint64) (err error) { + // Remove dups + // with an ugly mysql workaround + _, err = r.db().Exec( + "DELETE FROM unreads WHERE rel_user = ? "+ + "AND rel_channel IN (SELECT rel_channel FROM (SELECT * FROM unreads) AS workaround WHERE rel_user = ?)", + userID, + target) + + if err != nil { + return err + } + + _, err = r.db().Exec( + "UPDATE unreads SET rel_user = ? WHERE rel_user = ?", + target, + userID) + + return err +} diff --git a/sam/service/channel.go b/sam/service/channel.go index cb24b70a3..09ea93674 100644 --- a/sam/service/channel.go +++ b/sam/service/channel.go @@ -81,7 +81,7 @@ func (svc *channel) With(ctx context.Context) ChannelService { channel: repository.Channel(ctx, db), cmember: repository.ChannelMember(ctx, db), - unread: repository.ChannelView(ctx, db), + unread: repository.Unread(ctx, db), message: repository.Message(ctx, db), // System messages should be flushed at the end of each session diff --git a/sam/service/message.go b/sam/service/message.go index 671737392..4bb9c8ab5 100644 --- a/sam/service/message.go +++ b/sam/service/message.go @@ -83,7 +83,7 @@ func (svc *message) With(ctx context.Context) MessageService { attachment: repository.Attachment(ctx, db), channel: repository.Channel(ctx, db), cmember: repository.ChannelMember(ctx, db), - unreads: repository.ChannelView(ctx, db), + unreads: repository.Unread(ctx, db), message: repository.Message(ctx, db), mflag: repository.MessageFlag(ctx, db), mentions: repository.Mention(ctx, db), diff --git a/system/start.go b/system/start.go index f69f47cc7..72f5428dc 100644 --- a/system/start.go +++ b/system/start.go @@ -44,16 +44,7 @@ func Init() error { mail.SetupDialer(flags.smtp) - // start/configure database connection - db, err := db.TryToConnect(flags.db.DSN, flags.db.Profiler) - if err != nil { - return errors.Wrap(err, "could not connect to database") - } - - // migrate database schema - if err := migrate.Migrate(db); err != nil { - return err - } + InitDb() // configure resputil options resputil.SetConfig(resputil.Options{ @@ -69,6 +60,21 @@ func Init() error { return nil } +func InitDb() error { + // start/configure database connection + db, err := db.TryToConnect(flags.db.DSN, flags.db.Profiler) + if err != nil { + return errors.Wrap(err, "could not connect to database") + } + + // migrate database schema + if err := migrate.Migrate(db); err != nil { + return err + } + + return nil +} + func Start() error { log.Printf("Starting "+os.Args[0]+", version: %v, built on: %v", version.Version, version.BuildTime) log.Println("Starting http server on address " + flags.http.Addr)