341 lines
6.4 KiB
Go
341 lines
6.4 KiB
Go
package websocket
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"github.com/cortezaproject/corteza-server/pkg/auth"
|
|
"github.com/cortezaproject/corteza-server/pkg/errors"
|
|
"github.com/cortezaproject/corteza-server/pkg/id"
|
|
"github.com/cortezaproject/corteza-server/pkg/options"
|
|
"github.com/gorilla/websocket"
|
|
"go.uber.org/zap"
|
|
"net"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
// active sessions of users
|
|
var (
|
|
// wrapper around nextID that will aid service testing
|
|
nextID = func() uint64 {
|
|
return id.Next()
|
|
}
|
|
)
|
|
|
|
type (
|
|
session struct {
|
|
id uint64
|
|
once sync.Once
|
|
conn *websocket.Conn
|
|
|
|
ctx context.Context
|
|
ctxCancel context.CancelFunc
|
|
|
|
logger *zap.Logger
|
|
|
|
send chan []byte
|
|
stop chan []byte
|
|
|
|
remoteAddr string
|
|
|
|
config options.WebsocketOpt
|
|
|
|
identity auth.Identifiable
|
|
|
|
server *server
|
|
}
|
|
)
|
|
|
|
func Session(ctx context.Context, ws *server, conn *websocket.Conn) *session {
|
|
s := &session{
|
|
id: nextID(),
|
|
conn: conn,
|
|
config: ws.config,
|
|
send: make(chan []byte, 512),
|
|
stop: make(chan []byte, 1),
|
|
server: ws,
|
|
}
|
|
|
|
s.ctx, s.ctxCancel = context.WithCancel(ctx)
|
|
|
|
s.logger = ws.logger.
|
|
Named("session").
|
|
With(
|
|
zap.Uint64("id", s.id),
|
|
)
|
|
|
|
return s
|
|
}
|
|
|
|
func (s *session) connected() (err error) {
|
|
s.logger.Info("connected", zap.String("remoteAddr", s.conn.RemoteAddr().String()))
|
|
|
|
//// Tell everyone that user has connected
|
|
//if err = s.sendPresence("connected"); err != nil {
|
|
// return
|
|
//}
|
|
//
|
|
//
|
|
//// Create a heartbeat every minute for this user
|
|
//go func() {
|
|
// defer sentry.Recover()
|
|
//
|
|
// t := time.NewTicker(time.Second * 60)
|
|
// for {
|
|
// select {
|
|
// case <-s.ctx.Done():
|
|
// return
|
|
// case <-t.C:
|
|
// _ = s.sendPresence("active")
|
|
// }
|
|
// }
|
|
//}()
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *session) disconnected() {
|
|
// Cancel context
|
|
s.ctxCancel()
|
|
|
|
s.logger.Info("disconnected")
|
|
|
|
// Close connection
|
|
_ = s.conn.Close()
|
|
s.conn = nil
|
|
}
|
|
|
|
//func (s *session) sendPresence(_ string) error {
|
|
// return nil
|
|
//}
|
|
|
|
func (s *session) Handle() (err error) {
|
|
if err = s.connected(); err != nil {
|
|
s.Close()
|
|
return
|
|
}
|
|
|
|
go func() {
|
|
// Close unidentified connections in 5sec
|
|
<-time.NewTimer(time.Second * 5).C
|
|
if s.identity == nil {
|
|
s.Write([]byte(closingUnidentifiedConn))
|
|
s.logger.Info("closing unidentified connection")
|
|
s.Close()
|
|
}
|
|
}()
|
|
|
|
go func() {
|
|
if err = s.readLoop(); err != nil {
|
|
s.logger.Error("read failure", zap.Error(err))
|
|
}
|
|
s.Close()
|
|
}()
|
|
|
|
if err = s.writeLoop(); err != nil {
|
|
s.logger.Error("write failure", zap.Error(err))
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func (s *session) Close() {
|
|
s.once.Do(func() {
|
|
s.disconnected()
|
|
s.server.RemoveSession(s)
|
|
})
|
|
}
|
|
|
|
func (s *session) readLoop() (err error) {
|
|
if err = s.conn.SetReadDeadline(time.Now().Add(s.config.PingTimeout)); err != nil {
|
|
return
|
|
}
|
|
|
|
s.conn.SetPongHandler(func(string) error {
|
|
return s.conn.SetReadDeadline(time.Now().Add(s.config.PingTimeout))
|
|
})
|
|
|
|
s.remoteAddr = s.conn.RemoteAddr().String()
|
|
|
|
var (
|
|
raw []byte
|
|
)
|
|
|
|
for {
|
|
if s.conn == nil {
|
|
return nil
|
|
}
|
|
|
|
if _, raw, err = s.conn.ReadMessage(); err != nil {
|
|
return errHandler("read failed", err)
|
|
}
|
|
|
|
if err = s.procRawMessage(raw); err != nil {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *session) procRawMessage(raw []byte) (err error) {
|
|
pw := payloadWrap{}
|
|
if err = json.Unmarshal(raw, &pw); err != nil {
|
|
return fmt.Errorf("could not unmarshal session message: %w", err)
|
|
}
|
|
|
|
if pw.Type == payloadTypeCredentials {
|
|
authPayload := &payloadAuth{}
|
|
if err = pw.UnmarshalValue(authPayload); err != nil {
|
|
return fmt.Errorf("could not unmarshal session payload: %w", err)
|
|
}
|
|
|
|
if err = s.authenticate(authPayload); err != nil {
|
|
return fmt.Errorf("unauthorized: %w", err)
|
|
}
|
|
|
|
s.logger.Debug(
|
|
"authenticated",
|
|
zap.Uint64("userID", s.identity.Identity()),
|
|
zap.Uint64s("roles", s.identity.Roles()),
|
|
)
|
|
|
|
s.server.StoreSession(s)
|
|
|
|
// not expecting anything else
|
|
return
|
|
}
|
|
|
|
if s.identity == nil {
|
|
return fmt.Errorf("unauthenticated session")
|
|
}
|
|
|
|
// at the moment we do not support any other kinds of message types
|
|
return fmt.Errorf("unknown message type '%s'", pw.Type)
|
|
}
|
|
|
|
func (s *session) writeLoop() error {
|
|
ticker := time.NewTicker(s.config.PingPeriod)
|
|
|
|
defer func() {
|
|
ticker.Stop()
|
|
s.Close() // break readLoop
|
|
}()
|
|
|
|
write := func(msg []byte) (err error) {
|
|
if s.conn == nil {
|
|
// Connection closed, nowhere to write
|
|
return
|
|
}
|
|
|
|
if err = s.conn.SetWriteDeadline(time.Now().Add(s.config.Timeout)); err != nil {
|
|
return fmt.Errorf("deadline error: %w", err)
|
|
}
|
|
|
|
if msg != nil && s.conn != nil {
|
|
return s.conn.WriteMessage(websocket.TextMessage, msg)
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
ping := func() (err error) {
|
|
if s.conn == nil {
|
|
// Connection closed, nothing to ping
|
|
return
|
|
}
|
|
|
|
if err = s.conn.SetWriteDeadline(time.Now().Add(s.config.Timeout)); err != nil {
|
|
return
|
|
}
|
|
|
|
if s.conn != nil {
|
|
return s.conn.WriteMessage(websocket.PingMessage, nil)
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
for {
|
|
if s.conn == nil {
|
|
return nil
|
|
}
|
|
|
|
select {
|
|
case msg, ok := <-s.send:
|
|
if !ok {
|
|
// channel closed
|
|
return nil
|
|
}
|
|
|
|
if err := errHandler("send failed", write(msg)); err != nil {
|
|
return err
|
|
}
|
|
|
|
case msg := <-s.stop:
|
|
// Shutdown requested, don't care if the message is delivered
|
|
_ = write(msg)
|
|
return nil
|
|
|
|
case <-ticker.C:
|
|
if err := ping(); err != nil {
|
|
return errHandler("ping failed", err)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *session) authenticate(p *payloadAuth) error {
|
|
claims, err := s.server.accessToken.Authenticate(p.AccessToken)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if !auth.CheckScope(claims["scope"], "api") {
|
|
return fmt.Errorf("client does not allow use of websockets (missing 'api' scope)")
|
|
}
|
|
|
|
// Get identity using JWT claims
|
|
identity := auth.ClaimsToIdentity(claims)
|
|
|
|
if s.identity != nil {
|
|
if s.identity.Identity() != identity.Identity() {
|
|
return fmt.Errorf("identity does not match")
|
|
}
|
|
}
|
|
|
|
if !identity.Valid() {
|
|
return fmt.Errorf("invalid identity")
|
|
}
|
|
|
|
s.identity = identity
|
|
s.Write([]byte(ok))
|
|
return nil
|
|
}
|
|
|
|
// sendBytes sends byte to channel or timeout
|
|
func (s *session) Write(p []byte) (int, error) {
|
|
select {
|
|
case s.send <- p:
|
|
return len(p), nil
|
|
case <-time.After(2 * time.Millisecond):
|
|
return 0, fmt.Errorf("write timedout")
|
|
}
|
|
}
|
|
|
|
func errHandler(wrap string, err error) error {
|
|
if err == nil {
|
|
return nil
|
|
}
|
|
|
|
if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
|
|
// normal closing
|
|
return nil
|
|
}
|
|
|
|
if errors.Is(err, net.ErrClosed) {
|
|
// suppress errors when reading/writing from/to a closed connection
|
|
return nil
|
|
}
|
|
return fmt.Errorf(wrap+": %w", err)
|
|
}
|