357 lines
7.0 KiB
Go
357 lines
7.0 KiB
Go
package websocket
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/cortezaproject/corteza/server/pkg/auth"
|
|
"github.com/cortezaproject/corteza/server/pkg/errors"
|
|
"github.com/cortezaproject/corteza/server/pkg/id"
|
|
"github.com/cortezaproject/corteza/server/pkg/options"
|
|
"github.com/gorilla/websocket"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
// active sessions of users
|
|
var (
|
|
// wrapper around nextID that will aid service testing
|
|
nextID = func() uint64 {
|
|
return id.Next()
|
|
}
|
|
)
|
|
|
|
type (
|
|
conection interface {
|
|
Close() error
|
|
RemoteAddr() net.Addr
|
|
WriteMessage(messageType int, data []byte) error
|
|
SetWriteDeadline(t time.Time) error
|
|
ReadMessage() (messageType int, p []byte, err error)
|
|
SetReadDeadline(t time.Time) error
|
|
SetPongHandler(h func(appData string) error)
|
|
}
|
|
|
|
session struct {
|
|
l sync.RWMutex
|
|
|
|
id uint64
|
|
once sync.Once
|
|
conn conection
|
|
|
|
ctx context.Context
|
|
ctxCancel context.CancelFunc
|
|
|
|
logger *zap.Logger
|
|
|
|
send chan []byte
|
|
stop chan []byte
|
|
|
|
remoteAddr string
|
|
|
|
config options.WebsocketOpt
|
|
|
|
identity auth.Identifiable
|
|
|
|
server *server
|
|
}
|
|
)
|
|
|
|
func Session(ctx context.Context, ws *server, conn *websocket.Conn) *session {
|
|
s := &session{
|
|
id: nextID(),
|
|
conn: conn,
|
|
config: ws.config,
|
|
send: make(chan []byte, 512),
|
|
stop: make(chan []byte, 1),
|
|
server: ws,
|
|
}
|
|
|
|
s.ctx, s.ctxCancel = context.WithCancel(ctx)
|
|
|
|
s.logger = ws.logger.
|
|
Named("session").
|
|
With(
|
|
zap.Uint64("id", s.id),
|
|
)
|
|
|
|
return s
|
|
}
|
|
|
|
func (s *session) Identity() auth.Identifiable {
|
|
s.l.RLock()
|
|
defer s.l.RUnlock()
|
|
return s.identity
|
|
}
|
|
|
|
func (s *session) disconnect() {
|
|
s.l.Lock()
|
|
defer s.l.Unlock()
|
|
|
|
// Cancel context
|
|
s.ctxCancel()
|
|
|
|
s.logger.Info("disconnected")
|
|
|
|
// Close connection
|
|
_ = s.conn.Close()
|
|
|
|
close(s.send)
|
|
close(s.stop)
|
|
s.conn = nil
|
|
}
|
|
|
|
func (s *session) Handle() error {
|
|
go func() {
|
|
// Close unidentified connections in 5sec
|
|
<-time.NewTimer(time.Second * 5).C
|
|
if s.Identity() == nil {
|
|
_, _ = s.Write(closingUnidentifiedConn)
|
|
s.logger.Info("closing unidentified connection")
|
|
s.Close()
|
|
}
|
|
}()
|
|
|
|
go func() {
|
|
if err := s.readLoop(); err != nil {
|
|
if errors.Is(err, net.ErrClosed) {
|
|
// read will return net.ErrClosed when
|
|
// recovering from panic
|
|
return
|
|
}
|
|
|
|
s.logger.Error("read failure", zap.Error(err))
|
|
}
|
|
s.Close()
|
|
}()
|
|
|
|
if err := s.writeLoop(); err != nil {
|
|
if errors.Is(err, net.ErrClosed) {
|
|
// write will return net.ErrClosed when
|
|
// recovering from panic
|
|
return nil
|
|
}
|
|
|
|
s.logger.Error("write failure", zap.Error(err))
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *session) Close() {
|
|
s.once.Do(func() {
|
|
s.disconnect()
|
|
s.server.RemoveSession(s)
|
|
})
|
|
}
|
|
|
|
func (s *session) readLoop() (err error) {
|
|
if err = s.conn.SetReadDeadline(time.Now().Add(s.config.PingTimeout)); err != nil {
|
|
return
|
|
}
|
|
|
|
s.conn.SetPongHandler(func(string) error {
|
|
return s.conn.SetReadDeadline(time.Now().Add(s.config.PingTimeout))
|
|
})
|
|
|
|
s.remoteAddr = s.conn.RemoteAddr().String()
|
|
|
|
var (
|
|
raw []byte
|
|
)
|
|
|
|
for {
|
|
if raw, err = s.read(); err != nil {
|
|
return
|
|
}
|
|
|
|
if raw == nil {
|
|
continue
|
|
}
|
|
|
|
if err = s.procRawMessage(raw); err != nil {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *session) read() (raw []byte, err error) {
|
|
defer func() {
|
|
if recovered := recover(); recovered != nil {
|
|
s.logger.Debug("recovering from websocket read panic", zap.Any("recovered-error", recovered))
|
|
err = net.ErrClosed
|
|
}
|
|
}()
|
|
|
|
s.l.RLock()
|
|
defer s.l.RUnlock()
|
|
|
|
if _, raw, err = s.conn.ReadMessage(); err != nil {
|
|
return nil, errHandler("websocket read failed", err)
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func (s *session) procRawMessage(raw []byte) (err error) {
|
|
pw := payloadWrap{}
|
|
if err = json.Unmarshal(raw, &pw); err != nil {
|
|
return fmt.Errorf("could not unmarshal session message: %w", err)
|
|
}
|
|
|
|
if pw.Type == payloadTypeCredentials {
|
|
authPayload := &payloadAuth{}
|
|
if err = pw.UnmarshalValue(authPayload); err != nil {
|
|
return fmt.Errorf("could not unmarshal session payload: %w", err)
|
|
}
|
|
|
|
if err = s.authenticate(authPayload); err != nil {
|
|
return fmt.Errorf("unauthorized: %w", err)
|
|
}
|
|
|
|
i := s.Identity()
|
|
s.logger.Debug(
|
|
"authenticated",
|
|
zap.Uint64("userID", i.Identity()),
|
|
zap.Uint64s("roles", i.Roles()),
|
|
)
|
|
|
|
s.server.StoreSession(s)
|
|
|
|
// not expecting anything else
|
|
return
|
|
}
|
|
|
|
if s.Identity() == nil {
|
|
return fmt.Errorf("unauthenticated session")
|
|
}
|
|
|
|
// at the moment we do not support any other kinds of message types
|
|
return fmt.Errorf("unknown message type '%s'", pw.Type)
|
|
}
|
|
|
|
// reads send & stop channels and sends received messages to websocket connection via write fn()
|
|
func (s *session) writeLoop() error {
|
|
ticker := time.NewTicker(s.config.PingPeriod)
|
|
|
|
defer ticker.Stop()
|
|
defer s.Close() // break readLoop
|
|
|
|
for {
|
|
select {
|
|
case msg, ok := <-s.send:
|
|
if !ok {
|
|
// channel closed
|
|
return nil
|
|
}
|
|
|
|
if err := s.write(websocket.TextMessage, msg); err != nil {
|
|
return err
|
|
}
|
|
|
|
// continue with wait & write/ping loop
|
|
|
|
case msg, ok := <-s.stop:
|
|
if !ok {
|
|
// channel closed
|
|
return nil
|
|
}
|
|
|
|
// Shutdown requested, don't care if the message is delivered
|
|
if err := s.write(websocket.TextMessage, msg); err != nil {
|
|
return err
|
|
}
|
|
|
|
// stopping, break the loop.
|
|
return nil
|
|
|
|
case <-ticker.C:
|
|
if err := s.write(websocket.PingMessage, nil); err != nil {
|
|
return err
|
|
}
|
|
|
|
// continue with wait & write/ping loop
|
|
}
|
|
}
|
|
}
|
|
|
|
// writes messages to websocket connection
|
|
func (s *session) write(t int, msg []byte) (err error) {
|
|
s.l.RLock()
|
|
defer s.l.RUnlock()
|
|
|
|
defer func() {
|
|
if recovered := recover(); recovered != nil {
|
|
s.logger.Debug("recovering from websocket write panic", zap.Any("recovered-error", recovered))
|
|
err = net.ErrClosed
|
|
}
|
|
}()
|
|
|
|
if err = s.conn.SetWriteDeadline(time.Now().Add(s.config.Timeout)); err != nil {
|
|
return fmt.Errorf("deadline error: %w", err)
|
|
}
|
|
|
|
return errHandler("websocket write failed", s.conn.WriteMessage(t, msg))
|
|
}
|
|
|
|
func (s *session) authenticate(p *payloadAuth) error {
|
|
identity, err := s.server.tokenValidator(s.ctx, p.AccessToken)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if i := s.Identity(); i != nil {
|
|
if i.Identity() != identity.Identity() {
|
|
return fmt.Errorf("identity does not match")
|
|
}
|
|
}
|
|
|
|
if !identity.Valid() {
|
|
return fmt.Errorf("invalid identity")
|
|
}
|
|
|
|
s.l.Lock()
|
|
defer s.l.Unlock()
|
|
|
|
s.identity = identity
|
|
_, _ = s.Write(ok)
|
|
return nil
|
|
}
|
|
|
|
// sendBytes sends byte to channel or timeout
|
|
func (s *session) Write(p []byte) (int, error) {
|
|
defer func() {
|
|
if recovered := recover(); recovered != nil {
|
|
s.logger.Debug("recovering from websocket write panic", zap.Any("recovered-error", recovered))
|
|
}
|
|
}()
|
|
|
|
select {
|
|
case s.send <- p:
|
|
return len(p), nil
|
|
case <-time.After(2 * time.Millisecond):
|
|
return 0, fmt.Errorf("write timedout")
|
|
}
|
|
}
|
|
|
|
func errHandler(prefix string, err error) error {
|
|
if err == nil {
|
|
return nil
|
|
}
|
|
|
|
if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
|
|
// normal closing
|
|
return nil
|
|
}
|
|
|
|
if errors.Is(err, net.ErrClosed) {
|
|
// suppress errors when reading/writing from/to a closed connection
|
|
return nil
|
|
}
|
|
|
|
return fmt.Errorf(prefix+": %w", err)
|
|
}
|