From 70dedcaaba45e6890ebf37006cc3c8ad27bd22e5 Mon Sep 17 00:00:00 2001 From: Denis Arh Date: Tue, 4 May 2021 14:51:28 +0200 Subject: [PATCH] Refactor & move websocket code under pkg --- app/boot_levels.go | 4 +- automation/service/session.go | 13 +- automation/types/session.go | 2 +- pkg/websocket/payload.go | 44 ++++ {websocket => pkg/websocket}/router.go | 2 +- pkg/websocket/server.go | 136 ++++++++++ pkg/websocket/server_test.go | 63 +++++ pkg/websocket/session.go | 340 +++++++++++++++++++++++++ pkg/websocket/session_test.go | 53 ++++ pkg/wfexec/prompt.go | 12 + pkg/wfexec/session.go | 12 +- system/service/reminder.go | 18 +- websocket/session.go | 309 ---------------------- websocket/types.go | 55 ---- websocket/websocket.go | 87 ------- websocket/websocket_test.go | 127 --------- 16 files changed, 679 insertions(+), 598 deletions(-) create mode 100644 pkg/websocket/payload.go rename {websocket => pkg/websocket}/router.go (87%) create mode 100644 pkg/websocket/server.go create mode 100644 pkg/websocket/server_test.go create mode 100644 pkg/websocket/session.go create mode 100644 pkg/websocket/session_test.go delete mode 100644 websocket/session.go delete mode 100644 websocket/types.go delete mode 100644 websocket/websocket.go delete mode 100644 websocket/websocket_test.go diff --git a/app/boot_levels.go b/app/boot_levels.go index 435662355..66307a702 100644 --- a/app/boot_levels.go +++ b/app/boot_levels.go @@ -29,11 +29,11 @@ import ( "github.com/cortezaproject/corteza-server/pkg/rbac" "github.com/cortezaproject/corteza-server/pkg/scheduler" "github.com/cortezaproject/corteza-server/pkg/sentry" + "github.com/cortezaproject/corteza-server/pkg/websocket" "github.com/cortezaproject/corteza-server/store" sysService "github.com/cortezaproject/corteza-server/system/service" sysEvent "github.com/cortezaproject/corteza-server/system/service/event" "github.com/cortezaproject/corteza-server/system/types" - "github.com/cortezaproject/corteza-server/websocket" "go.uber.org/zap" gomail "gopkg.in/mail.v2" ) @@ -252,7 +252,7 @@ func (app *CortezaApp) InitServices(ctx context.Context) (err error) { return err } - app.WsServer = websocket.Websocket(app.Log, app.Opt.Websocket) + app.WsServer = websocket.Server(app.Log, app.Opt.Websocket) ctx = actionlog.RequestOriginToContext(ctx, actionlog.RequestOrigin_APP_Init) defer sentry.Recover() diff --git a/automation/service/session.go b/automation/service/session.go index 9f970e1fa..0896e666f 100644 --- a/automation/service/session.go +++ b/automation/service/session.go @@ -212,7 +212,16 @@ func (svc *session) Resume(sessionID, stateID uint64, i auth.Identifiable, input return errors.NotFound("session not found") } - return ses.Resume(ctx, stateID, input) + resPrompt, err := ses.Resume(ctx, stateID, input) + if err != nil { + return err + } + + if err = svc.promptSender.Send("workflowSessionResumed", resPrompt, resPrompt.OwnerId); err != nil { + svc.log.Error("failed to send prompt resume status to user", zap.Error(err)) + } + + return nil } // spawns a new session @@ -367,7 +376,7 @@ func (svc *session) stateChangeHandler(ctx context.Context) wfexec.StateChangeHa // Send the pending prompts to user if svc.promptSender != nil { for _, pp := range s.AllPendingPrompts() { - if err := svc.promptSender.Send("ok", pp, pp.OwnerId); err != nil { + if err := svc.promptSender.Send("workflowSessionPrompt", pp, pp.OwnerId); err != nil { svc.log.Error("failed to send prompt to user", zap.Error(err)) } } diff --git a/automation/types/session.go b/automation/types/session.go index 1dc34414b..d564185fc 100644 --- a/automation/types/session.go +++ b/automation/types/session.go @@ -103,7 +103,7 @@ func (s Session) Exec(ctx context.Context, step wfexec.Step, input *expr.Vars) e return s.session.Exec(ctx, step, input) } -func (s Session) Resume(ctx context.Context, stateID uint64, input *expr.Vars) error { +func (s Session) Resume(ctx context.Context, stateID uint64, input *expr.Vars) (*wfexec.ResumedPrompt, error) { return s.session.Resume(ctx, stateID, input) } diff --git a/pkg/websocket/payload.go b/pkg/websocket/payload.go new file mode 100644 index 000000000..0924743b9 --- /dev/null +++ b/pkg/websocket/payload.go @@ -0,0 +1,44 @@ +package websocket + +import ( + "encoding/json" +) + +type ( + // Auth is JWT token provided by client as first message, + // and will be passed whenever it changes + payloadAuth struct { + AccessToken string `json:"accessToken"` + } + + payloadWrap struct { + Type string `json:"@type"` + Value json.RawMessage `json:"@value"` + } +) + +const ( + payloadTypeCredentials = "credentials" +) + +var ( + closingUnidentifiedConn, _ = MarshalPayload("error", "closing unidentified connection") + ok, _ = MarshalPayload("message", "authenticated") +) + +func (p payloadWrap) UnmarshalValue(m interface{}) error { + return json.Unmarshal(p.Value, m) +} + +func MarshalPayload(t string, m interface{}) ([]byte, error) { + var ( + err error + w = payloadWrap{Type: t} + ) + + if w.Value, err = json.Marshal(m); err != nil { + return nil, err + } + + return json.Marshal(w) +} diff --git a/websocket/router.go b/pkg/websocket/router.go similarity index 87% rename from websocket/router.go rename to pkg/websocket/router.go index ae57b1a2e..25ce1a059 100644 --- a/websocket/router.go +++ b/pkg/websocket/router.go @@ -8,7 +8,7 @@ import ( // No middleware used, since anyone can open connection and // send first message with valid JWT token, // If it's valid then we keep the connection open or close it -func (ws *websocket) MountRoutes(r chi.Router) { +func (ws *server) MountRoutes(r chi.Router) { // Initialize handlers & controllers. r.Get("/", ws.Open) } diff --git a/pkg/websocket/server.go b/pkg/websocket/server.go new file mode 100644 index 000000000..956148579 --- /dev/null +++ b/pkg/websocket/server.go @@ -0,0 +1,136 @@ +package websocket + +import ( + "github.com/cortezaproject/corteza-server/pkg/auth" + "github.com/cortezaproject/corteza-server/pkg/errors" + "github.com/cortezaproject/corteza-server/pkg/options" + "github.com/cortezaproject/corteza-server/pkg/slice" + "github.com/dgrijalva/jwt-go" + "github.com/gorilla/websocket" + "go.uber.org/zap" + "io" + "net/http" + "sync" +) + +var ( + // upgrader handles websocket requests from peers + upgrader = websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + + // Allow connections from any Origin + CheckOrigin: func(r *http.Request) bool { return true }, + } +) + +type ( + server struct { + config options.WebsocketOpt + logger *zap.Logger + + // user id => session id => session + sessions map[uint64]map[uint64]io.Writer + + accessToken interface { + Authenticate(string) (jwt.MapClaims, error) + } + + // keep lock on session map changes + l sync.RWMutex + } +) + +func Server(logger *zap.Logger, config options.WebsocketOpt) *server { + if !config.LogEnabled { + logger = zap.NewNop() + } + + return &server{ + config: config, + logger: logger.WithOptions(zap.AddStacktrace(zap.PanicLevel)).Named("websocket"), + accessToken: auth.DefaultJwtHandler, + sessions: make(map[uint64]map[uint64]io.Writer), + } +} + +func (ws *server) Open(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + conn, err := upgrader.Upgrade(w, r, nil) + if _, ok := err.(websocket.HandshakeError); ok { + errors.ProperlyServeHTTP(w, r, errors.Internal("need a websocket handshake"), false) + return + } else if err != nil { + errors.ProperlyServeHTTP(w, r, errors.Internal("failed to upgrade connection").Wrap(err), false) + return + } + + // init new session + // + // session will add itself back to server's session map when + // ready (if user authenticates itself) + ses := Session(ctx, ws, conn) + + if err = ses.Handle(); err != nil { + ws.logger.Warn("websocket session handler error", zap.Error(err)) + } +} + +// Send delivers payload to one, more or all users +// +// Omit userIDs to deliver to ALL users +func (ws *server) Send(t string, payload interface{}, userIDs ...uint64) error { + pb, err := MarshalPayload(t, payload) + if err != nil { + return err + } + + var ( + sendToAll = len(userIDs) == 0 + uMap = slice.ToUint64BoolMap(userIDs) + ) + + ws.l.RLock() + defer ws.l.RUnlock() + + for uid := range ws.sessions { + if sendToAll || (!sendToAll && uMap[uid]) { + for _, s := range ws.sessions[uid] { + _, err = s.Write(pb) + } + } + } + + return nil +} + +func (ws *server) StoreSession(s *session) { + ws.l.Lock() + defer ws.l.Unlock() + if s.identity != nil { + ws.storeSession(s, s.identity.Identity(), s.id) + } +} + +func (ws *server) storeSession(w io.Writer, uid, sid uint64) { + if ws.sessions[uid] == nil { + ws.sessions[uid] = make(map[uint64]io.Writer) + + } + + ws.sessions[uid][sid] = w +} + +func (ws *server) RemoveSession(s *session) { + ws.l.Lock() + defer ws.l.Unlock() + if s.identity != nil { + uid := s.identity.Identity() + delete(ws.sessions[uid], s.id) + + if len(ws.sessions[uid]) == 0 { + delete(ws.sessions, uid) + } + } +} diff --git a/pkg/websocket/server_test.go b/pkg/websocket/server_test.go new file mode 100644 index 000000000..f50e73730 --- /dev/null +++ b/pkg/websocket/server_test.go @@ -0,0 +1,63 @@ +package websocket + +import ( + "bytes" + "github.com/cortezaproject/corteza-server/pkg/options" + "github.com/stretchr/testify/require" + "go.uber.org/zap" + "testing" +) + +func TestWebsocketSend_NoSessions(t *testing.T) { + var ( + req = require.New(t) + ws = Server(zap.NewNop(), options.WebsocketOpt{}) + ) + + req.NoError(ws.Send("msg", "msg")) + req.NoError(ws.Send("msg", "msg", 1)) + req.NoError(ws.Send("msg", "msg", 1, 2)) + req.NoError(ws.Send("msg", "msg", 1, 2, 3)) +} + +func TestWebsocketSend_ExistingSessions(t *testing.T) { + var ( + req = require.New(t) + ws = Server(zap.NewNop(), options.WebsocketOpt{}) + + s1User uint64 = 100 + s1ID uint64 = 101 + + s2User uint64 = 200 + s2ID uint64 = 201 + + s1 = &bytes.Buffer{} + s2 = &bytes.Buffer{} + ) + + ws.storeSession(s1, s1User, s1ID) + ws.storeSession(s2, s2User, s2ID) + + req.Empty(s1) + req.Empty(s2) + + req.NoError(ws.Send("msg", "msg", 0)) + req.Empty(s1) + req.Empty(s2) + + req.NoError(ws.Send("msg", "msg1", s1User)) + req.Equal(s1.String(), `{"msg":"msg1"}`) + req.Equal(s2.String(), "") + + req.NoError(ws.Send("msg", "msg2", s2User)) + req.Equal(s1.String(), `{"msg":"msg1"}`) + req.Equal(s2.String(), `{"msg":"msg2"}`) + + req.NoError(ws.Send("both", "msg3", s1User, s2User)) + req.Equal(s1.String(), `{"msg":"msg1"}{"both":"msg3"}`) + req.Equal(s2.String(), `{"msg":"msg2"}{"both":"msg3"}`) + + req.NoError(ws.Send("all", "msg4")) + req.Equal(s1.String(), `{"msg":"msg1"}{"both":"msg3"}{"all":"msg4"}`) + req.Equal(s2.String(), `{"msg":"msg2"}{"both":"msg3"}{"all":"msg4"}`) +} diff --git a/pkg/websocket/session.go b/pkg/websocket/session.go new file mode 100644 index 000000000..833f58c12 --- /dev/null +++ b/pkg/websocket/session.go @@ -0,0 +1,340 @@ +package websocket + +import ( + "context" + "encoding/json" + "fmt" + "github.com/cortezaproject/corteza-server/pkg/auth" + "github.com/cortezaproject/corteza-server/pkg/errors" + "github.com/cortezaproject/corteza-server/pkg/id" + "github.com/cortezaproject/corteza-server/pkg/options" + "github.com/gorilla/websocket" + "go.uber.org/zap" + "net" + "sync" + "time" +) + +// active sessions of users +var ( + // wrapper around nextID that will aid service testing + nextID = func() uint64 { + return id.Next() + } +) + +type ( + session struct { + id uint64 + once sync.Once + conn *websocket.Conn + + ctx context.Context + ctxCancel context.CancelFunc + + logger *zap.Logger + + send chan []byte + stop chan []byte + + remoteAddr string + + config options.WebsocketOpt + + identity auth.Identifiable + + server *server + } +) + +func Session(ctx context.Context, ws *server, conn *websocket.Conn) *session { + s := &session{ + id: nextID(), + conn: conn, + config: ws.config, + send: make(chan []byte, 512), + stop: make(chan []byte, 1), + server: ws, + } + + s.ctx, s.ctxCancel = context.WithCancel(ctx) + + s.logger = ws.logger. + Named("session"). + With( + zap.Uint64("id", s.id), + ) + + return s +} + +func (s *session) connected() (err error) { + s.logger.Info("connected", zap.String("remoteAddr", s.conn.RemoteAddr().String())) + + //// Tell everyone that user has connected + //if err = s.sendPresence("connected"); err != nil { + // return + //} + // + // + //// Create a heartbeat every minute for this user + //go func() { + // defer sentry.Recover() + // + // t := time.NewTicker(time.Second * 60) + // for { + // select { + // case <-s.ctx.Done(): + // return + // case <-t.C: + // _ = s.sendPresence("active") + // } + // } + //}() + + return nil +} + +func (s *session) disconnected() { + // Cancel context + s.ctxCancel() + + s.logger.Info("disconnected") + + // Close connection + _ = s.conn.Close() + s.conn = nil +} + +//func (s *session) sendPresence(_ string) error { +// return nil +//} + +func (s *session) Handle() (err error) { + if err = s.connected(); err != nil { + s.Close() + return + } + + go func() { + // Close unidentified connections in 5sec + <-time.NewTimer(time.Second * 5).C + if s.identity == nil { + s.Write([]byte(closingUnidentifiedConn)) + s.logger.Info("closing unidentified connection") + s.Close() + } + }() + + go func() { + if err = s.readLoop(); err != nil { + s.logger.Error("read failure", zap.Error(err)) + } + s.Close() + }() + + if err = s.writeLoop(); err != nil { + s.logger.Error("write failure", zap.Error(err)) + } + + return +} + +func (s *session) Close() { + s.once.Do(func() { + s.disconnected() + s.server.RemoveSession(s) + }) +} + +func (s *session) readLoop() (err error) { + if err = s.conn.SetReadDeadline(time.Now().Add(s.config.PingTimeout)); err != nil { + return + } + + s.conn.SetPongHandler(func(string) error { + return s.conn.SetReadDeadline(time.Now().Add(s.config.PingTimeout)) + }) + + s.remoteAddr = s.conn.RemoteAddr().String() + + var ( + raw []byte + ) + + for { + if s.conn == nil { + return nil + } + + if _, raw, err = s.conn.ReadMessage(); err != nil { + return errHandler("read failed", err) + } + + if err = s.procRawMessage(raw); err != nil { + return + } + } +} + +func (s *session) procRawMessage(raw []byte) (err error) { + pw := payloadWrap{} + if err = json.Unmarshal(raw, &pw); err != nil { + return fmt.Errorf("could not unmarshal session message: %w", err) + } + + if pw.Type == payloadTypeCredentials { + authPayload := &payloadAuth{} + if err = pw.UnmarshalValue(authPayload); err != nil { + return fmt.Errorf("could not unmarshal session payload: %w", err) + } + + if err = s.authenticate(authPayload); err != nil { + return fmt.Errorf("unauthorized: %w", err) + } + + s.logger.Debug( + "authenticated", + zap.Uint64("userID", s.identity.Identity()), + zap.Uint64s("roles", s.identity.Roles()), + ) + + s.server.StoreSession(s) + + // not expecting anything else + return + } + + if s.identity == nil { + return fmt.Errorf("unauthenticated session") + } + + // at the moment we do not support any other kinds of message types + return fmt.Errorf("unknown message type '%s'", pw.Type) +} + +func (s *session) writeLoop() error { + ticker := time.NewTicker(s.config.PingPeriod) + + defer func() { + ticker.Stop() + s.Close() // break readLoop + }() + + write := func(msg []byte) (err error) { + if s.conn == nil { + // Connection closed, nowhere to write + return + } + + if err = s.conn.SetWriteDeadline(time.Now().Add(s.config.Timeout)); err != nil { + return fmt.Errorf("deadline error: %w", err) + } + + if msg != nil && s.conn != nil { + return s.conn.WriteMessage(websocket.TextMessage, msg) + } + + return + } + + ping := func() (err error) { + if s.conn == nil { + // Connection closed, nothing to ping + return + } + + if err = s.conn.SetWriteDeadline(time.Now().Add(s.config.Timeout)); err != nil { + return + } + + if s.conn != nil { + return s.conn.WriteMessage(websocket.PingMessage, nil) + } + + return + } + + for { + if s.conn == nil { + return nil + } + + select { + case msg, ok := <-s.send: + if !ok { + // channel closed + return nil + } + + if err := errHandler("send failed", write(msg)); err != nil { + return err + } + + case msg := <-s.stop: + // Shutdown requested, don't care if the message is delivered + _ = write(msg) + return nil + + case <-ticker.C: + if err := ping(); err != nil { + return errHandler("ping failed", err) + } + } + } +} + +func (s *session) authenticate(p *payloadAuth) error { + claims, err := s.server.accessToken.Authenticate(p.AccessToken) + if err != nil { + return err + } + + if !auth.CheckScope(claims["scope"], "api") { + return fmt.Errorf("client does not allow use of websockets (missing 'api' scope)") + } + + // Get identity using JWT claims + identity := auth.ClaimsToIdentity(claims) + + if s.identity != nil { + if s.identity.Identity() != identity.Identity() { + return fmt.Errorf("identity does not match") + } + } + + if !identity.Valid() { + return fmt.Errorf("invalid identity") + } + + s.identity = identity + s.Write([]byte(ok)) + return nil +} + +// sendBytes sends byte to channel or timeout +func (s *session) Write(p []byte) (int, error) { + select { + case s.send <- p: + return len(p), nil + case <-time.After(2 * time.Millisecond): + return 0, fmt.Errorf("write timedout") + } +} + +func errHandler(wrap string, err error) error { + if err == nil { + return nil + } + + if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) { + // normal closing + return nil + } + + if errors.Is(err, net.ErrClosed) { + // suppress errors when reading/writing from/to a closed connection + return nil + } + return fmt.Errorf(wrap+": %w", err) +} diff --git a/pkg/websocket/session_test.go b/pkg/websocket/session_test.go new file mode 100644 index 000000000..f96a79e31 --- /dev/null +++ b/pkg/websocket/session_test.go @@ -0,0 +1,53 @@ +package websocket + +import ( + "github.com/cortezaproject/corteza-server/pkg/auth" + "github.com/cortezaproject/corteza-server/pkg/options" + "github.com/stretchr/testify/require" + "testing" + "time" +) + +func TestSession_procRawMessage(t *testing.T) { + var ( + req = require.New(t) + s = session{server: Server(nil, options.WebsocketOpt{})} + jwtHandler, err = auth.JWT("secret", time.Minute) + + userID uint64 = 123 + ) + + req.NoError(err) + s.server.accessToken = jwtHandler + + jwt := jwtHandler.Encode(auth.NewIdentity(userID, 456, 789)) + + req.EqualError(s.procRawMessage([]byte("{}")), "empty payload") + req.Nil(s.identity) + + req.EqualError(s.procRawMessage([]byte(`{"auth":{}}`)), "unauthorized: token contains an invalid number of segments") + req.Nil(s.identity) + + req.EqualError(s.procRawMessage([]byte(`{"auth":{"access_token": ""}}`)), "unauthorized: token contains an invalid number of segments") + req.Nil(s.identity) + + req.NoError(s.procRawMessage([]byte(`{"auth":{"access_token": "` + jwt + `"}}`))) + req.NotNil(s.identity) + req.Equal(userID, s.identity.Identity()) + + req.EqualError(s.procRawMessage([]byte("{}")), "empty payload") + req.Equal(userID, s.identity.Identity()) + + // Repeat with the same user + jwt = jwtHandler.Encode(auth.NewIdentity(userID, 456, 789)) + + req.NoError(s.procRawMessage([]byte(`{"auth":{"access_token": "` + jwt + `"}}`))) + req.NotNil(s.identity) + req.Equal(userID, s.identity.Identity()) + + // Try to authenticate on an existing authenticated session as a different user + jwt = jwtHandler.Encode(auth.NewIdentity(userID+1, 456, 789)) + + req.EqualError(s.procRawMessage([]byte(`{"auth":{"access_token": "`+jwt+`"}}`)), "unauthorized: identity does not match") + +} diff --git a/pkg/wfexec/prompt.go b/pkg/wfexec/prompt.go index 3689c4c90..d43f057f0 100644 --- a/pkg/wfexec/prompt.go +++ b/pkg/wfexec/prompt.go @@ -29,6 +29,11 @@ type ( Payload *expr.Vars `json:"payload"` OwnerId uint64 `json:"-"` } + + ResumedPrompt struct { + StateID uint64 `json:"stateID,string"` + OwnerId uint64 `json:"-"` + } ) func Prompt(ownerId uint64, ref string, payload *expr.Vars) *prompted { @@ -44,3 +49,10 @@ func (p *prompted) toPending() *PendingPrompt { OwnerId: p.ownerId, } } + +func (p *prompted) toResumed() *ResumedPrompt { + return &ResumedPrompt{ + StateID: p.state.stateId, + OwnerId: p.ownerId, + } +} diff --git a/pkg/wfexec/session.go b/pkg/wfexec/session.go index 93594de0f..0c670e264 100644 --- a/pkg/wfexec/session.go +++ b/pkg/wfexec/session.go @@ -288,7 +288,7 @@ func (s *Session) AllPendingPrompts() (out []*PendingPrompt) { return } -func (s *Session) Resume(ctx context.Context, stateId uint64, input *expr.Vars) error { +func (s *Session) Resume(ctx context.Context, stateId uint64, input *expr.Vars) (*ResumedPrompt, error) { defer s.mux.Unlock() s.mux.Lock() @@ -297,11 +297,11 @@ func (s *Session) Resume(ctx context.Context, stateId uint64, input *expr.Vars) p, has = s.prompted[stateId] ) if !has { - return fmt.Errorf("unexisting state") + return nil, fmt.Errorf("unexisting state") } if i == nil || p.ownerId != i.Identity() { - return fmt.Errorf("state access denied") + return nil, fmt.Errorf("state access denied") } delete(s.prompted, stateId) @@ -309,7 +309,11 @@ func (s *Session) Resume(ctx context.Context, stateId uint64, input *expr.Vars) // setting received input to state p.state.input = input - return s.enqueue(ctx, p.state) + if err := s.enqueue(ctx, p.state); err != nil { + return nil, err + } + + return p.toResumed(), nil } func (s *Session) enqueue(ctx context.Context, st *State) error { diff --git a/system/service/reminder.go b/system/service/reminder.go index f3a7c8fa7..5b975f382 100644 --- a/system/service/reminder.go +++ b/system/service/reminder.go @@ -278,7 +278,10 @@ func (svc reminder) Delete(ctx context.Context, ID uint64) (err error) { func (svc reminder) Watch(ctx context.Context) { if svc.reminderSender != nil { - rTicker := time.NewTicker(time.Second) + var ( + interval = time.Second + rTicker = time.NewTicker(interval) + ) go func() { defer sentry.Recover() @@ -289,26 +292,21 @@ func (svc reminder) Watch(ctx context.Context) { select { case <-ctx.Done(): return - case t := <-rTicker.C: + case <-rTicker.C: // Get scheduled reminders of users rr, _, err := svc.Find(ctx, types.ReminderFilter{ ExcludeDismissed: true, ScheduledOnly: true, }) + if err != nil { svc.log.Error("failed to get reminders of users", zap.Error(err)) } - // sendReminderNow checks time is equal to current time or not - sendReminderNow := func(tt time.Time) bool { - timeLayout := time.RFC3339 - return tt.Format(timeLayout) == t.Format(timeLayout) - } - // Send scheduled reminders to users _ = rr.Walk(func(r *types.Reminder) error { - if r.RemindAt != nil && sendReminderNow(*r.RemindAt) { - if err := svc.reminderSender.Send("ok", r, r.AssignedTo); err != nil { + if r.RemindAt != nil && now().Round(interval) == r.RemindAt.Round(interval) { + if err := svc.reminderSender.Send("reminder", r, r.AssignedTo); err != nil { svc.log.Error("failed to send reminder to user", zap.Error(err)) } } diff --git a/websocket/session.go b/websocket/session.go deleted file mode 100644 index 17b48840b..000000000 --- a/websocket/session.go +++ /dev/null @@ -1,309 +0,0 @@ -package websocket - -import ( - "context" - "github.com/cortezaproject/corteza-server/pkg/id" - "github.com/cortezaproject/corteza-server/pkg/options" - "github.com/getsentry/sentry-go" - "github.com/pkg/errors" - "go.uber.org/zap/zapcore" - "sync" - "time" - - gWebsocket "github.com/gorilla/websocket" - "go.uber.org/zap" - - "github.com/cortezaproject/corteza-server/pkg/auth" -) - -// active sessions of users -var sessions = make(map[uint64][]*session) - -type ( - session struct { - id uint64 - once sync.Once - conn *gWebsocket.Conn - - ctx context.Context - ctxCancel context.CancelFunc - - logger *zap.Logger - - send chan []byte - stop chan []byte - - remoteAddr string - - config options.WebsocketOpt - - user auth.Identifiable - } -) - -func Session(ctx context.Context, logger *zap.Logger, config options.WebsocketOpt, conn *gWebsocket.Conn) *session { - s := &session{ - conn: conn, - config: config, - send: make(chan []byte, 512), - stop: make(chan []byte, 1), - } - - s.ctx, s.ctxCancel = context.WithCancel(ctx) - - s.logger = logger - - return s -} - -func (s *session) log(fields ...zapcore.Field) *zap.Logger { - return s.logger.With(fields...) -} - -func (s *session) Context() context.Context { - return s.ctx -} - -func (s *session) User() auth.Identifiable { - return s.user -} - -func (s *session) connected() (err error) { - // Tell everyone that user has connected - if err = s.sendPresence("connected"); err != nil { - return - } - - // Create a heartbeat every minute for this user - go func() { - defer sentry.Recover() - - t := time.NewTicker(time.Second * 60) - for { - select { - case <-s.ctx.Done(): - return - case <-t.C: - _ = s.sendPresence("") - } - } - }() - - return nil -} - -func (s *session) disconnected() { - // Tell everyone that user has disconnected - _ = s.sendPresence("disconnected") - - // Cancel context - s.ctxCancel() - - // Close connection - _ = s.conn.Close() - s.conn = nil -} - -// sendPresence sends user presence: "connected", "disconnected" and "" activity kinds -func (s *session) sendPresence(kind string) error { - //connections := store.CountConnections(s.user.Identity()) - //if kind == "disconnected" { - // connections-- - //} - - return nil -} - -func (s *session) Handle() (err error) { - if err = s.connected(); err != nil { - s.Close() - return - } - - go func() { - _ = s.readLoop() - }() - return s.writeLoop() -} - -func (s *session) Close() { - s.once.Do(func() { - s.disconnected() - _ = s.Delete() - }) -} - -func (s *session) readLoop() (err error) { - defer func() { - s.Close() - }() - - if err = s.conn.SetReadDeadline(time.Now().Add(s.config.PingTimeout)); err != nil { - return - } - - s.conn.SetPongHandler(func(string) error { - return s.conn.SetReadDeadline(time.Now().Add(s.config.PingTimeout)) - }) - - s.remoteAddr = s.conn.RemoteAddr().String() - - for { - _, raw, err := s.conn.ReadMessage() - if err != nil { - return errors.Wrap(err, "s.readLoop") - } - - if err = s.dispatch(raw); err != nil { - s.log(zap.Error(err)).Error("could not dispatch") - //_ = s.send(outgoing.NewError(err)) - } - } -} - -func (s *session) writeLoop() error { - ticker := time.NewTicker(s.config.PingPeriod) - - defer func() { - ticker.Stop() - s.Close() // break readLoop - }() - - write := func(msg []byte) (err error) { - if s.conn == nil { - // Connection closed, nowhere to write - return - } - - if err = s.conn.SetWriteDeadline(time.Now().Add(s.config.Timeout)); err != nil { - return - } - - if msg != nil && s.conn != nil { - return s.conn.WriteMessage(gWebsocket.TextMessage, msg) - } - - return - } - - ping := func() (err error) { - if s.conn == nil { - // Connection closed, nothing to ping - return - } - - if err = s.conn.SetWriteDeadline(time.Now().Add(s.config.Timeout)); err != nil { - return - } - - if s.conn != nil { - return s.conn.WriteMessage(gWebsocket.PingMessage, nil) - } - - return - } - - for { - select { - case msg, ok := <-s.send: - if !ok { - // channel closed - return nil - } - - if err := write(msg); err != nil { - return errors.Wrap(err, "writeLoop send") - } - - case msg := <-s.stop: - // Shutdown requested, don't care if the message is delivered - _ = write(msg) - return nil - - case <-ticker.C: - if err := ping(); err != nil { - return errors.Wrap(err, "writeLoop ping") - } - } - } -} - -func (s *session) dispatch(raw []byte) error { - var p, err = Unmarshal(raw) - if err != nil { - return errors.Wrap(err, "Session.incoming: payload malformed") - } - - if p.Auth != nil { - return s.authenticate(p.Auth) - } - - return nil -} - -func (s *session) authenticate(p *Auth) error { - // Get JWT claims - claims, err := p.ParseWithClaims() - if err != nil { - s.Close() - return err - } - - // Get identity using JWT claims - identity := auth.ClaimsToIdentity(claims) - s.Save(identity) - - return nil -} - -func (s *session) Save(identity *auth.Identity) *session { - if s.id == 0 { - s.id = id.Next() - } - - if identity != nil { - userID := identity.Identity() - existingSessions, ok := sessions[userID] - // Add sessions for user - if !ok { - s.user = identity - sessions[userID] = append(sessions[userID], s) - } - - // Update the identity in existing sessions - for _, sess := range existingSessions { - sess.user = identity - } - } - - return s -} - -func (s *session) Get(userID uint64) []*session { - if sess, ok := sessions[userID]; ok { - return sess - } - return nil -} - -func (s *session) Delete() error { - if s.id == 0 { - return nil - } - - if s.user != nil { - delete(sessions, s.user.Identity()) - } - - return nil -} - -// sendBytes sends byte to channel or timout -func (s *session) sendBytes(p []byte) error { - select { - case s.send <- p: - case <-time.After(2 * time.Millisecond): - s.logger.Warn("websocket.sendBytes send timeout") - } - return nil -} diff --git a/websocket/types.go b/websocket/types.go deleted file mode 100644 index d67d7dd9b..000000000 --- a/websocket/types.go +++ /dev/null @@ -1,55 +0,0 @@ -package websocket - -import ( - "encoding/json" - "github.com/dgrijalva/jwt-go" - "github.com/pkg/errors" -) - -type ( - // Auth is JWT token provided by client as first message, - // and will be passed whenever it changes - Auth struct { - AccessToken *string `json:"access_token"` - } - - // payload for incoming messages from user - payload struct { - *Auth `json:"auth"` - } - - // response for sending messages to user - response struct { - Status string `json:"status"` - Data interface{} `json:"data"` - } -) - -func (a *Auth) ParseWithClaims() (jwt.MapClaims, error) { - token, err := jwt.Parse(*a.AccessToken, nil) - if token == nil { - return nil, err - } - claims, ok := token.Claims.(jwt.MapClaims) - if ok { - return claims, nil - } else { - return nil, errors.New("Invalid token") - } -} - -func Unmarshal(raw []byte) (*payload, error) { - var p payload - return &p, json.Unmarshal(raw, &p) -} - -func Response(status string, data interface{}) *response { - return &response{ - Status: status, - Data: data, - } -} - -func (m response) Marshal() ([]byte, error) { - return json.Marshal(m) -} diff --git a/websocket/websocket.go b/websocket/websocket.go deleted file mode 100644 index 7c4df20a2..000000000 --- a/websocket/websocket.go +++ /dev/null @@ -1,87 +0,0 @@ -package websocket - -import ( - "github.com/cortezaproject/corteza-server/pkg/api" - "github.com/cortezaproject/corteza-server/pkg/options" - gWebsocket "github.com/gorilla/websocket" - "github.com/pkg/errors" - "go.uber.org/zap" - "net/http" -) - -var ( - // upgrader handles websocket requests from peers - upgrader = gWebsocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, - - // Allow connections from any Origin - CheckOrigin: func(r *http.Request) bool { return true }, - } -) - -type ( - websocket struct { - config options.WebsocketOpt - logger *zap.Logger - } -) - -func Websocket(logger *zap.Logger, config options.WebsocketOpt) *websocket { - if !config.LogEnabled { - logger = zap.NewNop() - } - - return &websocket{ - config: config, - logger: logger, - } -} - -func (ws *websocket) Open(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - - conn, err := upgrader.Upgrade(w, r, nil) - if _, ok := err.(gWebsocket.HandshakeError); ok { - ws.logger.Error("ws: need a websocket handshake") - api.Send(w, r, errors.Wrap(err, "ws: need a websocket handshake")) - return - } else if err != nil { - ws.logger.Error("ws: failed to upgrade connection") - api.Send(w, r, errors.Wrap(err, "ws: failed to upgrade connection")) - return - } - - session := Session(ctx, ws.logger, ws.config, conn) - - if err := session.Handle(); err != nil { - ws.logger. - WithOptions(zap.AddStacktrace(zap.PanicLevel)). - Warn("websocket session handler error", zap.Error(err)) - } -} - -// Send delivers message to user to ones we want to -// if len(userIDs) == 0 -- it delivers to everyone -func (ws *websocket) Send(kind string, payload interface{}, userIDs ...uint64) error { - pb, err := Response(kind, payload).Marshal() - if err != nil { - return err - } - - sendsToAll := len(userIDs) == 0 - userIDMap := make(map[uint64]bool) - for _, userID := range userIDs { - userIDMap[userID] = true - } - - for uid, uSessions := range sessions { - if sendsToAll || (!sendsToAll && userIDMap[uid]) { - for _, sess := range uSessions { - _ = sess.sendBytes(pb) - } - } - } - - return nil -} diff --git a/websocket/websocket_test.go b/websocket/websocket_test.go deleted file mode 100644 index f91ecd4e8..000000000 --- a/websocket/websocket_test.go +++ /dev/null @@ -1,127 +0,0 @@ -package websocket - -import ( - "context" - "fmt" - "github.com/cortezaproject/corteza-server/pkg/options" - gWebsocket "github.com/gorilla/websocket" - "go.uber.org/zap" - "net/http" - "net/http/httptest" - "strings" - "testing" -) - -// WebsocketServer provide websocket server for testing -func WebsocketTestServer(t *testing.T) (*httptest.Server, *gWebsocket.Conn) { - // Create test server with the websocket handler. - s := httptest.NewServer(http.HandlerFunc(wsOpen)) - - // Convert http://.. to ws://.. - u := "ws" + strings.TrimPrefix(s.URL, "http") - - // Connect to the server - conn, _, err := gWebsocket.DefaultDialer.Dial(u, nil) - if err != nil { - t.Fatalf("WebsocketServer() error while creating connection = %v", err) - } - - return s, conn -} - -// wsOpen opens websocket connection -func wsOpen(w http.ResponseWriter, r *http.Request) { - var gUpgrader = gWebsocket.Upgrader{} - c, err := gUpgrader.Upgrade(w, r, nil) - if err != nil { - return - } - defer func(c *gWebsocket.Conn) { - _ = c.Close() - }(c) - - for { - mt, message, err := c.ReadMessage() - if err != nil { - break - } - err = c.WriteMessage(mt, message) - if err != nil { - break - } - } -} - -func TestSendingMessageToUser(t *testing.T) { - tests := []struct { - name string - kind string - payload interface{} - expectedP string - }{ - { - name: "send json", - kind: "Json", - payload: struct { - Title string `json:"title"` - Description string `json:"description"` - }{Title: "Websocket", Description: "Testing connection.."}, - expectedP: `{"status":"ok","data":{"title":"Websocket","description":"Testing connection.."}}`, - }, - { - name: "send text", - kind: "Text", - payload: "testing connectivity", - expectedP: "testing connectivity", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - var logger *zap.Logger - var config options.WebsocketOpt - ws := Websocket(logger, config) - s, conn := WebsocketTestServer(t) - defer s.Close() - defer func(ws *gWebsocket.Conn) { - err := ws.Close() - if err != nil { - t.Fatalf("TestSendingMessageToUser() error closing connection = %v", err) - } - }(conn) - - // Open a session using ws connection - wsSession := Session(context.Background(), ws.logger, ws.config, conn) - - var messageType int - var data []byte - switch tt.kind { - case "Text": - messageType = gWebsocket.TextMessage - data = []byte(fmt.Sprintf("%v", tt.payload)) - case "Json": - res := Response("ok", tt.payload) - messageType = gWebsocket.BinaryMessage - var err error - data, err = res.Marshal() - if err != nil { - t.Fatalf("TestSendingMessageToUser() error while marshaling payload = %v", err) - } - } - - // Send message to server, read response and check to see if it's what we expect. - if err := wsSession.conn.WriteMessage(messageType, data); err != nil { - t.Fatalf("TestSendingMessageToUser() error while sending message = %v", err) - } - - _, p, err := wsSession.conn.ReadMessage() - if err != nil { - t.Fatalf("TestSendingMessageToUser() error while reading message =%v", err) - } - - if string(p) != tt.expectedP { - t.Fatalf("TestSendingMessageToUser() gotP = %v, want = %v", string(p), tt.expectedP) - } - }) - } -}