Refactor & move websocket code under pkg
This commit is contained in:
44
pkg/websocket/payload.go
Normal file
44
pkg/websocket/payload.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
type (
|
||||
// Auth is JWT token provided by client as first message,
|
||||
// and will be passed whenever it changes
|
||||
payloadAuth struct {
|
||||
AccessToken string `json:"accessToken"`
|
||||
}
|
||||
|
||||
payloadWrap struct {
|
||||
Type string `json:"@type"`
|
||||
Value json.RawMessage `json:"@value"`
|
||||
}
|
||||
)
|
||||
|
||||
const (
|
||||
payloadTypeCredentials = "credentials"
|
||||
)
|
||||
|
||||
var (
|
||||
closingUnidentifiedConn, _ = MarshalPayload("error", "closing unidentified connection")
|
||||
ok, _ = MarshalPayload("message", "authenticated")
|
||||
)
|
||||
|
||||
func (p payloadWrap) UnmarshalValue(m interface{}) error {
|
||||
return json.Unmarshal(p.Value, m)
|
||||
}
|
||||
|
||||
func MarshalPayload(t string, m interface{}) ([]byte, error) {
|
||||
var (
|
||||
err error
|
||||
w = payloadWrap{Type: t}
|
||||
)
|
||||
|
||||
if w.Value, err = json.Marshal(m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return json.Marshal(w)
|
||||
}
|
||||
14
pkg/websocket/router.go
Normal file
14
pkg/websocket/router.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"github.com/go-chi/chi"
|
||||
)
|
||||
|
||||
// MountRoutes initialize route for websocket
|
||||
// No middleware used, since anyone can open connection and
|
||||
// send first message with valid JWT token,
|
||||
// If it's valid then we keep the connection open or close it
|
||||
func (ws *server) MountRoutes(r chi.Router) {
|
||||
// Initialize handlers & controllers.
|
||||
r.Get("/", ws.Open)
|
||||
}
|
||||
136
pkg/websocket/server.go
Normal file
136
pkg/websocket/server.go
Normal file
@@ -0,0 +1,136 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"github.com/cortezaproject/corteza-server/pkg/auth"
|
||||
"github.com/cortezaproject/corteza-server/pkg/errors"
|
||||
"github.com/cortezaproject/corteza-server/pkg/options"
|
||||
"github.com/cortezaproject/corteza-server/pkg/slice"
|
||||
"github.com/dgrijalva/jwt-go"
|
||||
"github.com/gorilla/websocket"
|
||||
"go.uber.org/zap"
|
||||
"io"
|
||||
"net/http"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var (
|
||||
// upgrader handles websocket requests from peers
|
||||
upgrader = websocket.Upgrader{
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
|
||||
// Allow connections from any Origin
|
||||
CheckOrigin: func(r *http.Request) bool { return true },
|
||||
}
|
||||
)
|
||||
|
||||
type (
|
||||
server struct {
|
||||
config options.WebsocketOpt
|
||||
logger *zap.Logger
|
||||
|
||||
// user id => session id => session
|
||||
sessions map[uint64]map[uint64]io.Writer
|
||||
|
||||
accessToken interface {
|
||||
Authenticate(string) (jwt.MapClaims, error)
|
||||
}
|
||||
|
||||
// keep lock on session map changes
|
||||
l sync.RWMutex
|
||||
}
|
||||
)
|
||||
|
||||
func Server(logger *zap.Logger, config options.WebsocketOpt) *server {
|
||||
if !config.LogEnabled {
|
||||
logger = zap.NewNop()
|
||||
}
|
||||
|
||||
return &server{
|
||||
config: config,
|
||||
logger: logger.WithOptions(zap.AddStacktrace(zap.PanicLevel)).Named("websocket"),
|
||||
accessToken: auth.DefaultJwtHandler,
|
||||
sessions: make(map[uint64]map[uint64]io.Writer),
|
||||
}
|
||||
}
|
||||
|
||||
func (ws *server) Open(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if _, ok := err.(websocket.HandshakeError); ok {
|
||||
errors.ProperlyServeHTTP(w, r, errors.Internal("need a websocket handshake"), false)
|
||||
return
|
||||
} else if err != nil {
|
||||
errors.ProperlyServeHTTP(w, r, errors.Internal("failed to upgrade connection").Wrap(err), false)
|
||||
return
|
||||
}
|
||||
|
||||
// init new session
|
||||
//
|
||||
// session will add itself back to server's session map when
|
||||
// ready (if user authenticates itself)
|
||||
ses := Session(ctx, ws, conn)
|
||||
|
||||
if err = ses.Handle(); err != nil {
|
||||
ws.logger.Warn("websocket session handler error", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
// Send delivers payload to one, more or all users
|
||||
//
|
||||
// Omit userIDs to deliver to ALL users
|
||||
func (ws *server) Send(t string, payload interface{}, userIDs ...uint64) error {
|
||||
pb, err := MarshalPayload(t, payload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var (
|
||||
sendToAll = len(userIDs) == 0
|
||||
uMap = slice.ToUint64BoolMap(userIDs)
|
||||
)
|
||||
|
||||
ws.l.RLock()
|
||||
defer ws.l.RUnlock()
|
||||
|
||||
for uid := range ws.sessions {
|
||||
if sendToAll || (!sendToAll && uMap[uid]) {
|
||||
for _, s := range ws.sessions[uid] {
|
||||
_, err = s.Write(pb)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ws *server) StoreSession(s *session) {
|
||||
ws.l.Lock()
|
||||
defer ws.l.Unlock()
|
||||
if s.identity != nil {
|
||||
ws.storeSession(s, s.identity.Identity(), s.id)
|
||||
}
|
||||
}
|
||||
|
||||
func (ws *server) storeSession(w io.Writer, uid, sid uint64) {
|
||||
if ws.sessions[uid] == nil {
|
||||
ws.sessions[uid] = make(map[uint64]io.Writer)
|
||||
|
||||
}
|
||||
|
||||
ws.sessions[uid][sid] = w
|
||||
}
|
||||
|
||||
func (ws *server) RemoveSession(s *session) {
|
||||
ws.l.Lock()
|
||||
defer ws.l.Unlock()
|
||||
if s.identity != nil {
|
||||
uid := s.identity.Identity()
|
||||
delete(ws.sessions[uid], s.id)
|
||||
|
||||
if len(ws.sessions[uid]) == 0 {
|
||||
delete(ws.sessions, uid)
|
||||
}
|
||||
}
|
||||
}
|
||||
63
pkg/websocket/server_test.go
Normal file
63
pkg/websocket/server_test.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"github.com/cortezaproject/corteza-server/pkg/options"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestWebsocketSend_NoSessions(t *testing.T) {
|
||||
var (
|
||||
req = require.New(t)
|
||||
ws = Server(zap.NewNop(), options.WebsocketOpt{})
|
||||
)
|
||||
|
||||
req.NoError(ws.Send("msg", "msg"))
|
||||
req.NoError(ws.Send("msg", "msg", 1))
|
||||
req.NoError(ws.Send("msg", "msg", 1, 2))
|
||||
req.NoError(ws.Send("msg", "msg", 1, 2, 3))
|
||||
}
|
||||
|
||||
func TestWebsocketSend_ExistingSessions(t *testing.T) {
|
||||
var (
|
||||
req = require.New(t)
|
||||
ws = Server(zap.NewNop(), options.WebsocketOpt{})
|
||||
|
||||
s1User uint64 = 100
|
||||
s1ID uint64 = 101
|
||||
|
||||
s2User uint64 = 200
|
||||
s2ID uint64 = 201
|
||||
|
||||
s1 = &bytes.Buffer{}
|
||||
s2 = &bytes.Buffer{}
|
||||
)
|
||||
|
||||
ws.storeSession(s1, s1User, s1ID)
|
||||
ws.storeSession(s2, s2User, s2ID)
|
||||
|
||||
req.Empty(s1)
|
||||
req.Empty(s2)
|
||||
|
||||
req.NoError(ws.Send("msg", "msg", 0))
|
||||
req.Empty(s1)
|
||||
req.Empty(s2)
|
||||
|
||||
req.NoError(ws.Send("msg", "msg1", s1User))
|
||||
req.Equal(s1.String(), `{"msg":"msg1"}`)
|
||||
req.Equal(s2.String(), "")
|
||||
|
||||
req.NoError(ws.Send("msg", "msg2", s2User))
|
||||
req.Equal(s1.String(), `{"msg":"msg1"}`)
|
||||
req.Equal(s2.String(), `{"msg":"msg2"}`)
|
||||
|
||||
req.NoError(ws.Send("both", "msg3", s1User, s2User))
|
||||
req.Equal(s1.String(), `{"msg":"msg1"}{"both":"msg3"}`)
|
||||
req.Equal(s2.String(), `{"msg":"msg2"}{"both":"msg3"}`)
|
||||
|
||||
req.NoError(ws.Send("all", "msg4"))
|
||||
req.Equal(s1.String(), `{"msg":"msg1"}{"both":"msg3"}{"all":"msg4"}`)
|
||||
req.Equal(s2.String(), `{"msg":"msg2"}{"both":"msg3"}{"all":"msg4"}`)
|
||||
}
|
||||
340
pkg/websocket/session.go
Normal file
340
pkg/websocket/session.go
Normal file
@@ -0,0 +1,340 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/cortezaproject/corteza-server/pkg/auth"
|
||||
"github.com/cortezaproject/corteza-server/pkg/errors"
|
||||
"github.com/cortezaproject/corteza-server/pkg/id"
|
||||
"github.com/cortezaproject/corteza-server/pkg/options"
|
||||
"github.com/gorilla/websocket"
|
||||
"go.uber.org/zap"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// active sessions of users
|
||||
var (
|
||||
// wrapper around nextID that will aid service testing
|
||||
nextID = func() uint64 {
|
||||
return id.Next()
|
||||
}
|
||||
)
|
||||
|
||||
type (
|
||||
session struct {
|
||||
id uint64
|
||||
once sync.Once
|
||||
conn *websocket.Conn
|
||||
|
||||
ctx context.Context
|
||||
ctxCancel context.CancelFunc
|
||||
|
||||
logger *zap.Logger
|
||||
|
||||
send chan []byte
|
||||
stop chan []byte
|
||||
|
||||
remoteAddr string
|
||||
|
||||
config options.WebsocketOpt
|
||||
|
||||
identity auth.Identifiable
|
||||
|
||||
server *server
|
||||
}
|
||||
)
|
||||
|
||||
func Session(ctx context.Context, ws *server, conn *websocket.Conn) *session {
|
||||
s := &session{
|
||||
id: nextID(),
|
||||
conn: conn,
|
||||
config: ws.config,
|
||||
send: make(chan []byte, 512),
|
||||
stop: make(chan []byte, 1),
|
||||
server: ws,
|
||||
}
|
||||
|
||||
s.ctx, s.ctxCancel = context.WithCancel(ctx)
|
||||
|
||||
s.logger = ws.logger.
|
||||
Named("session").
|
||||
With(
|
||||
zap.Uint64("id", s.id),
|
||||
)
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *session) connected() (err error) {
|
||||
s.logger.Info("connected", zap.String("remoteAddr", s.conn.RemoteAddr().String()))
|
||||
|
||||
//// Tell everyone that user has connected
|
||||
//if err = s.sendPresence("connected"); err != nil {
|
||||
// return
|
||||
//}
|
||||
//
|
||||
//
|
||||
//// Create a heartbeat every minute for this user
|
||||
//go func() {
|
||||
// defer sentry.Recover()
|
||||
//
|
||||
// t := time.NewTicker(time.Second * 60)
|
||||
// for {
|
||||
// select {
|
||||
// case <-s.ctx.Done():
|
||||
// return
|
||||
// case <-t.C:
|
||||
// _ = s.sendPresence("active")
|
||||
// }
|
||||
// }
|
||||
//}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *session) disconnected() {
|
||||
// Cancel context
|
||||
s.ctxCancel()
|
||||
|
||||
s.logger.Info("disconnected")
|
||||
|
||||
// Close connection
|
||||
_ = s.conn.Close()
|
||||
s.conn = nil
|
||||
}
|
||||
|
||||
//func (s *session) sendPresence(_ string) error {
|
||||
// return nil
|
||||
//}
|
||||
|
||||
func (s *session) Handle() (err error) {
|
||||
if err = s.connected(); err != nil {
|
||||
s.Close()
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
// Close unidentified connections in 5sec
|
||||
<-time.NewTimer(time.Second * 5).C
|
||||
if s.identity == nil {
|
||||
s.Write([]byte(closingUnidentifiedConn))
|
||||
s.logger.Info("closing unidentified connection")
|
||||
s.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
if err = s.readLoop(); err != nil {
|
||||
s.logger.Error("read failure", zap.Error(err))
|
||||
}
|
||||
s.Close()
|
||||
}()
|
||||
|
||||
if err = s.writeLoop(); err != nil {
|
||||
s.logger.Error("write failure", zap.Error(err))
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (s *session) Close() {
|
||||
s.once.Do(func() {
|
||||
s.disconnected()
|
||||
s.server.RemoveSession(s)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *session) readLoop() (err error) {
|
||||
if err = s.conn.SetReadDeadline(time.Now().Add(s.config.PingTimeout)); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
s.conn.SetPongHandler(func(string) error {
|
||||
return s.conn.SetReadDeadline(time.Now().Add(s.config.PingTimeout))
|
||||
})
|
||||
|
||||
s.remoteAddr = s.conn.RemoteAddr().String()
|
||||
|
||||
var (
|
||||
raw []byte
|
||||
)
|
||||
|
||||
for {
|
||||
if s.conn == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, raw, err = s.conn.ReadMessage(); err != nil {
|
||||
return errHandler("read failed", err)
|
||||
}
|
||||
|
||||
if err = s.procRawMessage(raw); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *session) procRawMessage(raw []byte) (err error) {
|
||||
pw := payloadWrap{}
|
||||
if err = json.Unmarshal(raw, &pw); err != nil {
|
||||
return fmt.Errorf("could not unmarshal session message: %w", err)
|
||||
}
|
||||
|
||||
if pw.Type == payloadTypeCredentials {
|
||||
authPayload := &payloadAuth{}
|
||||
if err = pw.UnmarshalValue(authPayload); err != nil {
|
||||
return fmt.Errorf("could not unmarshal session payload: %w", err)
|
||||
}
|
||||
|
||||
if err = s.authenticate(authPayload); err != nil {
|
||||
return fmt.Errorf("unauthorized: %w", err)
|
||||
}
|
||||
|
||||
s.logger.Debug(
|
||||
"authenticated",
|
||||
zap.Uint64("userID", s.identity.Identity()),
|
||||
zap.Uint64s("roles", s.identity.Roles()),
|
||||
)
|
||||
|
||||
s.server.StoreSession(s)
|
||||
|
||||
// not expecting anything else
|
||||
return
|
||||
}
|
||||
|
||||
if s.identity == nil {
|
||||
return fmt.Errorf("unauthenticated session")
|
||||
}
|
||||
|
||||
// at the moment we do not support any other kinds of message types
|
||||
return fmt.Errorf("unknown message type '%s'", pw.Type)
|
||||
}
|
||||
|
||||
func (s *session) writeLoop() error {
|
||||
ticker := time.NewTicker(s.config.PingPeriod)
|
||||
|
||||
defer func() {
|
||||
ticker.Stop()
|
||||
s.Close() // break readLoop
|
||||
}()
|
||||
|
||||
write := func(msg []byte) (err error) {
|
||||
if s.conn == nil {
|
||||
// Connection closed, nowhere to write
|
||||
return
|
||||
}
|
||||
|
||||
if err = s.conn.SetWriteDeadline(time.Now().Add(s.config.Timeout)); err != nil {
|
||||
return fmt.Errorf("deadline error: %w", err)
|
||||
}
|
||||
|
||||
if msg != nil && s.conn != nil {
|
||||
return s.conn.WriteMessage(websocket.TextMessage, msg)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
ping := func() (err error) {
|
||||
if s.conn == nil {
|
||||
// Connection closed, nothing to ping
|
||||
return
|
||||
}
|
||||
|
||||
if err = s.conn.SetWriteDeadline(time.Now().Add(s.config.Timeout)); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if s.conn != nil {
|
||||
return s.conn.WriteMessage(websocket.PingMessage, nil)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
if s.conn == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
select {
|
||||
case msg, ok := <-s.send:
|
||||
if !ok {
|
||||
// channel closed
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := errHandler("send failed", write(msg)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
case msg := <-s.stop:
|
||||
// Shutdown requested, don't care if the message is delivered
|
||||
_ = write(msg)
|
||||
return nil
|
||||
|
||||
case <-ticker.C:
|
||||
if err := ping(); err != nil {
|
||||
return errHandler("ping failed", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *session) authenticate(p *payloadAuth) error {
|
||||
claims, err := s.server.accessToken.Authenticate(p.AccessToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !auth.CheckScope(claims["scope"], "api") {
|
||||
return fmt.Errorf("client does not allow use of websockets (missing 'api' scope)")
|
||||
}
|
||||
|
||||
// Get identity using JWT claims
|
||||
identity := auth.ClaimsToIdentity(claims)
|
||||
|
||||
if s.identity != nil {
|
||||
if s.identity.Identity() != identity.Identity() {
|
||||
return fmt.Errorf("identity does not match")
|
||||
}
|
||||
}
|
||||
|
||||
if !identity.Valid() {
|
||||
return fmt.Errorf("invalid identity")
|
||||
}
|
||||
|
||||
s.identity = identity
|
||||
s.Write([]byte(ok))
|
||||
return nil
|
||||
}
|
||||
|
||||
// sendBytes sends byte to channel or timeout
|
||||
func (s *session) Write(p []byte) (int, error) {
|
||||
select {
|
||||
case s.send <- p:
|
||||
return len(p), nil
|
||||
case <-time.After(2 * time.Millisecond):
|
||||
return 0, fmt.Errorf("write timedout")
|
||||
}
|
||||
}
|
||||
|
||||
func errHandler(wrap string, err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
|
||||
// normal closing
|
||||
return nil
|
||||
}
|
||||
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
// suppress errors when reading/writing from/to a closed connection
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf(wrap+": %w", err)
|
||||
}
|
||||
53
pkg/websocket/session_test.go
Normal file
53
pkg/websocket/session_test.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"github.com/cortezaproject/corteza-server/pkg/auth"
|
||||
"github.com/cortezaproject/corteza-server/pkg/options"
|
||||
"github.com/stretchr/testify/require"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestSession_procRawMessage(t *testing.T) {
|
||||
var (
|
||||
req = require.New(t)
|
||||
s = session{server: Server(nil, options.WebsocketOpt{})}
|
||||
jwtHandler, err = auth.JWT("secret", time.Minute)
|
||||
|
||||
userID uint64 = 123
|
||||
)
|
||||
|
||||
req.NoError(err)
|
||||
s.server.accessToken = jwtHandler
|
||||
|
||||
jwt := jwtHandler.Encode(auth.NewIdentity(userID, 456, 789))
|
||||
|
||||
req.EqualError(s.procRawMessage([]byte("{}")), "empty payload")
|
||||
req.Nil(s.identity)
|
||||
|
||||
req.EqualError(s.procRawMessage([]byte(`{"auth":{}}`)), "unauthorized: token contains an invalid number of segments")
|
||||
req.Nil(s.identity)
|
||||
|
||||
req.EqualError(s.procRawMessage([]byte(`{"auth":{"access_token": ""}}`)), "unauthorized: token contains an invalid number of segments")
|
||||
req.Nil(s.identity)
|
||||
|
||||
req.NoError(s.procRawMessage([]byte(`{"auth":{"access_token": "` + jwt + `"}}`)))
|
||||
req.NotNil(s.identity)
|
||||
req.Equal(userID, s.identity.Identity())
|
||||
|
||||
req.EqualError(s.procRawMessage([]byte("{}")), "empty payload")
|
||||
req.Equal(userID, s.identity.Identity())
|
||||
|
||||
// Repeat with the same user
|
||||
jwt = jwtHandler.Encode(auth.NewIdentity(userID, 456, 789))
|
||||
|
||||
req.NoError(s.procRawMessage([]byte(`{"auth":{"access_token": "` + jwt + `"}}`)))
|
||||
req.NotNil(s.identity)
|
||||
req.Equal(userID, s.identity.Identity())
|
||||
|
||||
// Try to authenticate on an existing authenticated session as a different user
|
||||
jwt = jwtHandler.Encode(auth.NewIdentity(userID+1, 456, 789))
|
||||
|
||||
req.EqualError(s.procRawMessage([]byte(`{"auth":{"access_token": "`+jwt+`"}}`)), "unauthorized: identity does not match")
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user