diff --git a/auth/service/service.go b/auth/service/service.go new file mode 100644 index 000000000..10fb34954 --- /dev/null +++ b/auth/service/service.go @@ -0,0 +1,16 @@ +package service + +import ( + "sync" +) + +var ( + o sync.Once + DefaultUser UserService +) + +func Init() { + o.Do(func() { + DefaultUser = User() + }) +} diff --git a/auth/service/user.go b/auth/service/user.go index f23aa3b08..37536bdd9 100644 --- a/auth/service/user.go +++ b/auth/service/user.go @@ -2,6 +2,7 @@ package service import ( "context" + "github.com/crusttech/crust/auth/repository" "github.com/crusttech/crust/auth/types" ) @@ -19,7 +20,15 @@ type ( UserService interface { With(ctx context.Context) UserService + FindByID(id uint64) (*types.User, error) + Find(filter *types.UserFilter) ([]*types.User, error) + Create(input *types.User) (*types.User, error) + Update(mod *types.User) (*types.User, error) + Delete(id uint64) error + Suspend(id uint64) error + Unsuspend(id uint64) error + ValidateCredentials(username, password string) (*types.User, error) } ) diff --git a/sam/service/user_test.go b/auth/service/user_test.go similarity index 100% rename from sam/service/user_test.go rename to auth/service/user_test.go diff --git a/sam/repository/channel_test.go b/sam/repository/channel_test.go index 0177aa059..2e6cd5098 100644 --- a/sam/repository/channel_test.go +++ b/sam/repository/channel_test.go @@ -62,31 +62,3 @@ func TestChannel(t *testing.T) { } } } - -func TestChannelMembers(t *testing.T) { - var err error - - if testing.Short() { - t.Skip("skipping test in short mode.") - return - } - - rpo := New() - chn := &types.Channel{} - usr := &types.User{} - - { - chn, err = rpo.CreateChannel(chn) - assert(t, err == nil, "CreateChannel: %v", err) - - { - usr, err = rpo.CreateUser(usr) - assert(t, err == nil, "CreateUser error: %v", err) - - { - _, err = rpo.AddChannelMember(&types.ChannelMember{ChannelID: chn.ID, UserID: usr.ID}) - assert(t, err == nil, "AddChannelMember error: %v", err) - } - } - } -} diff --git a/sam/repository/repository.go b/sam/repository/repository.go index f1a144d22..49e871c8b 100644 --- a/sam/repository/repository.go +++ b/sam/repository/repository.go @@ -35,7 +35,6 @@ type ( Organisation Reaction Team - User EventQueue } diff --git a/sam/repository/user.go b/sam/repository/user.go deleted file mode 100644 index 91212fd7e..000000000 --- a/sam/repository/user.go +++ /dev/null @@ -1,89 +0,0 @@ -package repository - -import ( - "github.com/crusttech/crust/sam/types" - "github.com/titpetric/factory" - "time" -) - -type ( - User interface { - FindUserByUsername(username string) (*types.User, error) - FindUserByID(id uint64) (*types.User, error) - FindUsers(filter *types.UserFilter) ([]*types.User, error) - CreateUser(mod *types.User) (*types.User, error) - UpdateUser(mod *types.User) (*types.User, error) - SuspendUserByID(id uint64) error - UnsuspendUserByID(id uint64) error - DeleteUserByID(id uint64) error - } -) - -const ( - sqlUserScope = "suspended_at IS NULL AND deleted_at IS NULL" - sqlUserSelect = "SELECT * FROM users WHERE " + sqlUserScope - - ErrUserNotFound = repositoryError("UserNotFound") -) - -func (r *repository) FindUserByUsername(username string) (*types.User, error) { - sql := "SELECT * FROM users WHERE username = ? AND " + sqlUserScope - mod := &types.User{} - - return mod, isFound(r.db().Get(mod, sql, username), mod.ID > 0, ErrUserNotFound) -} - -func (r *repository) FindUserByID(id uint64) (*types.User, error) { - sql := "SELECT * FROM users WHERE id = ? AND " + sqlUserScope - mod := &types.User{} - - return mod, isFound(r.db().Get(mod, sql, id), mod.ID > 0, ErrUserNotFound) -} - -func (r *repository) FindUsers(filter *types.UserFilter) ([]*types.User, error) { - rval := make([]*types.User, 0) - params := make([]interface{}, 0) - sql := "SELECT * FROM users WHERE " + sqlUserScope - - if filter != nil { - if filter.Query != "" { - sql += " AND username LIKE ?" - params = append(params, filter.Query+"%") - } - - if filter.MembersOfChannel > 0 { - sql += " AND id IN (SELECT rel_user FROM channel_members WHERE rel_channel = ?)" - params = append(params, filter.MembersOfChannel) - } - } - - sql += " ORDER BY username ASC" - - return rval, r.db().Select(&rval, sql, params...) -} - -func (r *repository) CreateUser(mod *types.User) (*types.User, error) { - mod.ID = factory.Sonyflake.NextID() - mod.CreatedAt = time.Now() - mod.Meta = coalesceJson(mod.Meta, []byte("{}")) - return mod, r.db().Insert("users", mod) -} - -func (r *repository) UpdateUser(mod *types.User) (*types.User, error) { - mod.UpdatedAt = timeNowPtr() - mod.Meta = coalesceJson(mod.Meta, []byte("{}")) - - return mod, r.db().Replace("users", mod) -} - -func (r *repository) SuspendUserByID(id uint64) error { - return r.updateColumnByID("users", "suspend_at", time.Now(), id) -} - -func (r *repository) UnsuspendUserByID(id uint64) error { - return r.updateColumnByID("users", "suspend_at", nil, id) -} - -func (r *repository) DeleteUserByID(id uint64) error { - return r.updateColumnByID("users", "deleted_at", nil, id) -} diff --git a/sam/repository/user_test.go b/sam/repository/user_test.go deleted file mode 100644 index aead7a87c..000000000 --- a/sam/repository/user_test.go +++ /dev/null @@ -1,53 +0,0 @@ -package repository - -import ( - "github.com/crusttech/crust/sam/types" - "testing" -) - -func TestUser(t *testing.T) { - var err error - - if testing.Short() { - t.Skip("skipping test in short mode.") - return - } - - rpo := New() - team := &types.User{} - - var name1, name2 = "Test user v1", "Test user v2" - - var aa []*types.User - - { - team.Username = name1 - team, err = rpo.CreateUser(team) - assert(t, err == nil, "CreateUser error: %v", err) - assert(t, team.Username == name1, "Changes were not stored") - - { - team.Username = name2 - team, err = rpo.UpdateUser(team) - assert(t, err == nil, "UpdateUser error: %v", err) - assert(t, team.Username == name2, "Changes were not stored") - } - - { - team, err = rpo.FindUserByID(team.ID) - assert(t, err == nil, "FindUserByID error: %v", err) - assert(t, team.Username == name2, "Changes were not stored") - } - - { - aa, err = rpo.FindUsers(&types.UserFilter{Query: name2}) - assert(t, err == nil, "FindUsers error: %v", err) - assert(t, len(aa) > 0, "No results found") - } - - { - err = rpo.DeleteUserByID(team.ID) - assert(t, err == nil, "DeleteUserByID error: %v", err) - } - } -} diff --git a/sam/rest/attachment.go b/sam/rest/attachment.go index 768de7caf..081f32131 100644 --- a/sam/rest/attachment.go +++ b/sam/rest/attachment.go @@ -15,7 +15,9 @@ var _ = errors.Wrap type ( Attachment struct { - svc service.AttachmentService + svc struct { + att service.AttachmentService + } } file struct { @@ -25,8 +27,10 @@ type ( } ) -func (Attachment) New(svc service.AttachmentService) *Attachment { - return &Attachment{svc: svc} +func (Attachment) New() *Attachment { + ctrl := &Attachment{} + ctrl.svc.att = service.DefaultAttachment + return ctrl } func (ctrl *Attachment) Original(ctx context.Context, r *request.AttachmentOriginal) (interface{}, error) { @@ -41,14 +45,14 @@ func (ctrl *Attachment) Preview(ctx context.Context, r *request.AttachmentPrevie func (ctrl Attachment) get(ID uint64, preview, download bool) (handlers.Downloadable, error) { rval := &file{download: download} - if att, err := ctrl.svc.FindByID(ID); err != nil { + if att, err := ctrl.svc.att.FindByID(ID); err != nil { return nil, err } else { rval.Attachment = att if preview { - rval.content, err = ctrl.svc.OpenPreview(att) + rval.content, err = ctrl.svc.att.OpenPreview(att) } else { - rval.content, err = ctrl.svc.OpenOriginal(att) + rval.content, err = ctrl.svc.att.OpenOriginal(att) } if err != nil { diff --git a/sam/rest/auth.go b/sam/rest/auth.go index ab05e6a9b..d89a8b6a3 100644 --- a/sam/rest/auth.go +++ b/sam/rest/auth.go @@ -2,9 +2,10 @@ package rest import ( "context" + "github.com/crusttech/crust/auth/service" + "github.com/crusttech/crust/auth/types" "github.com/crusttech/crust/internal/auth" "github.com/crusttech/crust/sam/rest/request" - "github.com/crusttech/crust/sam/types" "github.com/pkg/errors" ) @@ -12,8 +13,10 @@ var _ = errors.Wrap type ( Auth struct { - user authUserBasics - token auth.TokenEncoder + svc struct { + user service.UserService + token auth.TokenEncoder + } } authPayload struct { @@ -27,22 +30,23 @@ type ( } ) -func (Auth) New(credValidator authUserBasics, tknEncoder auth.TokenEncoder) *Auth { - return &Auth{ - credValidator, - tknEncoder, - } +func (Auth) New(tknEncoder auth.TokenEncoder) *Auth { + ctrl := &Auth{} + ctrl.svc.user = service.DefaultUser + ctrl.svc.token = tknEncoder + + return ctrl } func (ctrl *Auth) Login(ctx context.Context, r *request.AuthLogin) (interface{}, error) { - return ctrl.tokenize(ctrl.user.ValidateCredentials(ctx, r.Username, r.Password)) + return ctrl.tokenize(ctrl.svc.user.ValidateCredentials(r.Username, r.Password)) } func (ctrl *Auth) Create(ctx context.Context, r *request.AuthCreate) (interface{}, error) { user := &types.User{Username: r.Username} user.GeneratePassword(r.Password) - return ctrl.tokenize(ctrl.user.Create(ctx, user)) + return ctrl.tokenize(ctrl.svc.user.With(ctx).Create(user)) } // Wraps user return value and appends JWT @@ -52,7 +56,7 @@ func (ctrl *Auth) tokenize(user *types.User, err error) (interface{}, error) { } return &authPayload{ - JWT: ctrl.token.Encode(user), + JWT: ctrl.svc.token.Encode(user), User: user, }, nil } diff --git a/sam/rest/channel.go b/sam/rest/channel.go index 8181d6f9a..165f4dcee 100644 --- a/sam/rest/channel.go +++ b/sam/rest/channel.go @@ -7,7 +7,6 @@ import ( "github.com/crusttech/crust/sam/service" "github.com/crusttech/crust/sam/types" "github.com/pkg/errors" - "io" ) var _ = errors.Wrap @@ -15,20 +14,16 @@ var _ = errors.Wrap type ( Channel struct { svc struct { - ch service.ChannelService - at channelAttachmentService + ch service.ChannelService + att service.AttachmentService } } - - channelAttachmentService interface { - Create(ctx context.Context, channelID uint64, name string, size int64, fh io.ReadSeeker) (*types.Attachment, error) - } ) -func (Channel) New(chSvc service.ChannelService, atSvc service.AttachmentService) *Channel { +func (Channel) New() *Channel { ctrl := &Channel{} - ctrl.svc.ch = chSvc - ctrl.svc.at = atSvc + ctrl.svc.ch = service.DefaultChannel + ctrl.svc.att = service.DefaultAttachment return ctrl } @@ -89,7 +84,7 @@ func (ctrl *Channel) Attach(ctx context.Context, r *request.ChannelAttach) (inte defer file.Close() - return ctrl.svc.at.Create( + return ctrl.svc.att.Create( ctx, r.ChannelID, r.Upload.Filename, diff --git a/sam/rest/message.go b/sam/rest/message.go index d377c084c..82683b275 100644 --- a/sam/rest/message.go +++ b/sam/rest/message.go @@ -12,30 +12,34 @@ var _ = errors.Wrap type ( Message struct { - svc service.MessageService + svc struct { + msg service.MessageService + } } ) -func (Message) New(message service.MessageService) *Message { - return &Message{message} +func (Message) New() *Message { + ctrl := &Message{} + ctrl.svc.msg = service.DefaultMessage + return ctrl } func (ctrl *Message) Create(ctx context.Context, r *request.MessageCreate) (interface{}, error) { - return ctrl.svc.Create(ctx, &types.Message{ + return ctrl.svc.msg.Create(ctx, &types.Message{ ChannelID: r.ChannelID, Message: r.Message, }) } func (ctrl *Message) History(ctx context.Context, r *request.MessageHistory) (interface{}, error) { - return ctrl.svc.Find(ctx, &types.MessageFilter{ + return ctrl.svc.msg.Find(ctx, &types.MessageFilter{ ChannelID: r.ChannelID, FromMessageID: r.LastMessageID, }) } func (ctrl *Message) Edit(ctx context.Context, r *request.MessageEdit) (interface{}, error) { - return ctrl.svc.Update(ctx, &types.Message{ + return ctrl.svc.msg.Update(ctx, &types.Message{ ID: r.MessageID, ChannelID: r.ChannelID, Message: r.Message, @@ -43,36 +47,36 @@ func (ctrl *Message) Edit(ctx context.Context, r *request.MessageEdit) (interfac } func (ctrl *Message) Delete(ctx context.Context, r *request.MessageDelete) (interface{}, error) { - return nil, ctrl.svc.Delete(ctx, r.MessageID) + return nil, ctrl.svc.msg.Delete(ctx, r.MessageID) } func (ctrl *Message) Search(ctx context.Context, r *request.MessageSearch) (interface{}, error) { - return ctrl.svc.Find(ctx, &types.MessageFilter{ + return ctrl.svc.msg.Find(ctx, &types.MessageFilter{ ChannelID: r.ChannelID, Query: r.Query, }) } func (ctrl *Message) Pin(ctx context.Context, r *request.MessagePin) (interface{}, error) { - return nil, ctrl.svc.Pin(ctx, r.MessageID) + return nil, ctrl.svc.msg.Pin(ctx, r.MessageID) } func (ctrl *Message) Unpin(ctx context.Context, r *request.MessageUnpin) (interface{}, error) { - return nil, ctrl.svc.Unpin(ctx, r.MessageID) + return nil, ctrl.svc.msg.Unpin(ctx, r.MessageID) } func (ctrl *Message) Flag(ctx context.Context, r *request.MessageFlag) (interface{}, error) { - return nil, ctrl.svc.Flag(ctx, r.MessageID) + return nil, ctrl.svc.msg.Flag(ctx, r.MessageID) } func (ctrl *Message) Unflag(ctx context.Context, r *request.MessageUnflag) (interface{}, error) { - return nil, ctrl.svc.Unflag(ctx, r.MessageID) + return nil, ctrl.svc.msg.Unflag(ctx, r.MessageID) } func (ctrl *Message) React(ctx context.Context, r *request.MessageReact) (interface{}, error) { - return nil, ctrl.svc.React(ctx, r.MessageID, r.Reaction) + return nil, ctrl.svc.msg.React(ctx, r.MessageID, r.Reaction) } func (ctrl *Message) Unreact(ctx context.Context, r *request.MessageUnreact) (interface{}, error) { - return nil, ctrl.svc.Unreact(ctx, r.MessageID, r.Reaction) + return nil, ctrl.svc.msg.Unreact(ctx, r.MessageID, r.Reaction) } diff --git a/sam/rest/organisation.go b/sam/rest/organisation.go index 94916e669..ad7ec530a 100644 --- a/sam/rest/organisation.go +++ b/sam/rest/organisation.go @@ -12,20 +12,24 @@ var _ = errors.Wrap type ( Organisation struct { - svc service.OrganisationService + svc struct { + org service.OrganisationService + } } ) -func (Organisation) New(organisation service.OrganisationService) *Organisation { - return &Organisation{organisation} +func (Organisation) New() *Organisation { + ctrl := &Organisation{} + ctrl.svc.org = service.DefaultOrganisation + return ctrl } func (ctrl *Organisation) Read(ctx context.Context, r *request.OrganisationRead) (interface{}, error) { - return ctrl.svc.FindByID(ctx, r.ID) + return ctrl.svc.org.FindByID(ctx, r.ID) } func (ctrl *Organisation) List(ctx context.Context, r *request.OrganisationList) (interface{}, error) { - return ctrl.svc.Find(ctx, &types.OrganisationFilter{Query: r.Query}) + return ctrl.svc.org.Find(ctx, &types.OrganisationFilter{Query: r.Query}) } func (ctrl *Organisation) Create(ctx context.Context, r *request.OrganisationCreate) (interface{}, error) { @@ -33,7 +37,7 @@ func (ctrl *Organisation) Create(ctx context.Context, r *request.OrganisationCre Name: r.Name, } - return ctrl.svc.Create(ctx, org) + return ctrl.svc.org.Create(ctx, org) } func (ctrl *Organisation) Edit(ctx context.Context, r *request.OrganisationEdit) (interface{}, error) { @@ -42,13 +46,13 @@ func (ctrl *Organisation) Edit(ctx context.Context, r *request.OrganisationEdit) Name: r.Name, } - return ctrl.svc.Update(ctx, org) + return ctrl.svc.org.Update(ctx, org) } func (ctrl *Organisation) Remove(ctx context.Context, r *request.OrganisationRemove) (interface{}, error) { - return nil, ctrl.svc.Delete(ctx, r.ID) + return nil, ctrl.svc.org.Delete(ctx, r.ID) } func (ctrl *Organisation) Archive(ctx context.Context, r *request.OrganisationArchive) (interface{}, error) { - return nil, ctrl.svc.Archive(ctx, r.ID) + return nil, ctrl.svc.org.Archive(ctx, r.ID) } diff --git a/sam/rest/router.go b/sam/rest/router.go index 4341c0633..6b14421fd 100644 --- a/sam/rest/router.go +++ b/sam/rest/router.go @@ -1,60 +1,32 @@ package rest import ( - "github.com/go-chi/chi" - "log" - "github.com/crusttech/crust/internal/auth" - "github.com/crusttech/crust/internal/store" "github.com/crusttech/crust/sam/rest/handlers" - "github.com/crusttech/crust/sam/service" + "github.com/go-chi/chi" ) func MountRoutes(jwtAuth auth.TokenEncoder) func(chi.Router) { - // Initialize services - fs, err := store.New("var/store") - if err != nil { - log.Fatalf("Failed to initialize stor: %v", err) - } - - var ( - channelSvc = service.Channel() - attachmentSvc = service.Attachment(fs) - messageSvc = service.Message(attachmentSvc) - organisationSvc = service.Organisation() - teamSvc = service.Team() - userSvc = service.User() - ) - - var ( - channel = Channel{}.New(channelSvc, attachmentSvc) - message = Message{}.New(messageSvc) - organisation = Organisation{}.New(organisationSvc) - team = Team{}.New(teamSvc) - user = User{}.New(userSvc, messageSvc) - attachment = Attachment{}.New(attachmentSvc) - ) - // Initialize handers & controllers. return func(r chi.Router) { // Cookie expiration in minutes // @todo pull this from auth/jwt config var cookieExp = 3600 - handlers.NewAuthCustom(Auth{}.New(userSvc, jwtAuth), cookieExp).MountRoutes(r) + handlers.NewAuthCustom(Auth{}.New(jwtAuth), cookieExp).MountRoutes(r) // @todo solve cookie issues ( - handlers.NewAttachmentDownloadable(attachment).MountRoutes(r) + handlers.NewAttachmentDownloadable(Attachment{}.New()).MountRoutes(r) // Protect all _private_ routes r.Group(func(r chi.Router) { r.Use(auth.AuthenticationMiddlewareValidOnly) - handlers.NewChannel(channel).MountRoutes(r) - handlers.NewMessage(message).MountRoutes(r) - handlers.NewOrganisation(organisation).MountRoutes(r) - handlers.NewTeam(team).MountRoutes(r) - handlers.NewUser(user).MountRoutes(r) + handlers.NewChannel(Channel{}.New()).MountRoutes(r) + handlers.NewMessage(Message{}.New()).MountRoutes(r) + handlers.NewOrganisation(Organisation{}.New()).MountRoutes(r) + handlers.NewTeam(Team{}.New()).MountRoutes(r) + handlers.NewUser(User{}.New()).MountRoutes(r) }) } } diff --git a/sam/rest/team.go b/sam/rest/team.go index 2c1264915..355e3d1dc 100644 --- a/sam/rest/team.go +++ b/sam/rest/team.go @@ -14,20 +14,24 @@ var _ = errors.Wrap type ( Team struct { - svc service.TeamService + svc struct { + team service.TeamService + } } ) -func (Team) New(team service.TeamService) *Team { - return &Team{team} +func (Team) New() *Team { + ctrl := &Team{} + ctrl.svc.team = service.DefaultTeam + return ctrl } func (ctrl *Team) Read(ctx context.Context, r *request.TeamRead) (interface{}, error) { - return ctrl.svc.FindByID(ctx, r.TeamID) + return ctrl.svc.team.FindByID(ctx, r.TeamID) } func (ctrl *Team) List(ctx context.Context, r *request.TeamList) (interface{}, error) { - return ctrl.svc.Find(ctx, &types.TeamFilter{Query: r.Query}) + return ctrl.svc.team.Find(ctx, &types.TeamFilter{Query: r.Query}) } func (ctrl *Team) Create(ctx context.Context, r *request.TeamCreate) (interface{}, error) { @@ -35,7 +39,7 @@ func (ctrl *Team) Create(ctx context.Context, r *request.TeamCreate) (interface{ Name: r.Name, } - return ctrl.svc.Create(ctx, org) + return ctrl.svc.team.Create(ctx, org) } func (ctrl *Team) Edit(ctx context.Context, r *request.TeamEdit) (interface{}, error) { @@ -44,21 +48,21 @@ func (ctrl *Team) Edit(ctx context.Context, r *request.TeamEdit) (interface{}, e Name: r.Name, } - return ctrl.svc.Update(ctx, org) + return ctrl.svc.team.Update(ctx, org) } func (ctrl *Team) Remove(ctx context.Context, r *request.TeamRemove) (interface{}, error) { - return nil, ctrl.svc.Delete(ctx, r.TeamID) + return nil, ctrl.svc.team.Delete(ctx, r.TeamID) } func (ctrl *Team) Archive(ctx context.Context, r *request.TeamArchive) (interface{}, error) { - return nil, ctrl.svc.Archive(ctx, r.TeamID) + return nil, ctrl.svc.team.Archive(ctx, r.TeamID) } func (ctrl *Team) Merge(ctx context.Context, r *request.TeamMerge) (interface{}, error) { - return nil, ctrl.svc.Merge(ctx, r.TeamID, r.Destination) + return nil, ctrl.svc.team.Merge(ctx, r.TeamID, r.Destination) } func (ctrl *Team) Move(ctx context.Context, r *request.TeamMove) (interface{}, error) { - return nil, ctrl.svc.Move(ctx, r.TeamID, r.Organisation_id) + return nil, ctrl.svc.team.Move(ctx, r.TeamID, r.Organisation_id) } diff --git a/sam/rest/user.go b/sam/rest/user.go index fc6b36f3c..090348f37 100644 --- a/sam/rest/user.go +++ b/sam/rest/user.go @@ -3,6 +3,8 @@ package rest import ( "context" + authService "github.com/crusttech/crust/auth/service" + authTypes "github.com/crusttech/crust/auth/types" "github.com/crusttech/crust/sam/rest/request" "github.com/crusttech/crust/sam/service" "github.com/crusttech/crust/sam/types" @@ -14,22 +16,22 @@ var _ = errors.Wrap type ( User struct { svc struct { - user service.UserService + user authService.UserService message service.MessageService } } ) -func (User) New(user service.UserService, message service.MessageService) *User { +func (User) New() *User { ctrl := &User{} - ctrl.svc.user = user - ctrl.svc.message = message + ctrl.svc.user = authService.DefaultUser + ctrl.svc.message = service.DefaultMessage return ctrl } // Searches the users table in the database to find users by matching (by-prefix) their.Username func (ctrl *User) Search(ctx context.Context, r *request.UserSearch) (interface{}, error) { - return ctrl.svc.user.Find(ctx, &types.UserFilter{Query: r.Query}) + return ctrl.svc.user.With(ctx).Find(&authTypes.UserFilter{Query: r.Query}) } func (ctrl *User) Message(ctx context.Context, r *request.UserMessage) (interface{}, error) { diff --git a/sam/service/channel.go b/sam/service/channel.go index cedfceed6..fcb17ffe4 100644 --- a/sam/service/channel.go +++ b/sam/service/channel.go @@ -12,12 +12,12 @@ import ( type ( channel struct { rpo channelRepository - usr UserService } ChannelService interface { FindByID(ctx context.Context, channelID uint64) (*types.Channel, error) Find(ctx context.Context, filter *types.ChannelFilter) ([]*types.Channel, error) + FindByMembership(ctx context.Context) (rval []*types.Channel, err error) Create(ctx context.Context, channel *types.Channel) (*types.Channel, error) Update(ctx context.Context, channel *types.Channel) (*types.Channel, error) @@ -40,7 +40,6 @@ func Channel() *channel { var svc = &channel{} svc.rpo = repository.New() - svc.usr = User() //svc.sec.ch = ChannelSecurity(svc.rpo) return svc @@ -64,10 +63,15 @@ func (svc channel) Find(ctx context.Context, filter *types.ChannelFilter) ([]*ty if cc, err := svc.rpo.FindChannels(filter); err != nil { return nil, err } else { - return cc, svc.usr.LoadFromChannels(ctx, cc) + return cc, svc.preloadMembers(ctx, cc) } } +func (svc channel) preloadMembers(ctx context.Context, set types.ChannelSet) error { + // @todo implement + return nil +} + // Returns all channels with membership info func (svc channel) FindByMembership(ctx context.Context) (rval []*types.Channel, err error) { return rval, svc.rpo.BeginWith(ctx, func(r repository.Interfaces) error { diff --git a/sam/service/service.go b/sam/service/service.go new file mode 100644 index 000000000..870843c6e --- /dev/null +++ b/sam/service/service.go @@ -0,0 +1,34 @@ +package service + +import ( + "log" + "sync" + + "github.com/crusttech/crust/internal/store" +) + +var ( + o sync.Once + DefaultAttachment AttachmentService + DefaultChannel ChannelService + DefaultMessage MessageService + DefaultOrganisation OrganisationService + DefaultPubSub *pubSub + DefaultTeam TeamService +) + +func Init() { + o.Do(func() { + fs, err := store.New("var/store") + if err != nil { + log.Fatalf("Failed to initialize stor: %v", err) + } + + DefaultAttachment = Attachment(fs) + DefaultChannel = Channel() + DefaultMessage = Message(DefaultAttachment) + DefaultOrganisation = Organisation() + DefaultPubSub = PubSub() + DefaultTeam = Team() + }) +} diff --git a/sam/service/service_mock_test.go b/sam/service/service_mock_test.go index 29364bf64..2b89daa10 100644 --- a/sam/service/service_mock_test.go +++ b/sam/service/service_mock_test.go @@ -616,107 +616,6 @@ func (mr *MockRepositoryMockRecorder) MoveTeamByID(id, targetOrganisationID inte return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MoveTeamByID", reflect.TypeOf((*MockRepository)(nil).MoveTeamByID), id, targetOrganisationID) } -// FindUserByUsername mocks base method -func (m *MockRepository) FindUserByUsername(username string) (*types.User, error) { - ret := m.ctrl.Call(m, "FindUserByUsername", username) - ret0, _ := ret[0].(*types.User) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// FindUserByUsername indicates an expected call of FindUserByUsername -func (mr *MockRepositoryMockRecorder) FindUserByUsername(username interface{}) *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindUserByUsername", reflect.TypeOf((*MockRepository)(nil).FindUserByUsername), username) -} - -// FindUserByID mocks base method -func (m *MockRepository) FindUserByID(id uint64) (*types.User, error) { - ret := m.ctrl.Call(m, "FindUserByID", id) - ret0, _ := ret[0].(*types.User) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// FindUserByID indicates an expected call of FindUserByID -func (mr *MockRepositoryMockRecorder) FindUserByID(id interface{}) *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindUserByID", reflect.TypeOf((*MockRepository)(nil).FindUserByID), id) -} - -// FindUsers mocks base method -func (m *MockRepository) FindUsers(filter *types.UserFilter) ([]*types.User, error) { - ret := m.ctrl.Call(m, "FindUsers", filter) - ret0, _ := ret[0].([]*types.User) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// FindUsers indicates an expected call of FindUsers -func (mr *MockRepositoryMockRecorder) FindUsers(filter interface{}) *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindUsers", reflect.TypeOf((*MockRepository)(nil).FindUsers), filter) -} - -// CreateUser mocks base method -func (m *MockRepository) CreateUser(mod *types.User) (*types.User, error) { - ret := m.ctrl.Call(m, "CreateUser", mod) - ret0, _ := ret[0].(*types.User) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// CreateUser indicates an expected call of CreateUser -func (mr *MockRepositoryMockRecorder) CreateUser(mod interface{}) *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateUser", reflect.TypeOf((*MockRepository)(nil).CreateUser), mod) -} - -// UpdateUser mocks base method -func (m *MockRepository) UpdateUser(mod *types.User) (*types.User, error) { - ret := m.ctrl.Call(m, "UpdateUser", mod) - ret0, _ := ret[0].(*types.User) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// UpdateUser indicates an expected call of UpdateUser -func (mr *MockRepositoryMockRecorder) UpdateUser(mod interface{}) *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUser", reflect.TypeOf((*MockRepository)(nil).UpdateUser), mod) -} - -// SuspendUserByID mocks base method -func (m *MockRepository) SuspendUserByID(id uint64) error { - ret := m.ctrl.Call(m, "SuspendUserByID", id) - ret0, _ := ret[0].(error) - return ret0 -} - -// SuspendUserByID indicates an expected call of SuspendUserByID -func (mr *MockRepositoryMockRecorder) SuspendUserByID(id interface{}) *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SuspendUserByID", reflect.TypeOf((*MockRepository)(nil).SuspendUserByID), id) -} - -// UnsuspendUserByID mocks base method -func (m *MockRepository) UnsuspendUserByID(id uint64) error { - ret := m.ctrl.Call(m, "UnsuspendUserByID", id) - ret0, _ := ret[0].(error) - return ret0 -} - -// UnsuspendUserByID indicates an expected call of UnsuspendUserByID -func (mr *MockRepositoryMockRecorder) UnsuspendUserByID(id interface{}) *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnsuspendUserByID", reflect.TypeOf((*MockRepository)(nil).UnsuspendUserByID), id) -} - -// DeleteUserByID mocks base method -func (m *MockRepository) DeleteUserByID(id uint64) error { - ret := m.ctrl.Call(m, "DeleteUserByID", id) - ret0, _ := ret[0].(error) - return ret0 -} - -// DeleteUserByID indicates an expected call of DeleteUserByID -func (mr *MockRepositoryMockRecorder) DeleteUserByID(id interface{}) *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUserByID", reflect.TypeOf((*MockRepository)(nil).DeleteUserByID), id) -} - // EventQueuePull mocks base method func (m *MockRepository) EventQueuePull(origin uint64) ([]*types.EventQueueItem, error) { ret := m.ctrl.Call(m, "EventQueuePull", origin) diff --git a/sam/service/user.go b/sam/service/user.go deleted file mode 100644 index b997c7407..000000000 --- a/sam/service/user.go +++ /dev/null @@ -1,105 +0,0 @@ -package service - -import ( - "context" - "github.com/crusttech/crust/sam/repository" - "github.com/crusttech/crust/sam/types" -) - -const ( - ErrUserInvalidCredentials = serviceError("UserInvalidCredentials") - ErrUserLocked = serviceError("UserLocked") -) - -type ( - user struct { - rpo userRepository - } - - UserService interface { - Find(ctx context.Context, filter *types.UserFilter) ([]*types.User, error) - LoadFromChannels(ctx context.Context, cc types.ChannelSet) (err error) - } - - userRepository interface { - repository.Transactionable - repository.Contextable - repository.User - } -) - -func User() *user { - return &user{rpo: repository.New()} -} - -func (svc user) ValidateCredentials(ctx context.Context, username, password string) (*types.User, error) { - user, err := svc.rpo.FindUserByUsername(username) - if err != nil { - return nil, err - } - - if !user.ValidatePassword(password) { - return nil, ErrUserInvalidCredentials - } - - if !svc.canLogin(user) { - return nil, ErrUserLocked - } - - return user, nil -} - -func (svc user) FindByID(ctx context.Context, id uint64) (*types.User, error) { - return svc.rpo.WithCtx(ctx).FindUserByID(id) -} - -func (svc user) LoadFromChannels(ctx context.Context, cc types.ChannelSet) (err error) { - return cc.Walk(func(c *types.Channel) error { - // @todo doing N selects (one per chan) for now, optimize! - c.Members, err = svc.rpo.FindUsers(&types.UserFilter{MembersOfChannel: c.ID}) - return err - }) -} - -func (svc user) Find(ctx context.Context, filter *types.UserFilter) ([]*types.User, error) { - return svc.rpo.FindUsers(filter) -} - -func (svc user) Create(ctx context.Context, input *types.User) (user *types.User, err error) { - return user, svc.rpo.BeginWith(ctx, func(r repository.Interfaces) error { - // Encrypt user password - if user, err = r.CreateUser(input); err != nil { - return err - } - - return nil - }) -} - -func (svc user) Update(ctx context.Context, mod *types.User) (*types.User, error) { - return svc.rpo.UpdateUser(mod) -} - -func (svc user) canLogin(u *types.User) bool { - return u != nil && u.ID > 0 && u.SuspendedAt == nil && u.DeletedAt == nil -} - -func (svc user) Delete(ctx context.Context, id uint64) error { - // @todo: permissions check if current user can delete this user - // @todo: notify users that user has been removed (remove from web UI) - return svc.rpo.DeleteUserByID(id) -} - -func (svc user) Suspend(ctx context.Context, id uint64) error { - // @todo: permissions check if current user can suspend this user - // @todo: notify users that user has been supsended (remove from web UI) - return svc.rpo.SuspendUserByID(id) -} - -func (svc user) Unsuspend(ctx context.Context, id uint64) error { - // @todo: permissions check if current user can unsuspend this user - // @todo: notify users that user has been unsuspended - return svc.rpo.UnsuspendUserByID(id) -} - -var _ UserService = &user{} diff --git a/sam/start.go b/sam/start.go index 1952ed776..2c43b01c1 100644 --- a/sam/start.go +++ b/sam/start.go @@ -8,6 +8,8 @@ import ( "github.com/SentimensRG/ctx" "github.com/SentimensRG/ctx/sigctx" + authService "github.com/crusttech/crust/auth/service" + samService "github.com/crusttech/crust/sam/service" "github.com/go-chi/chi" "github.com/go-chi/cors" "github.com/pkg/errors" @@ -48,6 +50,9 @@ func Init() error { }, }) + authService.Init() + samService.Init() + return nil } diff --git a/sam/types/channel.go b/sam/types/channel.go index 9a543b00b..5a64d59ae 100644 --- a/sam/types/channel.go +++ b/sam/types/channel.go @@ -24,7 +24,7 @@ type ( LastMessageID uint64 `json:",omitempty" db:"rel_last_message"` Member *ChannelMember `json:"-" db:"-"` - Members []*User `json:"-" db:"-"` + Members []*uint64 `json:"-" db:"-"` } ChannelMember struct { diff --git a/sam/types/user.go b/sam/types/user.go deleted file mode 100644 index a33ea81af..000000000 --- a/sam/types/user.go +++ /dev/null @@ -1,60 +0,0 @@ -package types - -import ( - "encoding/json" - "golang.org/x/crypto/bcrypt" - "time" -) - -type ( - User struct { - ID uint64 `json:"id" db:"id"` - Username string `json:"username" db:"username"` - Meta json.RawMessage `json:"-" db:"meta"` - OrganisationID uint64 `json:"organisationId" db:"rel_organisation"` - Password []byte `json:"-" db:"password"` - CreatedAt time.Time `json:"createdAt,omitempty" db:"created_at"` - UpdatedAt *time.Time `json:"updatedAt,omitempty" db:"updated_at"` - SuspendedAt *time.Time `json:"suspendedAt,omitempty" db:"suspended_at"` - DeletedAt *time.Time `json:"deletedAt,omitempty" db:"deleted_at"` - } - - UserFilter struct { - Query string - MembersOfChannel uint64 - } - - UserSet []*User -) - -func (uu UserSet) Walk(w func(*User) error) (err error) { - for i := range uu { - if err = w(uu[i]); err != nil { - return - } - } - - return -} - -func (u *User) Valid() bool { - return u.ID > 0 && u.SuspendedAt == nil && u.DeletedAt == nil -} - -func (u *User) Identity() uint64 { - return u.ID -} - -func (u *User) ValidatePassword(password string) bool { - return bcrypt.CompareHashAndPassword(u.Password, []byte(password)) == nil -} - -func (u *User) GeneratePassword(password string) error { - pwd, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) - if err != nil { - return err - } - - u.Password = pwd - return nil -} diff --git a/sam/websocket/payload.go b/sam/websocket/payload.go index 5d8df5272..d701ed359 100644 --- a/sam/websocket/payload.go +++ b/sam/websocket/payload.go @@ -1,11 +1,12 @@ package websocket import ( - "github.com/crusttech/crust/sam/types" + authTypes "github.com/crusttech/crust/auth/types" + samTypes "github.com/crusttech/crust/sam/types" "github.com/crusttech/crust/sam/websocket/outgoing" ) -func payloadFromMessage(msg *types.Message) *outgoing.Message { +func payloadFromMessage(msg *samTypes.Message) *outgoing.Message { return &outgoing.Message{ Message: msg.Message, ID: uint64toa(msg.ID), @@ -21,7 +22,7 @@ func payloadFromMessage(msg *types.Message) *outgoing.Message { } } -func payloadFromMessages(msg types.MessageSet) *outgoing.Messages { +func payloadFromMessages(msg samTypes.MessageSet) *outgoing.Messages { msgs := make([]*outgoing.Message, len(msg)) for k, m := range msg { msgs[k] = payloadFromMessage(m) @@ -30,18 +31,17 @@ func payloadFromMessages(msg types.MessageSet) *outgoing.Messages { return &retval } -func payloadFromChannel(ch *types.Channel) *outgoing.Channel { +func payloadFromChannel(ch *samTypes.Channel) *outgoing.Channel { return &outgoing.Channel{ ID: uint64toa(ch.ID), Name: ch.Name, LastMessageID: uint64toa(ch.LastMessageID), Topic: ch.Topic, Type: string(ch.Type), - Members: payloadFromUsers(ch.Members), } } -func payloadFromChannels(channels []*types.Channel) *outgoing.Channels { +func payloadFromChannels(channels []*samTypes.Channel) *outgoing.Channels { cc := make([]*outgoing.Channel, len(channels)) for k, c := range channels { cc[k] = payloadFromChannel(c) @@ -50,14 +50,14 @@ func payloadFromChannels(channels []*types.Channel) *outgoing.Channels { return &retval } -func payloadFromUser(user *types.User) *outgoing.User { +func payloadFromUser(user *authTypes.User) *outgoing.User { return &outgoing.User{ ID: uint64toa(user.ID), Username: user.Username, } } -func payloadFromUsers(users []*types.User) *outgoing.Users { +func payloadFromUsers(users []*authTypes.User) *outgoing.Users { uu := make([]*outgoing.User, len(users)) for k, u := range users { uu[k] = payloadFromUser(u) @@ -65,7 +65,7 @@ func payloadFromUsers(users []*types.User) *outgoing.Users { // @todo this is current instance only, need to sync this across all instances store.Walk(func(session *Session) { - if session.user.ID == u.ID { + if session.user.Identity() == u.ID { uu[k].Connections++ } }) @@ -76,7 +76,7 @@ func payloadFromUsers(users []*types.User) *outgoing.Users { return &retval } -func payloadFromAttachment(in *types.Attachment) *outgoing.Attachment { +func payloadFromAttachment(in *samTypes.Attachment) *outgoing.Attachment { if in == nil { return nil } diff --git a/sam/websocket/router.go b/sam/websocket/router.go index b026d110d..b0276bbce 100644 --- a/sam/websocket/router.go +++ b/sam/websocket/router.go @@ -7,16 +7,10 @@ import ( "github.com/go-chi/chi" "github.com/crusttech/crust/sam/repository" - "github.com/crusttech/crust/sam/service" ) func MountRoutes(ctx context.Context, config *repository.Flags) func(chi.Router) { return func(r chi.Router) { - var ( - // @todo move this 1 level up & join with rest init functions - svcUser = service.User() - ) - repo := repository.New() go func() { @@ -26,7 +20,7 @@ func MountRoutes(ctx context.Context, config *repository.Flags) func(chi.Router) }() eq.store(ctx, repo) - websocket := Websocket{}.New(svcUser, config) + websocket := Websocket{}.New(config) r.Group(func(r chi.Router) { r.Route("/websocket", func(r chi.Router) { r.Get("/", websocket.Open) diff --git a/sam/websocket/session.go b/sam/websocket/session.go index 65af0955b..06d5e9ebb 100644 --- a/sam/websocket/session.go +++ b/sam/websocket/session.go @@ -5,12 +5,13 @@ import ( "log" "time" + "github.com/crusttech/crust/internal/auth" "github.com/gorilla/websocket" "github.com/pkg/errors" + authService "github.com/crusttech/crust/auth/service" "github.com/crusttech/crust/sam/repository" - "github.com/crusttech/crust/sam/service" - "github.com/crusttech/crust/sam/types" + samService "github.com/crusttech/crust/sam/service" "github.com/crusttech/crust/sam/websocket/outgoing" ) @@ -30,12 +31,17 @@ type ( config *repository.Flags - user *types.User + user auth.Identifiable + + svc struct { + user authService.UserService + ch samService.ChannelService + } } ) func (Session) New(ctx context.Context, config *repository.Flags, conn *websocket.Conn) *Session { - return &Session{ + s := &Session{ conn: conn, ctx: ctx, config: config, @@ -43,6 +49,11 @@ func (Session) New(ctx context.Context, config *repository.Flags, conn *websocke send: make(chan []byte, 512), stop: make(chan []byte, 1), } + + s.svc.user = authService.DefaultUser + s.svc.ch = samService.DefaultChannel + + return s } func (sess *Session) Context() context.Context { @@ -51,22 +62,22 @@ func (sess *Session) Context() context.Context { func (sess *Session) connected() { // Tell everyone that user has connected - sess.sendToAll(&outgoing.Connected{UserID: uint64toa(sess.user.ID)}) + sess.sendToAll(&outgoing.Connected{UserID: uint64toa(sess.user.Identity())}) // Subscribe this user to all channels - if chs, err := service.Channel().FindByMembership(sess.ctx); err != nil { + if chs, err := sess.svc.ch.FindByMembership(sess.ctx); err != nil { log.Printf("Error: %v", err) } else { for _, ch := range chs { sess.subs.Add(uint64toa(ch.ID)) } - log.Printf("Subscribing %d to %d channels", sess.user.ID, len(chs)) + log.Printf("Subscribing %d to %d channels", sess.user.Identity(), len(chs)) } } func (sess *Session) disconnected() { // Tell everyone that user has disconnected - sess.sendToAll(&outgoing.Disconnected{UserID: uint64toa(sess.user.ID)}) + sess.sendToAll(&outgoing.Disconnected{UserID: uint64toa(sess.user.Identity())}) } func (sess *Session) Handle() error { diff --git a/sam/websocket/session_incoming_channel.go b/sam/websocket/session_incoming_channel.go index df1927d80..fd9245593 100644 --- a/sam/websocket/session_incoming_channel.go +++ b/sam/websocket/session_incoming_channel.go @@ -49,6 +49,8 @@ func (s *Session) channelList(ctx context.Context, p *incoming.Channels) error { return err } + // @todo count members for all channels + return s.sendReply(payloadFromChannels(channels)) } @@ -79,6 +81,8 @@ func (s *Session) channelCreate(ctx context.Context, p *incoming.ChannelCreate) // @todo this should go over all user's sessons and subscribe there as well + // @todo load channel member count + pl := payloadFromChannel(ch) if ch.Type == types.ChannelTypePublic { @@ -124,5 +128,7 @@ func (s *Session) channelUpdate(ctx context.Context, p *incoming.ChannelUpdate) return err } + // @todo load channel member count + return s.sendToAllSubscribers(payloadFromChannel(ch), p.ID) } diff --git a/sam/websocket/session_incoming_user.go b/sam/websocket/session_incoming_user.go index 9ccfc13b9..99a4e36b0 100644 --- a/sam/websocket/session_incoming_user.go +++ b/sam/websocket/session_incoming_user.go @@ -2,12 +2,11 @@ package websocket import ( "context" - "github.com/crusttech/crust/sam/service" "github.com/crusttech/crust/sam/websocket/incoming" ) func (s *Session) userList(ctx context.Context, p *incoming.Users) error { - users, err := service.User().Find(ctx, nil) + users, err := s.svc.user.With(ctx).Find(nil) if err != nil { return err } diff --git a/sam/websocket/websocket.go b/sam/websocket/websocket.go index 6fcaebb2c..0a01af3d7 100644 --- a/sam/websocket/websocket.go +++ b/sam/websocket/websocket.go @@ -1,7 +1,6 @@ package websocket import ( - "context" "log" "net/http" @@ -9,29 +8,25 @@ import ( "github.com/pkg/errors" "github.com/titpetric/factory/resputil" + authService "github.com/crusttech/crust/auth/service" "github.com/crusttech/crust/internal/auth" "github.com/crusttech/crust/sam/repository" - "github.com/crusttech/crust/sam/types" ) type ( Websocket struct { svc struct { - userFinder wsUserFinder + user authService.UserService } config *repository.Flags } - - wsUserFinder interface { - FindByID(ctx context.Context, userID uint64) (*types.User, error) - } ) -func (Websocket) New(svcUser wsUserFinder, config *repository.Flags) *Websocket { +func (Websocket) New(config *repository.Flags) *Websocket { ws := &Websocket{ config: config, } - ws.svc.userFinder = svcUser + ws.svc.user = authService.DefaultUser return ws } @@ -54,8 +49,7 @@ func (ws Websocket) Open(w http.ResponseWriter, r *http.Request) { return } - // @todo validate user (ws.svc.userFinder) here... - user, err := ws.svc.userFinder.FindByID(ctx, identity.Identity()) + user, err := ws.svc.user.With(ctx).FindByID(identity.Identity()) if err != nil { resputil.JSON(w, err) return