138 lines
3.0 KiB
Go
138 lines
3.0 KiB
Go
package websocket
|
|
|
|
import (
|
|
"context"
|
|
"io"
|
|
"net/http"
|
|
"sync"
|
|
|
|
"github.com/cortezaproject/corteza/server/pkg/auth"
|
|
"github.com/cortezaproject/corteza/server/pkg/errors"
|
|
"github.com/cortezaproject/corteza/server/pkg/options"
|
|
"github.com/cortezaproject/corteza/server/pkg/slice"
|
|
"github.com/gorilla/websocket"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
type (
|
|
tokenValidator func(context.Context, string) (auth.Identifiable, error)
|
|
|
|
server struct {
|
|
config options.WebsocketOpt
|
|
logger *zap.Logger
|
|
|
|
// user id => session id => session
|
|
sessions map[uint64]map[uint64]io.Writer
|
|
|
|
// keep lock on session map changes
|
|
l sync.RWMutex
|
|
|
|
tokenValidator tokenValidator
|
|
}
|
|
)
|
|
|
|
var (
|
|
// upgrader handles websocket requests from peers
|
|
upgrader = websocket.Upgrader{
|
|
ReadBufferSize: 1024,
|
|
WriteBufferSize: 1024,
|
|
|
|
// Allow connections from any Origin
|
|
CheckOrigin: func(r *http.Request) bool { return true },
|
|
}
|
|
)
|
|
|
|
func Server(logger *zap.Logger, config options.WebsocketOpt, tv tokenValidator) *server {
|
|
if !config.LogEnabled {
|
|
logger = zap.NewNop()
|
|
}
|
|
|
|
return &server{
|
|
config: config,
|
|
logger: logger.Named("websocket"),
|
|
sessions: make(map[uint64]map[uint64]io.Writer),
|
|
tokenValidator: tv,
|
|
}
|
|
}
|
|
|
|
func (ws *server) Open(w http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
|
|
conn, err := upgrader.Upgrade(w, r, nil)
|
|
if _, ok := err.(websocket.HandshakeError); ok {
|
|
errors.ProperlyServeHTTP(w, r, errors.Internal("need a websocket handshake"), false)
|
|
return
|
|
} else if err != nil {
|
|
errors.ProperlyServeHTTP(w, r, errors.Internal("failed to upgrade connection").Wrap(err), false)
|
|
return
|
|
}
|
|
|
|
// init new session
|
|
//
|
|
// session will add itself back to server's session map when
|
|
// ready (if user authenticates itself)
|
|
ses := Session(ctx, ws, conn)
|
|
|
|
if err = ses.Handle(); err != nil {
|
|
ws.logger.Warn("websocket session handler error", zap.Error(err))
|
|
}
|
|
}
|
|
|
|
// Send delivers payload to one, more or all users
|
|
//
|
|
// Omit userIDs to deliver to ALL users
|
|
func (ws *server) Send(t string, payload interface{}, userIDs ...uint64) error {
|
|
pb, err := MarshalPayload(t, payload)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
var (
|
|
sendToAll = len(userIDs) == 0
|
|
uMap = slice.ToUint64BoolMap(userIDs)
|
|
)
|
|
|
|
ws.l.RLock()
|
|
defer ws.l.RUnlock()
|
|
|
|
for uid := range ws.sessions {
|
|
if sendToAll || (!sendToAll && uMap[uid]) {
|
|
for _, s := range ws.sessions[uid] {
|
|
_, err = s.Write(pb)
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (ws *server) StoreSession(s *session) {
|
|
ws.l.Lock()
|
|
defer ws.l.Unlock()
|
|
if i := s.Identity(); i != nil {
|
|
ws.storeSession(s, i.Identity(), s.id)
|
|
}
|
|
}
|
|
|
|
func (ws *server) storeSession(w io.Writer, uid, sid uint64) {
|
|
if ws.sessions[uid] == nil {
|
|
ws.sessions[uid] = make(map[uint64]io.Writer)
|
|
|
|
}
|
|
|
|
ws.sessions[uid][sid] = w
|
|
}
|
|
|
|
func (ws *server) RemoveSession(s *session) {
|
|
ws.l.Lock()
|
|
defer ws.l.Unlock()
|
|
if i := s.Identity(); i != nil {
|
|
uid := i.Identity()
|
|
delete(ws.sessions[uid], s.id)
|
|
|
|
if len(ws.sessions[uid]) == 0 {
|
|
delete(ws.sessions, uid)
|
|
}
|
|
}
|
|
}
|