package websocket import ( "context" "encoding/json" "fmt" "net" "sync" "time" "github.com/cortezaproject/corteza/server/pkg/auth" "github.com/cortezaproject/corteza/server/pkg/errors" "github.com/cortezaproject/corteza/server/pkg/id" "github.com/cortezaproject/corteza/server/pkg/logger" "github.com/cortezaproject/corteza/server/pkg/options" "github.com/gorilla/websocket" "go.uber.org/zap" ) // active sessions of users var ( // wrapper around nextID that will aid service testing nextID = func() uint64 { return id.Next() } ) type ( conection interface { Close() error RemoteAddr() net.Addr WriteMessage(messageType int, data []byte) error SetWriteDeadline(t time.Time) error ReadMessage() (messageType int, p []byte, err error) SetReadDeadline(t time.Time) error SetPongHandler(h func(appData string) error) } session struct { l sync.RWMutex id uint64 once sync.Once conn conection ctx context.Context ctxCancel context.CancelFunc logger *zap.Logger send chan []byte stop chan []byte remoteAddr string config options.WebsocketOpt identity auth.Identifiable server *server } ) func Session(ctx context.Context, ws *server, conn *websocket.Conn) *session { s := &session{ id: nextID(), conn: conn, config: ws.config, send: make(chan []byte, 512), stop: make(chan []byte, 1), server: ws, } s.ctx, s.ctxCancel = context.WithCancel(ctx) s.logger = ws.logger. Named("session"). With( logger.Uint64("id", s.id), ) return s } func (s *session) Identity() auth.Identifiable { s.l.RLock() defer s.l.RUnlock() return s.identity } func (s *session) disconnect() { s.l.Lock() defer s.l.Unlock() // Cancel context s.ctxCancel() s.logger.Info("disconnected") // Close connection _ = s.conn.Close() close(s.send) close(s.stop) s.conn = nil } func (s *session) Handle() error { go func() { // Close unidentified connections in 5sec <-time.NewTimer(time.Second * 5).C if s.Identity() == nil { _, _ = s.Write(closingUnidentifiedConn) s.logger.Info("closing unidentified connection") s.Close() } }() go func() { if err := s.readLoop(); err != nil { if errors.Is(err, net.ErrClosed) { // read will return net.ErrClosed when // recovering from panic return } s.logger.Error("read failure", zap.Error(err)) } s.Close() }() if err := s.writeLoop(); err != nil { if errors.Is(err, net.ErrClosed) { // write will return net.ErrClosed when // recovering from panic return nil } s.logger.Error("write failure", zap.Error(err)) } return nil } func (s *session) Close() { s.once.Do(func() { s.disconnect() s.server.RemoveSession(s) }) } func (s *session) readLoop() (err error) { if err = s.conn.SetReadDeadline(time.Now().Add(s.config.PingTimeout)); err != nil { return } s.conn.SetPongHandler(func(string) error { return s.conn.SetReadDeadline(time.Now().Add(s.config.PingTimeout)) }) s.remoteAddr = s.conn.RemoteAddr().String() var ( raw []byte ) for { if raw, err = s.read(); err != nil { return } if raw == nil { continue } if err = s.procRawMessage(raw); err != nil { return } } } func (s *session) read() (raw []byte, err error) { defer func() { if recovered := recover(); recovered != nil { s.logger.Debug("recovering from websocket read panic", zap.Any("recovered-error", recovered)) err = net.ErrClosed } }() s.l.RLock() defer s.l.RUnlock() if _, raw, err = s.conn.ReadMessage(); err != nil { return nil, errHandler("websocket read failed", err) } return } func (s *session) procRawMessage(raw []byte) (err error) { pw := payloadWrap{} if err = json.Unmarshal(raw, &pw); err != nil { return fmt.Errorf("could not unmarshal session message: %w", err) } if pw.Type == payloadTypeCredentials { authPayload := &payloadAuth{} if err = pw.UnmarshalValue(authPayload); err != nil { return fmt.Errorf("could not unmarshal session payload: %w", err) } if err = s.authenticate(authPayload); err != nil { return fmt.Errorf("unauthorized: %w", err) } i := s.Identity() s.logger.Debug( "authenticated", logger.Uint64("userID", i.Identity()), logger.Uint64s("roles", i.Roles()), ) s.server.StoreSession(s) // not expecting anything else return } if s.Identity() == nil { return fmt.Errorf("unauthenticated session") } // at the moment we do not support any other kinds of message types return fmt.Errorf("unknown message type '%s'", pw.Type) } // reads send & stop channels and sends received messages to websocket connection via write fn() func (s *session) writeLoop() error { ticker := time.NewTicker(s.config.PingPeriod) defer ticker.Stop() defer s.Close() // break readLoop for { select { case msg, ok := <-s.send: if !ok { // channel closed return nil } if err := s.write(websocket.TextMessage, msg); err != nil { return err } // continue with wait & write/ping loop case msg, ok := <-s.stop: if !ok { // channel closed return nil } // Shutdown requested, don't care if the message is delivered if err := s.write(websocket.TextMessage, msg); err != nil { return err } // stopping, break the loop. return nil case <-ticker.C: if err := s.write(websocket.PingMessage, nil); err != nil { return err } // continue with wait & write/ping loop } } } // writes messages to websocket connection func (s *session) write(t int, msg []byte) (err error) { s.l.RLock() defer s.l.RUnlock() defer func() { if recovered := recover(); recovered != nil { s.logger.Debug("recovering from websocket write panic", zap.Any("recovered-error", recovered)) err = net.ErrClosed } }() if err = s.conn.SetWriteDeadline(time.Now().Add(s.config.Timeout)); err != nil { return fmt.Errorf("deadline error: %w", err) } return errHandler("websocket write failed", s.conn.WriteMessage(t, msg)) } func (s *session) authenticate(p *payloadAuth) error { identity, err := s.server.tokenValidator(s.ctx, p.AccessToken) if err != nil { return err } if i := s.Identity(); i != nil { if i.Identity() != identity.Identity() { return fmt.Errorf("identity does not match") } } if !identity.Valid() { return fmt.Errorf("invalid identity") } s.l.Lock() defer s.l.Unlock() s.identity = identity _, _ = s.Write(ok) return nil } // sendBytes sends byte to channel or timeout func (s *session) Write(p []byte) (int, error) { defer func() { if recovered := recover(); recovered != nil { s.logger.Debug("recovering from websocket write panic", zap.Any("recovered-error", recovered)) } }() select { case s.send <- p: return len(p), nil case <-time.After(2 * time.Millisecond): return 0, fmt.Errorf("write timedout") } } func errHandler(prefix string, err error) error { if err == nil { return nil } if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) { // normal closing return nil } if errors.Is(err, net.ErrClosed) { // suppress errors when reading/writing from/to a closed connection return nil } return fmt.Errorf(prefix+": %w", err) }