3
0
corteza/pkg/websocket/session.go
2021-05-13 14:08:54 +02:00

341 lines
6.4 KiB
Go

package websocket
import (
"context"
"encoding/json"
"fmt"
"github.com/cortezaproject/corteza-server/pkg/auth"
"github.com/cortezaproject/corteza-server/pkg/errors"
"github.com/cortezaproject/corteza-server/pkg/id"
"github.com/cortezaproject/corteza-server/pkg/options"
"github.com/gorilla/websocket"
"go.uber.org/zap"
"net"
"sync"
"time"
)
// active sessions of users
var (
// wrapper around nextID that will aid service testing
nextID = func() uint64 {
return id.Next()
}
)
type (
session struct {
id uint64
once sync.Once
conn *websocket.Conn
ctx context.Context
ctxCancel context.CancelFunc
logger *zap.Logger
send chan []byte
stop chan []byte
remoteAddr string
config options.WebsocketOpt
identity auth.Identifiable
server *server
}
)
func Session(ctx context.Context, ws *server, conn *websocket.Conn) *session {
s := &session{
id: nextID(),
conn: conn,
config: ws.config,
send: make(chan []byte, 512),
stop: make(chan []byte, 1),
server: ws,
}
s.ctx, s.ctxCancel = context.WithCancel(ctx)
s.logger = ws.logger.
Named("session").
With(
zap.Uint64("id", s.id),
)
return s
}
func (s *session) connected() (err error) {
s.logger.Info("connected", zap.String("remoteAddr", s.conn.RemoteAddr().String()))
//// Tell everyone that user has connected
//if err = s.sendPresence("connected"); err != nil {
// return
//}
//
//
//// Create a heartbeat every minute for this user
//go func() {
// defer sentry.Recover()
//
// t := time.NewTicker(time.Second * 60)
// for {
// select {
// case <-s.ctx.Done():
// return
// case <-t.C:
// _ = s.sendPresence("active")
// }
// }
//}()
return nil
}
func (s *session) disconnected() {
// Cancel context
s.ctxCancel()
s.logger.Info("disconnected")
// Close connection
_ = s.conn.Close()
s.conn = nil
}
//func (s *session) sendPresence(_ string) error {
// return nil
//}
func (s *session) Handle() (err error) {
if err = s.connected(); err != nil {
s.Close()
return
}
go func() {
// Close unidentified connections in 5sec
<-time.NewTimer(time.Second * 5).C
if s.identity == nil {
s.Write([]byte(closingUnidentifiedConn))
s.logger.Info("closing unidentified connection")
s.Close()
}
}()
go func() {
if err = s.readLoop(); err != nil {
s.logger.Error("read failure", zap.Error(err))
}
s.Close()
}()
if err = s.writeLoop(); err != nil {
s.logger.Error("write failure", zap.Error(err))
}
return
}
func (s *session) Close() {
s.once.Do(func() {
s.disconnected()
s.server.RemoveSession(s)
})
}
func (s *session) readLoop() (err error) {
if err = s.conn.SetReadDeadline(time.Now().Add(s.config.PingTimeout)); err != nil {
return
}
s.conn.SetPongHandler(func(string) error {
return s.conn.SetReadDeadline(time.Now().Add(s.config.PingTimeout))
})
s.remoteAddr = s.conn.RemoteAddr().String()
var (
raw []byte
)
for {
if s.conn == nil {
return nil
}
if _, raw, err = s.conn.ReadMessage(); err != nil {
return errHandler("read failed", err)
}
if err = s.procRawMessage(raw); err != nil {
return
}
}
}
func (s *session) procRawMessage(raw []byte) (err error) {
pw := payloadWrap{}
if err = json.Unmarshal(raw, &pw); err != nil {
return fmt.Errorf("could not unmarshal session message: %w", err)
}
if pw.Type == payloadTypeCredentials {
authPayload := &payloadAuth{}
if err = pw.UnmarshalValue(authPayload); err != nil {
return fmt.Errorf("could not unmarshal session payload: %w", err)
}
if err = s.authenticate(authPayload); err != nil {
return fmt.Errorf("unauthorized: %w", err)
}
s.logger.Debug(
"authenticated",
zap.Uint64("userID", s.identity.Identity()),
zap.Uint64s("roles", s.identity.Roles()),
)
s.server.StoreSession(s)
// not expecting anything else
return
}
if s.identity == nil {
return fmt.Errorf("unauthenticated session")
}
// at the moment we do not support any other kinds of message types
return fmt.Errorf("unknown message type '%s'", pw.Type)
}
func (s *session) writeLoop() error {
ticker := time.NewTicker(s.config.PingPeriod)
defer func() {
ticker.Stop()
s.Close() // break readLoop
}()
write := func(msg []byte) (err error) {
if s.conn == nil {
// Connection closed, nowhere to write
return
}
if err = s.conn.SetWriteDeadline(time.Now().Add(s.config.Timeout)); err != nil {
return fmt.Errorf("deadline error: %w", err)
}
if msg != nil && s.conn != nil {
return s.conn.WriteMessage(websocket.TextMessage, msg)
}
return
}
ping := func() (err error) {
if s.conn == nil {
// Connection closed, nothing to ping
return
}
if err = s.conn.SetWriteDeadline(time.Now().Add(s.config.Timeout)); err != nil {
return
}
if s.conn != nil {
return s.conn.WriteMessage(websocket.PingMessage, nil)
}
return
}
for {
if s.conn == nil {
return nil
}
select {
case msg, ok := <-s.send:
if !ok {
// channel closed
return nil
}
if err := errHandler("send failed", write(msg)); err != nil {
return err
}
case msg := <-s.stop:
// Shutdown requested, don't care if the message is delivered
_ = write(msg)
return nil
case <-ticker.C:
if err := ping(); err != nil {
return errHandler("ping failed", err)
}
}
}
}
func (s *session) authenticate(p *payloadAuth) error {
claims, err := s.server.accessToken.Authenticate(p.AccessToken)
if err != nil {
return err
}
if !auth.CheckScope(claims["scope"], "api") {
return fmt.Errorf("client does not allow use of websockets (missing 'api' scope)")
}
// Get identity using JWT claims
identity := auth.ClaimsToIdentity(claims)
if s.identity != nil {
if s.identity.Identity() != identity.Identity() {
return fmt.Errorf("identity does not match")
}
}
if !identity.Valid() {
return fmt.Errorf("invalid identity")
}
s.identity = identity
s.Write([]byte(ok))
return nil
}
// sendBytes sends byte to channel or timeout
func (s *session) Write(p []byte) (int, error) {
select {
case s.send <- p:
return len(p), nil
case <-time.After(2 * time.Millisecond):
return 0, fmt.Errorf("write timedout")
}
}
func errHandler(wrap string, err error) error {
if err == nil {
return nil
}
if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
// normal closing
return nil
}
if errors.Is(err, net.ErrClosed) {
// suppress errors when reading/writing from/to a closed connection
return nil
}
return fmt.Errorf(wrap+": %w", err)
}