310 lines
5.5 KiB
Go
310 lines
5.5 KiB
Go
package websocket
|
|
|
|
import (
|
|
"context"
|
|
"github.com/cortezaproject/corteza-server/pkg/id"
|
|
"github.com/cortezaproject/corteza-server/pkg/options"
|
|
"github.com/getsentry/sentry-go"
|
|
"github.com/pkg/errors"
|
|
"go.uber.org/zap/zapcore"
|
|
"sync"
|
|
"time"
|
|
|
|
gWebsocket "github.com/gorilla/websocket"
|
|
"go.uber.org/zap"
|
|
|
|
"github.com/cortezaproject/corteza-server/pkg/auth"
|
|
)
|
|
|
|
// active sessions of users
|
|
var sessions = make(map[uint64][]*session)
|
|
|
|
type (
|
|
session struct {
|
|
id uint64
|
|
once sync.Once
|
|
conn *gWebsocket.Conn
|
|
|
|
ctx context.Context
|
|
ctxCancel context.CancelFunc
|
|
|
|
logger *zap.Logger
|
|
|
|
send chan []byte
|
|
stop chan []byte
|
|
|
|
remoteAddr string
|
|
|
|
config options.WebsocketOpt
|
|
|
|
user auth.Identifiable
|
|
}
|
|
)
|
|
|
|
func Session(ctx context.Context, logger *zap.Logger, config options.WebsocketOpt, conn *gWebsocket.Conn) *session {
|
|
s := &session{
|
|
conn: conn,
|
|
config: config,
|
|
send: make(chan []byte, 512),
|
|
stop: make(chan []byte, 1),
|
|
}
|
|
|
|
s.ctx, s.ctxCancel = context.WithCancel(ctx)
|
|
|
|
s.logger = logger
|
|
|
|
return s
|
|
}
|
|
|
|
func (s *session) log(fields ...zapcore.Field) *zap.Logger {
|
|
return s.logger.With(fields...)
|
|
}
|
|
|
|
func (s *session) Context() context.Context {
|
|
return s.ctx
|
|
}
|
|
|
|
func (s *session) User() auth.Identifiable {
|
|
return s.user
|
|
}
|
|
|
|
func (s *session) connected() (err error) {
|
|
// Tell everyone that user has connected
|
|
if err = s.sendPresence("connected"); err != nil {
|
|
return
|
|
}
|
|
|
|
// Create a heartbeat every minute for this user
|
|
go func() {
|
|
defer sentry.Recover()
|
|
|
|
t := time.NewTicker(time.Second * 60)
|
|
for {
|
|
select {
|
|
case <-s.ctx.Done():
|
|
return
|
|
case <-t.C:
|
|
_ = s.sendPresence("")
|
|
}
|
|
}
|
|
}()
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *session) disconnected() {
|
|
// Tell everyone that user has disconnected
|
|
_ = s.sendPresence("disconnected")
|
|
|
|
// Cancel context
|
|
s.ctxCancel()
|
|
|
|
// Close connection
|
|
_ = s.conn.Close()
|
|
s.conn = nil
|
|
}
|
|
|
|
// sendPresence sends user presence: "connected", "disconnected" and "" activity kinds
|
|
func (s *session) sendPresence(kind string) error {
|
|
//connections := store.CountConnections(s.user.Identity())
|
|
//if kind == "disconnected" {
|
|
// connections--
|
|
//}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *session) Handle() (err error) {
|
|
if err = s.connected(); err != nil {
|
|
s.Close()
|
|
return
|
|
}
|
|
|
|
go func() {
|
|
_ = s.readLoop()
|
|
}()
|
|
return s.writeLoop()
|
|
}
|
|
|
|
func (s *session) Close() {
|
|
s.once.Do(func() {
|
|
s.disconnected()
|
|
_ = s.Delete()
|
|
})
|
|
}
|
|
|
|
func (s *session) readLoop() (err error) {
|
|
defer func() {
|
|
s.Close()
|
|
}()
|
|
|
|
if err = s.conn.SetReadDeadline(time.Now().Add(s.config.PingTimeout)); err != nil {
|
|
return
|
|
}
|
|
|
|
s.conn.SetPongHandler(func(string) error {
|
|
return s.conn.SetReadDeadline(time.Now().Add(s.config.PingTimeout))
|
|
})
|
|
|
|
s.remoteAddr = s.conn.RemoteAddr().String()
|
|
|
|
for {
|
|
_, raw, err := s.conn.ReadMessage()
|
|
if err != nil {
|
|
return errors.Wrap(err, "s.readLoop")
|
|
}
|
|
|
|
if err = s.dispatch(raw); err != nil {
|
|
s.log(zap.Error(err)).Error("could not dispatch")
|
|
//_ = s.send(outgoing.NewError(err))
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *session) writeLoop() error {
|
|
ticker := time.NewTicker(s.config.PingPeriod)
|
|
|
|
defer func() {
|
|
ticker.Stop()
|
|
s.Close() // break readLoop
|
|
}()
|
|
|
|
write := func(msg []byte) (err error) {
|
|
if s.conn == nil {
|
|
// Connection closed, nowhere to write
|
|
return
|
|
}
|
|
|
|
if err = s.conn.SetWriteDeadline(time.Now().Add(s.config.Timeout)); err != nil {
|
|
return
|
|
}
|
|
|
|
if msg != nil && s.conn != nil {
|
|
return s.conn.WriteMessage(gWebsocket.TextMessage, msg)
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
ping := func() (err error) {
|
|
if s.conn == nil {
|
|
// Connection closed, nothing to ping
|
|
return
|
|
}
|
|
|
|
if err = s.conn.SetWriteDeadline(time.Now().Add(s.config.Timeout)); err != nil {
|
|
return
|
|
}
|
|
|
|
if s.conn != nil {
|
|
return s.conn.WriteMessage(gWebsocket.PingMessage, nil)
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
for {
|
|
select {
|
|
case msg, ok := <-s.send:
|
|
if !ok {
|
|
// channel closed
|
|
return nil
|
|
}
|
|
|
|
if err := write(msg); err != nil {
|
|
return errors.Wrap(err, "writeLoop send")
|
|
}
|
|
|
|
case msg := <-s.stop:
|
|
// Shutdown requested, don't care if the message is delivered
|
|
_ = write(msg)
|
|
return nil
|
|
|
|
case <-ticker.C:
|
|
if err := ping(); err != nil {
|
|
return errors.Wrap(err, "writeLoop ping")
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *session) dispatch(raw []byte) error {
|
|
var p, err = Unmarshal(raw)
|
|
if err != nil {
|
|
return errors.Wrap(err, "Session.incoming: payload malformed")
|
|
}
|
|
|
|
if p.Auth != nil {
|
|
return s.authenticate(p.Auth)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *session) authenticate(p *Auth) error {
|
|
// Get JWT claims
|
|
claims, err := p.ParseWithClaims()
|
|
if err != nil {
|
|
s.Close()
|
|
return err
|
|
}
|
|
|
|
// Get identity using JWT claims
|
|
identity := auth.ClaimsToIdentity(claims)
|
|
s.Save(identity)
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *session) Save(identity *auth.Identity) *session {
|
|
if s.id == 0 {
|
|
s.id = id.Next()
|
|
}
|
|
|
|
if identity != nil {
|
|
userID := identity.Identity()
|
|
existingSessions, ok := sessions[userID]
|
|
// Add sessions for user
|
|
if !ok {
|
|
s.user = identity
|
|
sessions[userID] = append(sessions[userID], s)
|
|
}
|
|
|
|
// Update the identity in existing sessions
|
|
for _, sess := range existingSessions {
|
|
sess.user = identity
|
|
}
|
|
}
|
|
|
|
return s
|
|
}
|
|
|
|
func (s *session) Get(userID uint64) []*session {
|
|
if sess, ok := sessions[userID]; ok {
|
|
return sess
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *session) Delete() error {
|
|
if s.id == 0 {
|
|
return nil
|
|
}
|
|
|
|
if s.user != nil {
|
|
delete(sessions, s.user.Identity())
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// sendBytes sends byte to channel or timout
|
|
func (s *session) sendBytes(p []byte) error {
|
|
select {
|
|
case s.send <- p:
|
|
case <-time.After(2 * time.Millisecond):
|
|
s.logger.Warn("websocket.sendBytes send timeout")
|
|
}
|
|
return nil
|
|
}
|