3
0
Tomaž Jerman 462619f2b9 Change logs to encode uint64 values as strings
This is due to us introducing the web console and the uints needing
to be string encoded (because of JavaScript).
2023-05-24 12:26:01 +02:00

358 lines
7.1 KiB
Go

package websocket
import (
"context"
"encoding/json"
"fmt"
"net"
"sync"
"time"
"github.com/cortezaproject/corteza/server/pkg/auth"
"github.com/cortezaproject/corteza/server/pkg/errors"
"github.com/cortezaproject/corteza/server/pkg/id"
"github.com/cortezaproject/corteza/server/pkg/logger"
"github.com/cortezaproject/corteza/server/pkg/options"
"github.com/gorilla/websocket"
"go.uber.org/zap"
)
// active sessions of users
var (
// wrapper around nextID that will aid service testing
nextID = func() uint64 {
return id.Next()
}
)
type (
conection interface {
Close() error
RemoteAddr() net.Addr
WriteMessage(messageType int, data []byte) error
SetWriteDeadline(t time.Time) error
ReadMessage() (messageType int, p []byte, err error)
SetReadDeadline(t time.Time) error
SetPongHandler(h func(appData string) error)
}
session struct {
l sync.RWMutex
id uint64
once sync.Once
conn conection
ctx context.Context
ctxCancel context.CancelFunc
logger *zap.Logger
send chan []byte
stop chan []byte
remoteAddr string
config options.WebsocketOpt
identity auth.Identifiable
server *server
}
)
func Session(ctx context.Context, ws *server, conn *websocket.Conn) *session {
s := &session{
id: nextID(),
conn: conn,
config: ws.config,
send: make(chan []byte, 512),
stop: make(chan []byte, 1),
server: ws,
}
s.ctx, s.ctxCancel = context.WithCancel(ctx)
s.logger = ws.logger.
Named("session").
With(
logger.Uint64("id", s.id),
)
return s
}
func (s *session) Identity() auth.Identifiable {
s.l.RLock()
defer s.l.RUnlock()
return s.identity
}
func (s *session) disconnect() {
s.l.Lock()
defer s.l.Unlock()
// Cancel context
s.ctxCancel()
s.logger.Info("disconnected")
// Close connection
_ = s.conn.Close()
close(s.send)
close(s.stop)
s.conn = nil
}
func (s *session) Handle() error {
go func() {
// Close unidentified connections in 5sec
<-time.NewTimer(time.Second * 5).C
if s.Identity() == nil {
_, _ = s.Write(closingUnidentifiedConn)
s.logger.Info("closing unidentified connection")
s.Close()
}
}()
go func() {
if err := s.readLoop(); err != nil {
if errors.Is(err, net.ErrClosed) {
// read will return net.ErrClosed when
// recovering from panic
return
}
s.logger.Error("read failure", zap.Error(err))
}
s.Close()
}()
if err := s.writeLoop(); err != nil {
if errors.Is(err, net.ErrClosed) {
// write will return net.ErrClosed when
// recovering from panic
return nil
}
s.logger.Error("write failure", zap.Error(err))
}
return nil
}
func (s *session) Close() {
s.once.Do(func() {
s.disconnect()
s.server.RemoveSession(s)
})
}
func (s *session) readLoop() (err error) {
if err = s.conn.SetReadDeadline(time.Now().Add(s.config.PingTimeout)); err != nil {
return
}
s.conn.SetPongHandler(func(string) error {
return s.conn.SetReadDeadline(time.Now().Add(s.config.PingTimeout))
})
s.remoteAddr = s.conn.RemoteAddr().String()
var (
raw []byte
)
for {
if raw, err = s.read(); err != nil {
return
}
if raw == nil {
continue
}
if err = s.procRawMessage(raw); err != nil {
return
}
}
}
func (s *session) read() (raw []byte, err error) {
defer func() {
if recovered := recover(); recovered != nil {
s.logger.Debug("recovering from websocket read panic", zap.Any("recovered-error", recovered))
err = net.ErrClosed
}
}()
s.l.RLock()
defer s.l.RUnlock()
if _, raw, err = s.conn.ReadMessage(); err != nil {
return nil, errHandler("websocket read failed", err)
}
return
}
func (s *session) procRawMessage(raw []byte) (err error) {
pw := payloadWrap{}
if err = json.Unmarshal(raw, &pw); err != nil {
return fmt.Errorf("could not unmarshal session message: %w", err)
}
if pw.Type == payloadTypeCredentials {
authPayload := &payloadAuth{}
if err = pw.UnmarshalValue(authPayload); err != nil {
return fmt.Errorf("could not unmarshal session payload: %w", err)
}
if err = s.authenticate(authPayload); err != nil {
return fmt.Errorf("unauthorized: %w", err)
}
i := s.Identity()
s.logger.Debug(
"authenticated",
logger.Uint64("userID", i.Identity()),
logger.Uint64s("roles", i.Roles()),
)
s.server.StoreSession(s)
// not expecting anything else
return
}
if s.Identity() == nil {
return fmt.Errorf("unauthenticated session")
}
// at the moment we do not support any other kinds of message types
return fmt.Errorf("unknown message type '%s'", pw.Type)
}
// reads send & stop channels and sends received messages to websocket connection via write fn()
func (s *session) writeLoop() error {
ticker := time.NewTicker(s.config.PingPeriod)
defer ticker.Stop()
defer s.Close() // break readLoop
for {
select {
case msg, ok := <-s.send:
if !ok {
// channel closed
return nil
}
if err := s.write(websocket.TextMessage, msg); err != nil {
return err
}
// continue with wait & write/ping loop
case msg, ok := <-s.stop:
if !ok {
// channel closed
return nil
}
// Shutdown requested, don't care if the message is delivered
if err := s.write(websocket.TextMessage, msg); err != nil {
return err
}
// stopping, break the loop.
return nil
case <-ticker.C:
if err := s.write(websocket.PingMessage, nil); err != nil {
return err
}
// continue with wait & write/ping loop
}
}
}
// writes messages to websocket connection
func (s *session) write(t int, msg []byte) (err error) {
s.l.RLock()
defer s.l.RUnlock()
defer func() {
if recovered := recover(); recovered != nil {
s.logger.Debug("recovering from websocket write panic", zap.Any("recovered-error", recovered))
err = net.ErrClosed
}
}()
if err = s.conn.SetWriteDeadline(time.Now().Add(s.config.Timeout)); err != nil {
return fmt.Errorf("deadline error: %w", err)
}
return errHandler("websocket write failed", s.conn.WriteMessage(t, msg))
}
func (s *session) authenticate(p *payloadAuth) error {
identity, err := s.server.tokenValidator(s.ctx, p.AccessToken)
if err != nil {
return err
}
if i := s.Identity(); i != nil {
if i.Identity() != identity.Identity() {
return fmt.Errorf("identity does not match")
}
}
if !identity.Valid() {
return fmt.Errorf("invalid identity")
}
s.l.Lock()
defer s.l.Unlock()
s.identity = identity
_, _ = s.Write(ok)
return nil
}
// sendBytes sends byte to channel or timeout
func (s *session) Write(p []byte) (int, error) {
defer func() {
if recovered := recover(); recovered != nil {
s.logger.Debug("recovering from websocket write panic", zap.Any("recovered-error", recovered))
}
}()
select {
case s.send <- p:
return len(p), nil
case <-time.After(2 * time.Millisecond):
return 0, fmt.Errorf("write timedout")
}
}
func errHandler(prefix string, err error) error {
if err == nil {
return nil
}
if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
// normal closing
return nil
}
if errors.Is(err, net.ErrClosed) {
// suppress errors when reading/writing from/to a closed connection
return nil
}
return fmt.Errorf(prefix+": %w", err)
}