3
0

Extracted payloads from websocket pkg to internal/payload

This commit is contained in:
Denis Arh 2018-09-27 15:37:11 +02:00
parent 4788e61c07
commit bf8f6f4213
24 changed files with 220 additions and 195 deletions

View File

@ -28,6 +28,8 @@ type (
Query string
MembersOfChannel uint64
}
UserSet []*User
)
func (u *User) Valid() bool {
@ -51,3 +53,13 @@ func (u *User) GeneratePassword(password string) error {
u.Password = pwd
return nil
}
func (uu UserSet) Walk(w func(*User) error) (err error) {
for i := range uu {
if err = w(uu[i]); err != nil {
return
}
}
return
}

View File

@ -0,0 +1,12 @@
package payload
import (
"encoding/json"
"github.com/crusttech/crust/internal/payload/incoming"
)
func Unmarshal(raw []byte) (*incoming.Payload, error) {
var p = &incoming.Payload{}
return p, json.Unmarshal(raw, p)
}

View File

@ -0,0 +1,93 @@
package payload
import (
auth "github.com/crusttech/crust/auth/types"
"github.com/crusttech/crust/internal/payload/outgoing"
sam "github.com/crusttech/crust/sam/types"
)
func Message(msg *sam.Message) *outgoing.Message {
return &outgoing.Message{
Message: msg.Message,
ID: Uint64toa(msg.ID),
ChannelID: Uint64toa(msg.ChannelID),
Type: string(msg.Type),
ReplyTo: Uint64toa(msg.ReplyTo),
User: User(msg.User),
Attachment: Attachment(msg.Attachment),
CreatedAt: msg.CreatedAt,
UpdatedAt: msg.UpdatedAt,
}
}
func Messages(msg sam.MessageSet) *outgoing.MessageSet {
msgs := make([]*outgoing.Message, len(msg))
for k, m := range msg {
msgs[k] = Message(m)
}
retval := outgoing.MessageSet(msgs)
return &retval
}
func Channel(ch *sam.Channel) *outgoing.Channel {
return &outgoing.Channel{
ID: Uint64toa(ch.ID),
Name: ch.Name,
LastMessageID: Uint64toa(ch.LastMessageID),
Topic: ch.Topic,
Type: string(ch.Type),
Members: Uint64stoa(ch.Members),
}
}
func Channels(channels sam.ChannelSet) *outgoing.ChannelSet {
cc := make([]*outgoing.Channel, len(channels))
for k, c := range channels {
cc[k] = Channel(c)
}
retval := outgoing.ChannelSet(cc)
return &retval
}
func User(user *auth.User) *outgoing.User {
if user == nil {
return nil
}
return &outgoing.User{
ID: Uint64toa(user.ID),
Username: user.Username,
}
}
func Users(users []*auth.User) *outgoing.UserSet {
uu := make([]*outgoing.User, len(users))
for k, u := range users {
uu[k] = User(u)
uu[k].Connections = 0
}
retval := outgoing.UserSet(uu)
return &retval
}
func Attachment(in *sam.Attachment) *outgoing.Attachment {
if in == nil {
return nil
}
return &outgoing.Attachment{
ID: Uint64toa(in.ID),
UserID: Uint64toa(in.UserID),
Url: in.Url,
PreviewUrl: in.PreviewUrl,
Size: in.Size,
Mimetype: in.Mimetype,
Name: in.Name,
CreatedAt: in.CreatedAt,
UpdatedAt: in.UpdatedAt,
}
}

55
internal/payload/util.go Normal file
View File

@ -0,0 +1,55 @@
package payload
import (
"regexp"
"strconv"
)
var truthy = regexp.MustCompile("^\\s*(t(rue)?|y(es)?|1)\\s*$")
func Uint64toa(i uint64) string {
return strconv.FormatUint(i, 10)
}
func Uint64stoa(uu []uint64) []string {
ss := make([]string, len(uu))
for i, u := range uu {
ss[i] = Uint64toa(u)
}
return ss
}
//// parseInt64 parses an string to int64
//func parseInt64(s string) int64 {
// if s == "" {
// return 0
// }
// i, _ := strconv.ParseInt(s, 10, 64)
//
// return i
//}
//
// parseUInt64 parses an string to uint64
func ParseUInt64(s string) uint64 {
if s == "" {
return 0
}
i, _ := strconv.ParseUint(s, 10, 64)
return i
}
//// parseUInt64 parses an string to uint64
//func parseBool(s string) bool {
// return truthy.MatchString(strings.ToLower(s))
//}
//
//// is checks if string s is contained in matches
//func is(s string, matches ...string) bool {
// for _, v := range matches {
// if s == v {
// return true
// }
// }
// return false
//}

7
sam/websocket/encoder.go Normal file
View File

@ -0,0 +1,7 @@
package websocket
type (
MessageEncoder interface {
EncodeMessage() ([]byte, error)
}
)

View File

@ -1,100 +0,0 @@
package websocket
import (
authTypes "github.com/crusttech/crust/auth/types"
samTypes "github.com/crusttech/crust/sam/types"
"github.com/crusttech/crust/sam/websocket/outgoing"
)
func payloadFromMessage(msg *samTypes.Message) *outgoing.Message {
return &outgoing.Message{
Message: msg.Message,
ID: uint64toa(msg.ID),
ChannelID: uint64toa(msg.ChannelID),
Type: string(msg.Type),
ReplyTo: uint64toa(msg.ReplyTo),
User: payloadFromUser(msg.User),
Attachment: payloadFromAttachment(msg.Attachment),
CreatedAt: msg.CreatedAt,
UpdatedAt: msg.UpdatedAt,
}
}
func payloadFromMessages(msg samTypes.MessageSet) *outgoing.MessageSet {
msgs := make([]*outgoing.Message, len(msg))
for k, m := range msg {
msgs[k] = payloadFromMessage(m)
}
retval := outgoing.MessageSet(msgs)
return &retval
}
func payloadFromChannel(ch *samTypes.Channel) *outgoing.Channel {
return &outgoing.Channel{
ID: uint64toa(ch.ID),
Name: ch.Name,
LastMessageID: uint64toa(ch.LastMessageID),
Topic: ch.Topic,
Type: string(ch.Type),
Members: uint64stoa(ch.Members),
}
}
func payloadFromChannels(channels []*samTypes.Channel) *outgoing.ChannelSet {
cc := make([]*outgoing.Channel, len(channels))
for k, c := range channels {
cc[k] = payloadFromChannel(c)
}
retval := outgoing.ChannelSet(cc)
return &retval
}
func payloadFromUser(user *authTypes.User) *outgoing.User {
if user == nil {
return nil
}
return &outgoing.User{
ID: uint64toa(user.ID),
Username: user.Username,
}
}
func payloadFromUsers(users []*authTypes.User) *outgoing.UserSet {
uu := make([]*outgoing.User, len(users))
for k, u := range users {
uu[k] = payloadFromUser(u)
uu[k].Connections = 0
// @todo this is current instance only, need to sync this across all instances
store.Walk(func(session *Session) {
if session.user.Identity() == u.ID {
uu[k].Connections++
}
})
}
retval := outgoing.UserSet(uu)
return &retval
}
func payloadFromAttachment(in *samTypes.Attachment) *outgoing.Attachment {
if in == nil {
return nil
}
return &outgoing.Attachment{
ID: uint64toa(in.ID),
UserID: uint64toa(in.UserID),
Url: in.Url,
PreviewUrl: in.PreviewUrl,
Size: in.Size,
Mimetype: in.Mimetype,
Name: in.Name,
CreatedAt: in.CreatedAt,
UpdatedAt: in.UpdatedAt,
}
}

View File

@ -6,6 +6,8 @@ import (
"time"
"github.com/crusttech/crust/internal/auth"
"github.com/crusttech/crust/internal/payload"
"github.com/crusttech/crust/internal/payload/outgoing"
"github.com/crusttech/crust/sam/types"
"github.com/gorilla/websocket"
"github.com/pkg/errors"
@ -13,7 +15,6 @@ import (
authService "github.com/crusttech/crust/auth/service"
"github.com/crusttech/crust/sam/repository"
samService "github.com/crusttech/crust/sam/service"
"github.com/crusttech/crust/sam/websocket/outgoing"
)
type (
@ -68,31 +69,31 @@ func (sess *Session) connected() {
if users, err := sess.svc.user.With(sess.ctx).Find(nil); err != nil {
log.Printf("Error: %v", err)
} else {
sess.sendReply(payloadFromUsers(users))
sess.sendReply(payload.Users(users))
}
// Push user info about all channels he has access to...
if cc, err := sess.svc.ch.With(sess.ctx).Find(&types.ChannelFilter{IncludeMembers: true}); err != nil {
log.Printf("Error: %v", err)
} else {
sess.sendReply(payloadFromChannels(cc))
sess.sendReply(payload.Channels(cc))
log.Printf("Subscribing %d to %d channels", sess.user.Identity(), len(cc))
cc.Walk(func(c *types.Channel) error {
// Subscribe this user/session to all channels
sess.subs.Add(uint64toa(c.ID))
sess.subs.Add(payload.Uint64toa(c.ID))
return nil
})
}
// Tell everyone that user has connected
sess.sendToAll(&outgoing.Connected{UserID: uint64toa(sess.user.Identity())})
sess.sendToAll(&outgoing.Connected{UserID: payload.Uint64toa(sess.user.Identity())})
}
func (sess *Session) disconnected() {
// Tell everyone that user has disconnected
sess.sendToAll(&outgoing.Disconnected{UserID: uint64toa(sess.user.Identity())})
sess.sendToAll(&outgoing.Disconnected{UserID: payload.Uint64toa(sess.user.Identity())})
}
func (sess *Session) Handle() error {

View File

@ -1,19 +1,18 @@
package websocket
import (
"encoding/json"
"github.com/crusttech/crust/sam/websocket/incoming"
"github.com/crusttech/crust/internal/payload"
"github.com/pkg/errors"
)
func (s *Session) dispatch(raw []byte) (err error) {
var p = &incoming.Payload{}
if err = json.Unmarshal(raw, p); err != nil {
func (s *Session) dispatch(raw []byte) error {
var p, err = payload.Unmarshal(raw)
if err != nil {
return errors.Wrap(err, "Session.incoming: payload malformed")
}
ctx := s.Context()
switch {
// message actions

View File

@ -4,9 +4,10 @@ import (
"context"
"github.com/crusttech/crust/internal/auth"
"github.com/crusttech/crust/internal/payload"
"github.com/crusttech/crust/internal/payload/incoming"
"github.com/crusttech/crust/internal/payload/outgoing"
"github.com/crusttech/crust/sam/types"
"github.com/crusttech/crust/sam/websocket/incoming"
"github.com/crusttech/crust/sam/websocket/outgoing"
)
func (s *Session) channelJoin(ctx context.Context, p *incoming.ChannelJoin) error {
@ -17,7 +18,7 @@ func (s *Session) channelJoin(ctx context.Context, p *incoming.ChannelJoin) erro
// Telling all subscribers of the channel we're joining that we are joining.
var chJoin = &outgoing.ChannelJoin{
ID: p.ChannelID,
UserID: uint64toa(auth.GetIdentityFromContext(ctx).Identity()),
UserID: payload.Uint64toa(auth.GetIdentityFromContext(ctx).Identity()),
}
// Send to all channel subscribers
@ -35,7 +36,7 @@ func (s *Session) channelPart(ctx context.Context, p *incoming.ChannelPart) erro
// This payload will tell everyone that we're parting from ALL channels
var chPart = &outgoing.ChannelPart{
ID: p.ChannelID,
UserID: uint64toa(auth.GetIdentityFromContext(ctx).Identity()),
UserID: payload.Uint64toa(auth.GetIdentityFromContext(ctx).Identity()),
}
s.sendToAllSubscribers(chPart, p.ChannelID)
@ -51,7 +52,7 @@ func (s *Session) channelList(ctx context.Context, p *incoming.Channels) error {
// @todo count members for all channels
return s.sendReply(payloadFromChannels(channels))
return s.sendReply(payload.Channels(channels))
}
func (s *Session) channelCreate(ctx context.Context, p *incoming.ChannelCreate) (err error) {
@ -77,13 +78,13 @@ func (s *Session) channelCreate(ctx context.Context, p *incoming.ChannelCreate)
}
// Explicitly subscribe to newly created channel
s.subs.Add(uint64toa(ch.ID))
s.subs.Add(payload.Uint64toa(ch.ID))
// @todo this should go over all user's sessons and subscribe there as well
// @todo load channel member count
pl := payloadFromChannel(ch)
pl := payload.Channel(ch)
if ch.Type == types.ChannelTypePublic {
return s.sendToAll(pl)
@ -94,19 +95,19 @@ func (s *Session) channelCreate(ctx context.Context, p *incoming.ChannelCreate)
}
func (s *Session) channelDelete(ctx context.Context, p *incoming.ChannelDelete) (err error) {
err = s.svc.ch.With(ctx).Delete(parseUInt64(p.ChannelID))
err = s.svc.ch.With(ctx).Delete(payload.ParseUInt64(p.ChannelID))
if err != nil {
return err
}
return s.sendToAllSubscribers(&outgoing.ChannelDeleted{
ID: p.ChannelID,
UserID: uint64toa(auth.GetIdentityFromContext(ctx).Identity()),
UserID: payload.Uint64toa(auth.GetIdentityFromContext(ctx).Identity()),
}, p.ChannelID)
}
func (s *Session) channelUpdate(ctx context.Context, p *incoming.ChannelUpdate) error {
ch, err := s.svc.ch.With(ctx).FindByID(parseUInt64(p.ID))
ch, err := s.svc.ch.With(ctx).FindByID(payload.ParseUInt64(p.ID))
if err != nil {
return err
}
@ -130,5 +131,5 @@ func (s *Session) channelUpdate(ctx context.Context, p *incoming.ChannelUpdate)
// @todo load channel member count
return s.sendToAllSubscribers(payloadFromChannel(ch), p.ID)
return s.sendToAllSubscribers(payload.Channel(ch), p.ID)
}

View File

@ -3,15 +3,16 @@ package websocket
import (
"context"
"github.com/crusttech/crust/internal/payload"
"github.com/crusttech/crust/internal/payload/incoming"
"github.com/crusttech/crust/internal/payload/outgoing"
"github.com/crusttech/crust/sam/types"
"github.com/crusttech/crust/sam/websocket/incoming"
"github.com/crusttech/crust/sam/websocket/outgoing"
)
func (s *Session) messageCreate(ctx context.Context, p *incoming.MessageCreate) error {
var (
msg = &types.Message{
ChannelID: parseUInt64(p.ChannelID),
ChannelID: payload.ParseUInt64(p.ChannelID),
Message: p.Message,
}
)
@ -21,13 +22,13 @@ func (s *Session) messageCreate(ctx context.Context, p *incoming.MessageCreate)
return err
}
return s.sendToAllSubscribers(payloadFromMessage(msg), p.ChannelID)
return s.sendToAllSubscribers(payload.Message(msg), p.ChannelID)
}
func (s *Session) messageUpdate(ctx context.Context, p *incoming.MessageUpdate) error {
var (
msg = &types.Message{
ID: parseUInt64(p.ID),
ID: payload.ParseUInt64(p.ID),
Message: p.Message,
}
)
@ -47,7 +48,7 @@ func (s *Session) messageUpdate(ctx context.Context, p *incoming.MessageUpdate)
func (s *Session) messageDelete(ctx context.Context, p *incoming.MessageDelete) error {
var (
id = parseUInt64(p.ID)
id = payload.ParseUInt64(p.ID)
)
if err := s.svc.msg.With(ctx).Delete(id); err != nil {
@ -60,9 +61,9 @@ func (s *Session) messageDelete(ctx context.Context, p *incoming.MessageDelete)
func (s *Session) messageHistory(ctx context.Context, p *incoming.Messages) error {
var (
filter = &types.MessageFilter{
ChannelID: parseUInt64(p.ChannelID),
FromMessageID: parseUInt64(p.FromID),
UntilMessageID: parseUInt64(p.UntilID),
ChannelID: payload.ParseUInt64(p.ChannelID),
FromMessageID: payload.ParseUInt64(p.FromID),
UntilMessageID: payload.ParseUInt64(p.UntilID),
// Max no. of messages we will return
Limit: 50,
@ -74,5 +75,5 @@ func (s *Session) messageHistory(ctx context.Context, p *incoming.Messages) erro
return err
}
return s.sendReply(payloadFromMessages(messages))
return s.sendReply(payload.Messages(messages))
}

View File

@ -2,7 +2,8 @@ package websocket
import (
"context"
"github.com/crusttech/crust/sam/websocket/incoming"
"github.com/crusttech/crust/internal/payload"
"github.com/crusttech/crust/internal/payload/incoming"
)
func (s *Session) userList(ctx context.Context, p *incoming.Users) error {
@ -10,5 +11,5 @@ func (s *Session) userList(ctx context.Context, p *incoming.Users) error {
if err != nil {
return err
}
return s.sendReply(payloadFromUsers(users))
return s.sendReply(payload.Users(users))
}

View File

@ -6,11 +6,10 @@ import (
"github.com/crusttech/crust/sam/repository"
"github.com/crusttech/crust/sam/types"
"github.com/crusttech/crust/sam/websocket/outgoing"
)
// Sends message to subscribers
func (s *Session) sendToAllSubscribers(p outgoing.MessageEncoder, channelID string) error {
func (s *Session) sendToAllSubscribers(p MessageEncoder, channelID string) error {
pb, err := p.EncodeMessage()
if err != nil {
return err
@ -20,7 +19,7 @@ func (s *Session) sendToAllSubscribers(p outgoing.MessageEncoder, channelID stri
}
// Sends message to all connected clients
func (s *Session) sendToAll(p outgoing.MessageEncoder) error {
func (s *Session) sendToAll(p MessageEncoder) error {
pb, err := p.EncodeMessage()
if err != nil {
return err
@ -32,7 +31,7 @@ func (s *Session) sendToAll(p outgoing.MessageEncoder) error {
// @todo: this isn't going to be correct - a user may have open multiple clients,
// that will connect to different edge SAM servers. It should also go
// through a repository.Events().Push (EventQueueItem) path.
func (s *Session) sendReply(p outgoing.MessageEncoder) error {
func (s *Session) sendReply(p MessageEncoder) error {
pb, err := p.EncodeMessage()
if err != nil {
return err

View File

@ -1,56 +0,0 @@
package websocket
import (
"regexp"
"strconv"
"strings"
)
var truthy = regexp.MustCompile("^\\s*(t(rue)?|y(es)?|1)\\s*$")
func uint64toa(i uint64) string {
return strconv.FormatUint(i, 10)
}
func uint64stoa(uu []uint64) []string {
ss := make([]string, len(uu))
for i, u := range uu {
ss[i] = uint64toa(u)
}
return ss
}
// parseInt64 parses an string to int64
func parseInt64(s string) int64 {
if s == "" {
return 0
}
i, _ := strconv.ParseInt(s, 10, 64)
return i
}
// parseUInt64 parses an string to uint64
func parseUInt64(s string) uint64 {
if s == "" {
return 0
}
i, _ := strconv.ParseUint(s, 10, 64)
return i
}
// parseUInt64 parses an string to uint64
func parseBool(s string) bool {
return truthy.MatchString(strings.ToLower(s))
}
// is checks if string s is contained in matches
func is(s string, matches ...string) bool {
for _, v := range matches {
if s == v {
return true
}
}
return false
}