3
0

Refactor & move websocket code under pkg

This commit is contained in:
Denis Arh
2021-05-04 14:51:28 +02:00
parent bcf83b25e8
commit 70dedcaaba
16 changed files with 679 additions and 598 deletions

44
pkg/websocket/payload.go Normal file
View File

@@ -0,0 +1,44 @@
package websocket
import (
"encoding/json"
)
type (
// Auth is JWT token provided by client as first message,
// and will be passed whenever it changes
payloadAuth struct {
AccessToken string `json:"accessToken"`
}
payloadWrap struct {
Type string `json:"@type"`
Value json.RawMessage `json:"@value"`
}
)
const (
payloadTypeCredentials = "credentials"
)
var (
closingUnidentifiedConn, _ = MarshalPayload("error", "closing unidentified connection")
ok, _ = MarshalPayload("message", "authenticated")
)
func (p payloadWrap) UnmarshalValue(m interface{}) error {
return json.Unmarshal(p.Value, m)
}
func MarshalPayload(t string, m interface{}) ([]byte, error) {
var (
err error
w = payloadWrap{Type: t}
)
if w.Value, err = json.Marshal(m); err != nil {
return nil, err
}
return json.Marshal(w)
}

14
pkg/websocket/router.go Normal file
View File

@@ -0,0 +1,14 @@
package websocket
import (
"github.com/go-chi/chi"
)
// MountRoutes initialize route for websocket
// No middleware used, since anyone can open connection and
// send first message with valid JWT token,
// If it's valid then we keep the connection open or close it
func (ws *server) MountRoutes(r chi.Router) {
// Initialize handlers & controllers.
r.Get("/", ws.Open)
}

136
pkg/websocket/server.go Normal file
View File

@@ -0,0 +1,136 @@
package websocket
import (
"github.com/cortezaproject/corteza-server/pkg/auth"
"github.com/cortezaproject/corteza-server/pkg/errors"
"github.com/cortezaproject/corteza-server/pkg/options"
"github.com/cortezaproject/corteza-server/pkg/slice"
"github.com/dgrijalva/jwt-go"
"github.com/gorilla/websocket"
"go.uber.org/zap"
"io"
"net/http"
"sync"
)
var (
// upgrader handles websocket requests from peers
upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
// Allow connections from any Origin
CheckOrigin: func(r *http.Request) bool { return true },
}
)
type (
server struct {
config options.WebsocketOpt
logger *zap.Logger
// user id => session id => session
sessions map[uint64]map[uint64]io.Writer
accessToken interface {
Authenticate(string) (jwt.MapClaims, error)
}
// keep lock on session map changes
l sync.RWMutex
}
)
func Server(logger *zap.Logger, config options.WebsocketOpt) *server {
if !config.LogEnabled {
logger = zap.NewNop()
}
return &server{
config: config,
logger: logger.WithOptions(zap.AddStacktrace(zap.PanicLevel)).Named("websocket"),
accessToken: auth.DefaultJwtHandler,
sessions: make(map[uint64]map[uint64]io.Writer),
}
}
func (ws *server) Open(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
conn, err := upgrader.Upgrade(w, r, nil)
if _, ok := err.(websocket.HandshakeError); ok {
errors.ProperlyServeHTTP(w, r, errors.Internal("need a websocket handshake"), false)
return
} else if err != nil {
errors.ProperlyServeHTTP(w, r, errors.Internal("failed to upgrade connection").Wrap(err), false)
return
}
// init new session
//
// session will add itself back to server's session map when
// ready (if user authenticates itself)
ses := Session(ctx, ws, conn)
if err = ses.Handle(); err != nil {
ws.logger.Warn("websocket session handler error", zap.Error(err))
}
}
// Send delivers payload to one, more or all users
//
// Omit userIDs to deliver to ALL users
func (ws *server) Send(t string, payload interface{}, userIDs ...uint64) error {
pb, err := MarshalPayload(t, payload)
if err != nil {
return err
}
var (
sendToAll = len(userIDs) == 0
uMap = slice.ToUint64BoolMap(userIDs)
)
ws.l.RLock()
defer ws.l.RUnlock()
for uid := range ws.sessions {
if sendToAll || (!sendToAll && uMap[uid]) {
for _, s := range ws.sessions[uid] {
_, err = s.Write(pb)
}
}
}
return nil
}
func (ws *server) StoreSession(s *session) {
ws.l.Lock()
defer ws.l.Unlock()
if s.identity != nil {
ws.storeSession(s, s.identity.Identity(), s.id)
}
}
func (ws *server) storeSession(w io.Writer, uid, sid uint64) {
if ws.sessions[uid] == nil {
ws.sessions[uid] = make(map[uint64]io.Writer)
}
ws.sessions[uid][sid] = w
}
func (ws *server) RemoveSession(s *session) {
ws.l.Lock()
defer ws.l.Unlock()
if s.identity != nil {
uid := s.identity.Identity()
delete(ws.sessions[uid], s.id)
if len(ws.sessions[uid]) == 0 {
delete(ws.sessions, uid)
}
}
}

View File

@@ -0,0 +1,63 @@
package websocket
import (
"bytes"
"github.com/cortezaproject/corteza-server/pkg/options"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
"testing"
)
func TestWebsocketSend_NoSessions(t *testing.T) {
var (
req = require.New(t)
ws = Server(zap.NewNop(), options.WebsocketOpt{})
)
req.NoError(ws.Send("msg", "msg"))
req.NoError(ws.Send("msg", "msg", 1))
req.NoError(ws.Send("msg", "msg", 1, 2))
req.NoError(ws.Send("msg", "msg", 1, 2, 3))
}
func TestWebsocketSend_ExistingSessions(t *testing.T) {
var (
req = require.New(t)
ws = Server(zap.NewNop(), options.WebsocketOpt{})
s1User uint64 = 100
s1ID uint64 = 101
s2User uint64 = 200
s2ID uint64 = 201
s1 = &bytes.Buffer{}
s2 = &bytes.Buffer{}
)
ws.storeSession(s1, s1User, s1ID)
ws.storeSession(s2, s2User, s2ID)
req.Empty(s1)
req.Empty(s2)
req.NoError(ws.Send("msg", "msg", 0))
req.Empty(s1)
req.Empty(s2)
req.NoError(ws.Send("msg", "msg1", s1User))
req.Equal(s1.String(), `{"msg":"msg1"}`)
req.Equal(s2.String(), "")
req.NoError(ws.Send("msg", "msg2", s2User))
req.Equal(s1.String(), `{"msg":"msg1"}`)
req.Equal(s2.String(), `{"msg":"msg2"}`)
req.NoError(ws.Send("both", "msg3", s1User, s2User))
req.Equal(s1.String(), `{"msg":"msg1"}{"both":"msg3"}`)
req.Equal(s2.String(), `{"msg":"msg2"}{"both":"msg3"}`)
req.NoError(ws.Send("all", "msg4"))
req.Equal(s1.String(), `{"msg":"msg1"}{"both":"msg3"}{"all":"msg4"}`)
req.Equal(s2.String(), `{"msg":"msg2"}{"both":"msg3"}{"all":"msg4"}`)
}

340
pkg/websocket/session.go Normal file
View File

@@ -0,0 +1,340 @@
package websocket
import (
"context"
"encoding/json"
"fmt"
"github.com/cortezaproject/corteza-server/pkg/auth"
"github.com/cortezaproject/corteza-server/pkg/errors"
"github.com/cortezaproject/corteza-server/pkg/id"
"github.com/cortezaproject/corteza-server/pkg/options"
"github.com/gorilla/websocket"
"go.uber.org/zap"
"net"
"sync"
"time"
)
// active sessions of users
var (
// wrapper around nextID that will aid service testing
nextID = func() uint64 {
return id.Next()
}
)
type (
session struct {
id uint64
once sync.Once
conn *websocket.Conn
ctx context.Context
ctxCancel context.CancelFunc
logger *zap.Logger
send chan []byte
stop chan []byte
remoteAddr string
config options.WebsocketOpt
identity auth.Identifiable
server *server
}
)
func Session(ctx context.Context, ws *server, conn *websocket.Conn) *session {
s := &session{
id: nextID(),
conn: conn,
config: ws.config,
send: make(chan []byte, 512),
stop: make(chan []byte, 1),
server: ws,
}
s.ctx, s.ctxCancel = context.WithCancel(ctx)
s.logger = ws.logger.
Named("session").
With(
zap.Uint64("id", s.id),
)
return s
}
func (s *session) connected() (err error) {
s.logger.Info("connected", zap.String("remoteAddr", s.conn.RemoteAddr().String()))
//// Tell everyone that user has connected
//if err = s.sendPresence("connected"); err != nil {
// return
//}
//
//
//// Create a heartbeat every minute for this user
//go func() {
// defer sentry.Recover()
//
// t := time.NewTicker(time.Second * 60)
// for {
// select {
// case <-s.ctx.Done():
// return
// case <-t.C:
// _ = s.sendPresence("active")
// }
// }
//}()
return nil
}
func (s *session) disconnected() {
// Cancel context
s.ctxCancel()
s.logger.Info("disconnected")
// Close connection
_ = s.conn.Close()
s.conn = nil
}
//func (s *session) sendPresence(_ string) error {
// return nil
//}
func (s *session) Handle() (err error) {
if err = s.connected(); err != nil {
s.Close()
return
}
go func() {
// Close unidentified connections in 5sec
<-time.NewTimer(time.Second * 5).C
if s.identity == nil {
s.Write([]byte(closingUnidentifiedConn))
s.logger.Info("closing unidentified connection")
s.Close()
}
}()
go func() {
if err = s.readLoop(); err != nil {
s.logger.Error("read failure", zap.Error(err))
}
s.Close()
}()
if err = s.writeLoop(); err != nil {
s.logger.Error("write failure", zap.Error(err))
}
return
}
func (s *session) Close() {
s.once.Do(func() {
s.disconnected()
s.server.RemoveSession(s)
})
}
func (s *session) readLoop() (err error) {
if err = s.conn.SetReadDeadline(time.Now().Add(s.config.PingTimeout)); err != nil {
return
}
s.conn.SetPongHandler(func(string) error {
return s.conn.SetReadDeadline(time.Now().Add(s.config.PingTimeout))
})
s.remoteAddr = s.conn.RemoteAddr().String()
var (
raw []byte
)
for {
if s.conn == nil {
return nil
}
if _, raw, err = s.conn.ReadMessage(); err != nil {
return errHandler("read failed", err)
}
if err = s.procRawMessage(raw); err != nil {
return
}
}
}
func (s *session) procRawMessage(raw []byte) (err error) {
pw := payloadWrap{}
if err = json.Unmarshal(raw, &pw); err != nil {
return fmt.Errorf("could not unmarshal session message: %w", err)
}
if pw.Type == payloadTypeCredentials {
authPayload := &payloadAuth{}
if err = pw.UnmarshalValue(authPayload); err != nil {
return fmt.Errorf("could not unmarshal session payload: %w", err)
}
if err = s.authenticate(authPayload); err != nil {
return fmt.Errorf("unauthorized: %w", err)
}
s.logger.Debug(
"authenticated",
zap.Uint64("userID", s.identity.Identity()),
zap.Uint64s("roles", s.identity.Roles()),
)
s.server.StoreSession(s)
// not expecting anything else
return
}
if s.identity == nil {
return fmt.Errorf("unauthenticated session")
}
// at the moment we do not support any other kinds of message types
return fmt.Errorf("unknown message type '%s'", pw.Type)
}
func (s *session) writeLoop() error {
ticker := time.NewTicker(s.config.PingPeriod)
defer func() {
ticker.Stop()
s.Close() // break readLoop
}()
write := func(msg []byte) (err error) {
if s.conn == nil {
// Connection closed, nowhere to write
return
}
if err = s.conn.SetWriteDeadline(time.Now().Add(s.config.Timeout)); err != nil {
return fmt.Errorf("deadline error: %w", err)
}
if msg != nil && s.conn != nil {
return s.conn.WriteMessage(websocket.TextMessage, msg)
}
return
}
ping := func() (err error) {
if s.conn == nil {
// Connection closed, nothing to ping
return
}
if err = s.conn.SetWriteDeadline(time.Now().Add(s.config.Timeout)); err != nil {
return
}
if s.conn != nil {
return s.conn.WriteMessage(websocket.PingMessage, nil)
}
return
}
for {
if s.conn == nil {
return nil
}
select {
case msg, ok := <-s.send:
if !ok {
// channel closed
return nil
}
if err := errHandler("send failed", write(msg)); err != nil {
return err
}
case msg := <-s.stop:
// Shutdown requested, don't care if the message is delivered
_ = write(msg)
return nil
case <-ticker.C:
if err := ping(); err != nil {
return errHandler("ping failed", err)
}
}
}
}
func (s *session) authenticate(p *payloadAuth) error {
claims, err := s.server.accessToken.Authenticate(p.AccessToken)
if err != nil {
return err
}
if !auth.CheckScope(claims["scope"], "api") {
return fmt.Errorf("client does not allow use of websockets (missing 'api' scope)")
}
// Get identity using JWT claims
identity := auth.ClaimsToIdentity(claims)
if s.identity != nil {
if s.identity.Identity() != identity.Identity() {
return fmt.Errorf("identity does not match")
}
}
if !identity.Valid() {
return fmt.Errorf("invalid identity")
}
s.identity = identity
s.Write([]byte(ok))
return nil
}
// sendBytes sends byte to channel or timeout
func (s *session) Write(p []byte) (int, error) {
select {
case s.send <- p:
return len(p), nil
case <-time.After(2 * time.Millisecond):
return 0, fmt.Errorf("write timedout")
}
}
func errHandler(wrap string, err error) error {
if err == nil {
return nil
}
if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
// normal closing
return nil
}
if errors.Is(err, net.ErrClosed) {
// suppress errors when reading/writing from/to a closed connection
return nil
}
return fmt.Errorf(wrap+": %w", err)
}

View File

@@ -0,0 +1,53 @@
package websocket
import (
"github.com/cortezaproject/corteza-server/pkg/auth"
"github.com/cortezaproject/corteza-server/pkg/options"
"github.com/stretchr/testify/require"
"testing"
"time"
)
func TestSession_procRawMessage(t *testing.T) {
var (
req = require.New(t)
s = session{server: Server(nil, options.WebsocketOpt{})}
jwtHandler, err = auth.JWT("secret", time.Minute)
userID uint64 = 123
)
req.NoError(err)
s.server.accessToken = jwtHandler
jwt := jwtHandler.Encode(auth.NewIdentity(userID, 456, 789))
req.EqualError(s.procRawMessage([]byte("{}")), "empty payload")
req.Nil(s.identity)
req.EqualError(s.procRawMessage([]byte(`{"auth":{}}`)), "unauthorized: token contains an invalid number of segments")
req.Nil(s.identity)
req.EqualError(s.procRawMessage([]byte(`{"auth":{"access_token": ""}}`)), "unauthorized: token contains an invalid number of segments")
req.Nil(s.identity)
req.NoError(s.procRawMessage([]byte(`{"auth":{"access_token": "` + jwt + `"}}`)))
req.NotNil(s.identity)
req.Equal(userID, s.identity.Identity())
req.EqualError(s.procRawMessage([]byte("{}")), "empty payload")
req.Equal(userID, s.identity.Identity())
// Repeat with the same user
jwt = jwtHandler.Encode(auth.NewIdentity(userID, 456, 789))
req.NoError(s.procRawMessage([]byte(`{"auth":{"access_token": "` + jwt + `"}}`)))
req.NotNil(s.identity)
req.Equal(userID, s.identity.Identity())
// Try to authenticate on an existing authenticated session as a different user
jwt = jwtHandler.Encode(auth.NewIdentity(userID+1, 456, 789))
req.EqualError(s.procRawMessage([]byte(`{"auth":{"access_token": "`+jwt+`"}}`)), "unauthorized: identity does not match")
}