3
0

Refactor & move websocket code under pkg

This commit is contained in:
Denis Arh 2021-05-04 14:51:28 +02:00
parent bcf83b25e8
commit 70dedcaaba
16 changed files with 679 additions and 598 deletions

View File

@ -29,11 +29,11 @@ import (
"github.com/cortezaproject/corteza-server/pkg/rbac"
"github.com/cortezaproject/corteza-server/pkg/scheduler"
"github.com/cortezaproject/corteza-server/pkg/sentry"
"github.com/cortezaproject/corteza-server/pkg/websocket"
"github.com/cortezaproject/corteza-server/store"
sysService "github.com/cortezaproject/corteza-server/system/service"
sysEvent "github.com/cortezaproject/corteza-server/system/service/event"
"github.com/cortezaproject/corteza-server/system/types"
"github.com/cortezaproject/corteza-server/websocket"
"go.uber.org/zap"
gomail "gopkg.in/mail.v2"
)
@ -252,7 +252,7 @@ func (app *CortezaApp) InitServices(ctx context.Context) (err error) {
return err
}
app.WsServer = websocket.Websocket(app.Log, app.Opt.Websocket)
app.WsServer = websocket.Server(app.Log, app.Opt.Websocket)
ctx = actionlog.RequestOriginToContext(ctx, actionlog.RequestOrigin_APP_Init)
defer sentry.Recover()

View File

@ -212,7 +212,16 @@ func (svc *session) Resume(sessionID, stateID uint64, i auth.Identifiable, input
return errors.NotFound("session not found")
}
return ses.Resume(ctx, stateID, input)
resPrompt, err := ses.Resume(ctx, stateID, input)
if err != nil {
return err
}
if err = svc.promptSender.Send("workflowSessionResumed", resPrompt, resPrompt.OwnerId); err != nil {
svc.log.Error("failed to send prompt resume status to user", zap.Error(err))
}
return nil
}
// spawns a new session
@ -367,7 +376,7 @@ func (svc *session) stateChangeHandler(ctx context.Context) wfexec.StateChangeHa
// Send the pending prompts to user
if svc.promptSender != nil {
for _, pp := range s.AllPendingPrompts() {
if err := svc.promptSender.Send("ok", pp, pp.OwnerId); err != nil {
if err := svc.promptSender.Send("workflowSessionPrompt", pp, pp.OwnerId); err != nil {
svc.log.Error("failed to send prompt to user", zap.Error(err))
}
}

View File

@ -103,7 +103,7 @@ func (s Session) Exec(ctx context.Context, step wfexec.Step, input *expr.Vars) e
return s.session.Exec(ctx, step, input)
}
func (s Session) Resume(ctx context.Context, stateID uint64, input *expr.Vars) error {
func (s Session) Resume(ctx context.Context, stateID uint64, input *expr.Vars) (*wfexec.ResumedPrompt, error) {
return s.session.Resume(ctx, stateID, input)
}

44
pkg/websocket/payload.go Normal file
View File

@ -0,0 +1,44 @@
package websocket
import (
"encoding/json"
)
type (
// Auth is JWT token provided by client as first message,
// and will be passed whenever it changes
payloadAuth struct {
AccessToken string `json:"accessToken"`
}
payloadWrap struct {
Type string `json:"@type"`
Value json.RawMessage `json:"@value"`
}
)
const (
payloadTypeCredentials = "credentials"
)
var (
closingUnidentifiedConn, _ = MarshalPayload("error", "closing unidentified connection")
ok, _ = MarshalPayload("message", "authenticated")
)
func (p payloadWrap) UnmarshalValue(m interface{}) error {
return json.Unmarshal(p.Value, m)
}
func MarshalPayload(t string, m interface{}) ([]byte, error) {
var (
err error
w = payloadWrap{Type: t}
)
if w.Value, err = json.Marshal(m); err != nil {
return nil, err
}
return json.Marshal(w)
}

View File

@ -8,7 +8,7 @@ import (
// No middleware used, since anyone can open connection and
// send first message with valid JWT token,
// If it's valid then we keep the connection open or close it
func (ws *websocket) MountRoutes(r chi.Router) {
func (ws *server) MountRoutes(r chi.Router) {
// Initialize handlers & controllers.
r.Get("/", ws.Open)
}

136
pkg/websocket/server.go Normal file
View File

@ -0,0 +1,136 @@
package websocket
import (
"github.com/cortezaproject/corteza-server/pkg/auth"
"github.com/cortezaproject/corteza-server/pkg/errors"
"github.com/cortezaproject/corteza-server/pkg/options"
"github.com/cortezaproject/corteza-server/pkg/slice"
"github.com/dgrijalva/jwt-go"
"github.com/gorilla/websocket"
"go.uber.org/zap"
"io"
"net/http"
"sync"
)
var (
// upgrader handles websocket requests from peers
upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
// Allow connections from any Origin
CheckOrigin: func(r *http.Request) bool { return true },
}
)
type (
server struct {
config options.WebsocketOpt
logger *zap.Logger
// user id => session id => session
sessions map[uint64]map[uint64]io.Writer
accessToken interface {
Authenticate(string) (jwt.MapClaims, error)
}
// keep lock on session map changes
l sync.RWMutex
}
)
func Server(logger *zap.Logger, config options.WebsocketOpt) *server {
if !config.LogEnabled {
logger = zap.NewNop()
}
return &server{
config: config,
logger: logger.WithOptions(zap.AddStacktrace(zap.PanicLevel)).Named("websocket"),
accessToken: auth.DefaultJwtHandler,
sessions: make(map[uint64]map[uint64]io.Writer),
}
}
func (ws *server) Open(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
conn, err := upgrader.Upgrade(w, r, nil)
if _, ok := err.(websocket.HandshakeError); ok {
errors.ProperlyServeHTTP(w, r, errors.Internal("need a websocket handshake"), false)
return
} else if err != nil {
errors.ProperlyServeHTTP(w, r, errors.Internal("failed to upgrade connection").Wrap(err), false)
return
}
// init new session
//
// session will add itself back to server's session map when
// ready (if user authenticates itself)
ses := Session(ctx, ws, conn)
if err = ses.Handle(); err != nil {
ws.logger.Warn("websocket session handler error", zap.Error(err))
}
}
// Send delivers payload to one, more or all users
//
// Omit userIDs to deliver to ALL users
func (ws *server) Send(t string, payload interface{}, userIDs ...uint64) error {
pb, err := MarshalPayload(t, payload)
if err != nil {
return err
}
var (
sendToAll = len(userIDs) == 0
uMap = slice.ToUint64BoolMap(userIDs)
)
ws.l.RLock()
defer ws.l.RUnlock()
for uid := range ws.sessions {
if sendToAll || (!sendToAll && uMap[uid]) {
for _, s := range ws.sessions[uid] {
_, err = s.Write(pb)
}
}
}
return nil
}
func (ws *server) StoreSession(s *session) {
ws.l.Lock()
defer ws.l.Unlock()
if s.identity != nil {
ws.storeSession(s, s.identity.Identity(), s.id)
}
}
func (ws *server) storeSession(w io.Writer, uid, sid uint64) {
if ws.sessions[uid] == nil {
ws.sessions[uid] = make(map[uint64]io.Writer)
}
ws.sessions[uid][sid] = w
}
func (ws *server) RemoveSession(s *session) {
ws.l.Lock()
defer ws.l.Unlock()
if s.identity != nil {
uid := s.identity.Identity()
delete(ws.sessions[uid], s.id)
if len(ws.sessions[uid]) == 0 {
delete(ws.sessions, uid)
}
}
}

View File

@ -0,0 +1,63 @@
package websocket
import (
"bytes"
"github.com/cortezaproject/corteza-server/pkg/options"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
"testing"
)
func TestWebsocketSend_NoSessions(t *testing.T) {
var (
req = require.New(t)
ws = Server(zap.NewNop(), options.WebsocketOpt{})
)
req.NoError(ws.Send("msg", "msg"))
req.NoError(ws.Send("msg", "msg", 1))
req.NoError(ws.Send("msg", "msg", 1, 2))
req.NoError(ws.Send("msg", "msg", 1, 2, 3))
}
func TestWebsocketSend_ExistingSessions(t *testing.T) {
var (
req = require.New(t)
ws = Server(zap.NewNop(), options.WebsocketOpt{})
s1User uint64 = 100
s1ID uint64 = 101
s2User uint64 = 200
s2ID uint64 = 201
s1 = &bytes.Buffer{}
s2 = &bytes.Buffer{}
)
ws.storeSession(s1, s1User, s1ID)
ws.storeSession(s2, s2User, s2ID)
req.Empty(s1)
req.Empty(s2)
req.NoError(ws.Send("msg", "msg", 0))
req.Empty(s1)
req.Empty(s2)
req.NoError(ws.Send("msg", "msg1", s1User))
req.Equal(s1.String(), `{"msg":"msg1"}`)
req.Equal(s2.String(), "")
req.NoError(ws.Send("msg", "msg2", s2User))
req.Equal(s1.String(), `{"msg":"msg1"}`)
req.Equal(s2.String(), `{"msg":"msg2"}`)
req.NoError(ws.Send("both", "msg3", s1User, s2User))
req.Equal(s1.String(), `{"msg":"msg1"}{"both":"msg3"}`)
req.Equal(s2.String(), `{"msg":"msg2"}{"both":"msg3"}`)
req.NoError(ws.Send("all", "msg4"))
req.Equal(s1.String(), `{"msg":"msg1"}{"both":"msg3"}{"all":"msg4"}`)
req.Equal(s2.String(), `{"msg":"msg2"}{"both":"msg3"}{"all":"msg4"}`)
}

340
pkg/websocket/session.go Normal file
View File

@ -0,0 +1,340 @@
package websocket
import (
"context"
"encoding/json"
"fmt"
"github.com/cortezaproject/corteza-server/pkg/auth"
"github.com/cortezaproject/corteza-server/pkg/errors"
"github.com/cortezaproject/corteza-server/pkg/id"
"github.com/cortezaproject/corteza-server/pkg/options"
"github.com/gorilla/websocket"
"go.uber.org/zap"
"net"
"sync"
"time"
)
// active sessions of users
var (
// wrapper around nextID that will aid service testing
nextID = func() uint64 {
return id.Next()
}
)
type (
session struct {
id uint64
once sync.Once
conn *websocket.Conn
ctx context.Context
ctxCancel context.CancelFunc
logger *zap.Logger
send chan []byte
stop chan []byte
remoteAddr string
config options.WebsocketOpt
identity auth.Identifiable
server *server
}
)
func Session(ctx context.Context, ws *server, conn *websocket.Conn) *session {
s := &session{
id: nextID(),
conn: conn,
config: ws.config,
send: make(chan []byte, 512),
stop: make(chan []byte, 1),
server: ws,
}
s.ctx, s.ctxCancel = context.WithCancel(ctx)
s.logger = ws.logger.
Named("session").
With(
zap.Uint64("id", s.id),
)
return s
}
func (s *session) connected() (err error) {
s.logger.Info("connected", zap.String("remoteAddr", s.conn.RemoteAddr().String()))
//// Tell everyone that user has connected
//if err = s.sendPresence("connected"); err != nil {
// return
//}
//
//
//// Create a heartbeat every minute for this user
//go func() {
// defer sentry.Recover()
//
// t := time.NewTicker(time.Second * 60)
// for {
// select {
// case <-s.ctx.Done():
// return
// case <-t.C:
// _ = s.sendPresence("active")
// }
// }
//}()
return nil
}
func (s *session) disconnected() {
// Cancel context
s.ctxCancel()
s.logger.Info("disconnected")
// Close connection
_ = s.conn.Close()
s.conn = nil
}
//func (s *session) sendPresence(_ string) error {
// return nil
//}
func (s *session) Handle() (err error) {
if err = s.connected(); err != nil {
s.Close()
return
}
go func() {
// Close unidentified connections in 5sec
<-time.NewTimer(time.Second * 5).C
if s.identity == nil {
s.Write([]byte(closingUnidentifiedConn))
s.logger.Info("closing unidentified connection")
s.Close()
}
}()
go func() {
if err = s.readLoop(); err != nil {
s.logger.Error("read failure", zap.Error(err))
}
s.Close()
}()
if err = s.writeLoop(); err != nil {
s.logger.Error("write failure", zap.Error(err))
}
return
}
func (s *session) Close() {
s.once.Do(func() {
s.disconnected()
s.server.RemoveSession(s)
})
}
func (s *session) readLoop() (err error) {
if err = s.conn.SetReadDeadline(time.Now().Add(s.config.PingTimeout)); err != nil {
return
}
s.conn.SetPongHandler(func(string) error {
return s.conn.SetReadDeadline(time.Now().Add(s.config.PingTimeout))
})
s.remoteAddr = s.conn.RemoteAddr().String()
var (
raw []byte
)
for {
if s.conn == nil {
return nil
}
if _, raw, err = s.conn.ReadMessage(); err != nil {
return errHandler("read failed", err)
}
if err = s.procRawMessage(raw); err != nil {
return
}
}
}
func (s *session) procRawMessage(raw []byte) (err error) {
pw := payloadWrap{}
if err = json.Unmarshal(raw, &pw); err != nil {
return fmt.Errorf("could not unmarshal session message: %w", err)
}
if pw.Type == payloadTypeCredentials {
authPayload := &payloadAuth{}
if err = pw.UnmarshalValue(authPayload); err != nil {
return fmt.Errorf("could not unmarshal session payload: %w", err)
}
if err = s.authenticate(authPayload); err != nil {
return fmt.Errorf("unauthorized: %w", err)
}
s.logger.Debug(
"authenticated",
zap.Uint64("userID", s.identity.Identity()),
zap.Uint64s("roles", s.identity.Roles()),
)
s.server.StoreSession(s)
// not expecting anything else
return
}
if s.identity == nil {
return fmt.Errorf("unauthenticated session")
}
// at the moment we do not support any other kinds of message types
return fmt.Errorf("unknown message type '%s'", pw.Type)
}
func (s *session) writeLoop() error {
ticker := time.NewTicker(s.config.PingPeriod)
defer func() {
ticker.Stop()
s.Close() // break readLoop
}()
write := func(msg []byte) (err error) {
if s.conn == nil {
// Connection closed, nowhere to write
return
}
if err = s.conn.SetWriteDeadline(time.Now().Add(s.config.Timeout)); err != nil {
return fmt.Errorf("deadline error: %w", err)
}
if msg != nil && s.conn != nil {
return s.conn.WriteMessage(websocket.TextMessage, msg)
}
return
}
ping := func() (err error) {
if s.conn == nil {
// Connection closed, nothing to ping
return
}
if err = s.conn.SetWriteDeadline(time.Now().Add(s.config.Timeout)); err != nil {
return
}
if s.conn != nil {
return s.conn.WriteMessage(websocket.PingMessage, nil)
}
return
}
for {
if s.conn == nil {
return nil
}
select {
case msg, ok := <-s.send:
if !ok {
// channel closed
return nil
}
if err := errHandler("send failed", write(msg)); err != nil {
return err
}
case msg := <-s.stop:
// Shutdown requested, don't care if the message is delivered
_ = write(msg)
return nil
case <-ticker.C:
if err := ping(); err != nil {
return errHandler("ping failed", err)
}
}
}
}
func (s *session) authenticate(p *payloadAuth) error {
claims, err := s.server.accessToken.Authenticate(p.AccessToken)
if err != nil {
return err
}
if !auth.CheckScope(claims["scope"], "api") {
return fmt.Errorf("client does not allow use of websockets (missing 'api' scope)")
}
// Get identity using JWT claims
identity := auth.ClaimsToIdentity(claims)
if s.identity != nil {
if s.identity.Identity() != identity.Identity() {
return fmt.Errorf("identity does not match")
}
}
if !identity.Valid() {
return fmt.Errorf("invalid identity")
}
s.identity = identity
s.Write([]byte(ok))
return nil
}
// sendBytes sends byte to channel or timeout
func (s *session) Write(p []byte) (int, error) {
select {
case s.send <- p:
return len(p), nil
case <-time.After(2 * time.Millisecond):
return 0, fmt.Errorf("write timedout")
}
}
func errHandler(wrap string, err error) error {
if err == nil {
return nil
}
if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
// normal closing
return nil
}
if errors.Is(err, net.ErrClosed) {
// suppress errors when reading/writing from/to a closed connection
return nil
}
return fmt.Errorf(wrap+": %w", err)
}

View File

@ -0,0 +1,53 @@
package websocket
import (
"github.com/cortezaproject/corteza-server/pkg/auth"
"github.com/cortezaproject/corteza-server/pkg/options"
"github.com/stretchr/testify/require"
"testing"
"time"
)
func TestSession_procRawMessage(t *testing.T) {
var (
req = require.New(t)
s = session{server: Server(nil, options.WebsocketOpt{})}
jwtHandler, err = auth.JWT("secret", time.Minute)
userID uint64 = 123
)
req.NoError(err)
s.server.accessToken = jwtHandler
jwt := jwtHandler.Encode(auth.NewIdentity(userID, 456, 789))
req.EqualError(s.procRawMessage([]byte("{}")), "empty payload")
req.Nil(s.identity)
req.EqualError(s.procRawMessage([]byte(`{"auth":{}}`)), "unauthorized: token contains an invalid number of segments")
req.Nil(s.identity)
req.EqualError(s.procRawMessage([]byte(`{"auth":{"access_token": ""}}`)), "unauthorized: token contains an invalid number of segments")
req.Nil(s.identity)
req.NoError(s.procRawMessage([]byte(`{"auth":{"access_token": "` + jwt + `"}}`)))
req.NotNil(s.identity)
req.Equal(userID, s.identity.Identity())
req.EqualError(s.procRawMessage([]byte("{}")), "empty payload")
req.Equal(userID, s.identity.Identity())
// Repeat with the same user
jwt = jwtHandler.Encode(auth.NewIdentity(userID, 456, 789))
req.NoError(s.procRawMessage([]byte(`{"auth":{"access_token": "` + jwt + `"}}`)))
req.NotNil(s.identity)
req.Equal(userID, s.identity.Identity())
// Try to authenticate on an existing authenticated session as a different user
jwt = jwtHandler.Encode(auth.NewIdentity(userID+1, 456, 789))
req.EqualError(s.procRawMessage([]byte(`{"auth":{"access_token": "`+jwt+`"}}`)), "unauthorized: identity does not match")
}

View File

@ -29,6 +29,11 @@ type (
Payload *expr.Vars `json:"payload"`
OwnerId uint64 `json:"-"`
}
ResumedPrompt struct {
StateID uint64 `json:"stateID,string"`
OwnerId uint64 `json:"-"`
}
)
func Prompt(ownerId uint64, ref string, payload *expr.Vars) *prompted {
@ -44,3 +49,10 @@ func (p *prompted) toPending() *PendingPrompt {
OwnerId: p.ownerId,
}
}
func (p *prompted) toResumed() *ResumedPrompt {
return &ResumedPrompt{
StateID: p.state.stateId,
OwnerId: p.ownerId,
}
}

View File

@ -288,7 +288,7 @@ func (s *Session) AllPendingPrompts() (out []*PendingPrompt) {
return
}
func (s *Session) Resume(ctx context.Context, stateId uint64, input *expr.Vars) error {
func (s *Session) Resume(ctx context.Context, stateId uint64, input *expr.Vars) (*ResumedPrompt, error) {
defer s.mux.Unlock()
s.mux.Lock()
@ -297,11 +297,11 @@ func (s *Session) Resume(ctx context.Context, stateId uint64, input *expr.Vars)
p, has = s.prompted[stateId]
)
if !has {
return fmt.Errorf("unexisting state")
return nil, fmt.Errorf("unexisting state")
}
if i == nil || p.ownerId != i.Identity() {
return fmt.Errorf("state access denied")
return nil, fmt.Errorf("state access denied")
}
delete(s.prompted, stateId)
@ -309,7 +309,11 @@ func (s *Session) Resume(ctx context.Context, stateId uint64, input *expr.Vars)
// setting received input to state
p.state.input = input
return s.enqueue(ctx, p.state)
if err := s.enqueue(ctx, p.state); err != nil {
return nil, err
}
return p.toResumed(), nil
}
func (s *Session) enqueue(ctx context.Context, st *State) error {

View File

@ -278,7 +278,10 @@ func (svc reminder) Delete(ctx context.Context, ID uint64) (err error) {
func (svc reminder) Watch(ctx context.Context) {
if svc.reminderSender != nil {
rTicker := time.NewTicker(time.Second)
var (
interval = time.Second
rTicker = time.NewTicker(interval)
)
go func() {
defer sentry.Recover()
@ -289,26 +292,21 @@ func (svc reminder) Watch(ctx context.Context) {
select {
case <-ctx.Done():
return
case t := <-rTicker.C:
case <-rTicker.C:
// Get scheduled reminders of users
rr, _, err := svc.Find(ctx, types.ReminderFilter{
ExcludeDismissed: true,
ScheduledOnly: true,
})
if err != nil {
svc.log.Error("failed to get reminders of users", zap.Error(err))
}
// sendReminderNow checks time is equal to current time or not
sendReminderNow := func(tt time.Time) bool {
timeLayout := time.RFC3339
return tt.Format(timeLayout) == t.Format(timeLayout)
}
// Send scheduled reminders to users
_ = rr.Walk(func(r *types.Reminder) error {
if r.RemindAt != nil && sendReminderNow(*r.RemindAt) {
if err := svc.reminderSender.Send("ok", r, r.AssignedTo); err != nil {
if r.RemindAt != nil && now().Round(interval) == r.RemindAt.Round(interval) {
if err := svc.reminderSender.Send("reminder", r, r.AssignedTo); err != nil {
svc.log.Error("failed to send reminder to user", zap.Error(err))
}
}

View File

@ -1,309 +0,0 @@
package websocket
import (
"context"
"github.com/cortezaproject/corteza-server/pkg/id"
"github.com/cortezaproject/corteza-server/pkg/options"
"github.com/getsentry/sentry-go"
"github.com/pkg/errors"
"go.uber.org/zap/zapcore"
"sync"
"time"
gWebsocket "github.com/gorilla/websocket"
"go.uber.org/zap"
"github.com/cortezaproject/corteza-server/pkg/auth"
)
// active sessions of users
var sessions = make(map[uint64][]*session)
type (
session struct {
id uint64
once sync.Once
conn *gWebsocket.Conn
ctx context.Context
ctxCancel context.CancelFunc
logger *zap.Logger
send chan []byte
stop chan []byte
remoteAddr string
config options.WebsocketOpt
user auth.Identifiable
}
)
func Session(ctx context.Context, logger *zap.Logger, config options.WebsocketOpt, conn *gWebsocket.Conn) *session {
s := &session{
conn: conn,
config: config,
send: make(chan []byte, 512),
stop: make(chan []byte, 1),
}
s.ctx, s.ctxCancel = context.WithCancel(ctx)
s.logger = logger
return s
}
func (s *session) log(fields ...zapcore.Field) *zap.Logger {
return s.logger.With(fields...)
}
func (s *session) Context() context.Context {
return s.ctx
}
func (s *session) User() auth.Identifiable {
return s.user
}
func (s *session) connected() (err error) {
// Tell everyone that user has connected
if err = s.sendPresence("connected"); err != nil {
return
}
// Create a heartbeat every minute for this user
go func() {
defer sentry.Recover()
t := time.NewTicker(time.Second * 60)
for {
select {
case <-s.ctx.Done():
return
case <-t.C:
_ = s.sendPresence("")
}
}
}()
return nil
}
func (s *session) disconnected() {
// Tell everyone that user has disconnected
_ = s.sendPresence("disconnected")
// Cancel context
s.ctxCancel()
// Close connection
_ = s.conn.Close()
s.conn = nil
}
// sendPresence sends user presence: "connected", "disconnected" and "" activity kinds
func (s *session) sendPresence(kind string) error {
//connections := store.CountConnections(s.user.Identity())
//if kind == "disconnected" {
// connections--
//}
return nil
}
func (s *session) Handle() (err error) {
if err = s.connected(); err != nil {
s.Close()
return
}
go func() {
_ = s.readLoop()
}()
return s.writeLoop()
}
func (s *session) Close() {
s.once.Do(func() {
s.disconnected()
_ = s.Delete()
})
}
func (s *session) readLoop() (err error) {
defer func() {
s.Close()
}()
if err = s.conn.SetReadDeadline(time.Now().Add(s.config.PingTimeout)); err != nil {
return
}
s.conn.SetPongHandler(func(string) error {
return s.conn.SetReadDeadline(time.Now().Add(s.config.PingTimeout))
})
s.remoteAddr = s.conn.RemoteAddr().String()
for {
_, raw, err := s.conn.ReadMessage()
if err != nil {
return errors.Wrap(err, "s.readLoop")
}
if err = s.dispatch(raw); err != nil {
s.log(zap.Error(err)).Error("could not dispatch")
//_ = s.send(outgoing.NewError(err))
}
}
}
func (s *session) writeLoop() error {
ticker := time.NewTicker(s.config.PingPeriod)
defer func() {
ticker.Stop()
s.Close() // break readLoop
}()
write := func(msg []byte) (err error) {
if s.conn == nil {
// Connection closed, nowhere to write
return
}
if err = s.conn.SetWriteDeadline(time.Now().Add(s.config.Timeout)); err != nil {
return
}
if msg != nil && s.conn != nil {
return s.conn.WriteMessage(gWebsocket.TextMessage, msg)
}
return
}
ping := func() (err error) {
if s.conn == nil {
// Connection closed, nothing to ping
return
}
if err = s.conn.SetWriteDeadline(time.Now().Add(s.config.Timeout)); err != nil {
return
}
if s.conn != nil {
return s.conn.WriteMessage(gWebsocket.PingMessage, nil)
}
return
}
for {
select {
case msg, ok := <-s.send:
if !ok {
// channel closed
return nil
}
if err := write(msg); err != nil {
return errors.Wrap(err, "writeLoop send")
}
case msg := <-s.stop:
// Shutdown requested, don't care if the message is delivered
_ = write(msg)
return nil
case <-ticker.C:
if err := ping(); err != nil {
return errors.Wrap(err, "writeLoop ping")
}
}
}
}
func (s *session) dispatch(raw []byte) error {
var p, err = Unmarshal(raw)
if err != nil {
return errors.Wrap(err, "Session.incoming: payload malformed")
}
if p.Auth != nil {
return s.authenticate(p.Auth)
}
return nil
}
func (s *session) authenticate(p *Auth) error {
// Get JWT claims
claims, err := p.ParseWithClaims()
if err != nil {
s.Close()
return err
}
// Get identity using JWT claims
identity := auth.ClaimsToIdentity(claims)
s.Save(identity)
return nil
}
func (s *session) Save(identity *auth.Identity) *session {
if s.id == 0 {
s.id = id.Next()
}
if identity != nil {
userID := identity.Identity()
existingSessions, ok := sessions[userID]
// Add sessions for user
if !ok {
s.user = identity
sessions[userID] = append(sessions[userID], s)
}
// Update the identity in existing sessions
for _, sess := range existingSessions {
sess.user = identity
}
}
return s
}
func (s *session) Get(userID uint64) []*session {
if sess, ok := sessions[userID]; ok {
return sess
}
return nil
}
func (s *session) Delete() error {
if s.id == 0 {
return nil
}
if s.user != nil {
delete(sessions, s.user.Identity())
}
return nil
}
// sendBytes sends byte to channel or timout
func (s *session) sendBytes(p []byte) error {
select {
case s.send <- p:
case <-time.After(2 * time.Millisecond):
s.logger.Warn("websocket.sendBytes send timeout")
}
return nil
}

View File

@ -1,55 +0,0 @@
package websocket
import (
"encoding/json"
"github.com/dgrijalva/jwt-go"
"github.com/pkg/errors"
)
type (
// Auth is JWT token provided by client as first message,
// and will be passed whenever it changes
Auth struct {
AccessToken *string `json:"access_token"`
}
// payload for incoming messages from user
payload struct {
*Auth `json:"auth"`
}
// response for sending messages to user
response struct {
Status string `json:"status"`
Data interface{} `json:"data"`
}
)
func (a *Auth) ParseWithClaims() (jwt.MapClaims, error) {
token, err := jwt.Parse(*a.AccessToken, nil)
if token == nil {
return nil, err
}
claims, ok := token.Claims.(jwt.MapClaims)
if ok {
return claims, nil
} else {
return nil, errors.New("Invalid token")
}
}
func Unmarshal(raw []byte) (*payload, error) {
var p payload
return &p, json.Unmarshal(raw, &p)
}
func Response(status string, data interface{}) *response {
return &response{
Status: status,
Data: data,
}
}
func (m response) Marshal() ([]byte, error) {
return json.Marshal(m)
}

View File

@ -1,87 +0,0 @@
package websocket
import (
"github.com/cortezaproject/corteza-server/pkg/api"
"github.com/cortezaproject/corteza-server/pkg/options"
gWebsocket "github.com/gorilla/websocket"
"github.com/pkg/errors"
"go.uber.org/zap"
"net/http"
)
var (
// upgrader handles websocket requests from peers
upgrader = gWebsocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
// Allow connections from any Origin
CheckOrigin: func(r *http.Request) bool { return true },
}
)
type (
websocket struct {
config options.WebsocketOpt
logger *zap.Logger
}
)
func Websocket(logger *zap.Logger, config options.WebsocketOpt) *websocket {
if !config.LogEnabled {
logger = zap.NewNop()
}
return &websocket{
config: config,
logger: logger,
}
}
func (ws *websocket) Open(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
conn, err := upgrader.Upgrade(w, r, nil)
if _, ok := err.(gWebsocket.HandshakeError); ok {
ws.logger.Error("ws: need a websocket handshake")
api.Send(w, r, errors.Wrap(err, "ws: need a websocket handshake"))
return
} else if err != nil {
ws.logger.Error("ws: failed to upgrade connection")
api.Send(w, r, errors.Wrap(err, "ws: failed to upgrade connection"))
return
}
session := Session(ctx, ws.logger, ws.config, conn)
if err := session.Handle(); err != nil {
ws.logger.
WithOptions(zap.AddStacktrace(zap.PanicLevel)).
Warn("websocket session handler error", zap.Error(err))
}
}
// Send delivers message to user to ones we want to
// if len(userIDs) == 0 -- it delivers to everyone
func (ws *websocket) Send(kind string, payload interface{}, userIDs ...uint64) error {
pb, err := Response(kind, payload).Marshal()
if err != nil {
return err
}
sendsToAll := len(userIDs) == 0
userIDMap := make(map[uint64]bool)
for _, userID := range userIDs {
userIDMap[userID] = true
}
for uid, uSessions := range sessions {
if sendsToAll || (!sendsToAll && userIDMap[uid]) {
for _, sess := range uSessions {
_ = sess.sendBytes(pb)
}
}
}
return nil
}

View File

@ -1,127 +0,0 @@
package websocket
import (
"context"
"fmt"
"github.com/cortezaproject/corteza-server/pkg/options"
gWebsocket "github.com/gorilla/websocket"
"go.uber.org/zap"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
// WebsocketServer provide websocket server for testing
func WebsocketTestServer(t *testing.T) (*httptest.Server, *gWebsocket.Conn) {
// Create test server with the websocket handler.
s := httptest.NewServer(http.HandlerFunc(wsOpen))
// Convert http://.. to ws://..
u := "ws" + strings.TrimPrefix(s.URL, "http")
// Connect to the server
conn, _, err := gWebsocket.DefaultDialer.Dial(u, nil)
if err != nil {
t.Fatalf("WebsocketServer() error while creating connection = %v", err)
}
return s, conn
}
// wsOpen opens websocket connection
func wsOpen(w http.ResponseWriter, r *http.Request) {
var gUpgrader = gWebsocket.Upgrader{}
c, err := gUpgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer func(c *gWebsocket.Conn) {
_ = c.Close()
}(c)
for {
mt, message, err := c.ReadMessage()
if err != nil {
break
}
err = c.WriteMessage(mt, message)
if err != nil {
break
}
}
}
func TestSendingMessageToUser(t *testing.T) {
tests := []struct {
name string
kind string
payload interface{}
expectedP string
}{
{
name: "send json",
kind: "Json",
payload: struct {
Title string `json:"title"`
Description string `json:"description"`
}{Title: "Websocket", Description: "Testing connection.."},
expectedP: `{"status":"ok","data":{"title":"Websocket","description":"Testing connection.."}}`,
},
{
name: "send text",
kind: "Text",
payload: "testing connectivity",
expectedP: "testing connectivity",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var logger *zap.Logger
var config options.WebsocketOpt
ws := Websocket(logger, config)
s, conn := WebsocketTestServer(t)
defer s.Close()
defer func(ws *gWebsocket.Conn) {
err := ws.Close()
if err != nil {
t.Fatalf("TestSendingMessageToUser() error closing connection = %v", err)
}
}(conn)
// Open a session using ws connection
wsSession := Session(context.Background(), ws.logger, ws.config, conn)
var messageType int
var data []byte
switch tt.kind {
case "Text":
messageType = gWebsocket.TextMessage
data = []byte(fmt.Sprintf("%v", tt.payload))
case "Json":
res := Response("ok", tt.payload)
messageType = gWebsocket.BinaryMessage
var err error
data, err = res.Marshal()
if err != nil {
t.Fatalf("TestSendingMessageToUser() error while marshaling payload = %v", err)
}
}
// Send message to server, read response and check to see if it's what we expect.
if err := wsSession.conn.WriteMessage(messageType, data); err != nil {
t.Fatalf("TestSendingMessageToUser() error while sending message = %v", err)
}
_, p, err := wsSession.conn.ReadMessage()
if err != nil {
t.Fatalf("TestSendingMessageToUser() error while reading message =%v", err)
}
if string(p) != tt.expectedP {
t.Fatalf("TestSendingMessageToUser() gotP = %v, want = %v", string(p), tt.expectedP)
}
})
}
}