Sends pending prompts via websocket
This commit is contained in:
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"github.com/cortezaproject/corteza-server/auth/settings"
|
||||
"github.com/cortezaproject/corteza-server/store"
|
||||
"github.com/cortezaproject/corteza-server/websocket"
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/spf13/cobra"
|
||||
"go.uber.org/zap"
|
||||
@@ -22,6 +21,10 @@ type (
|
||||
Serve(ctx context.Context)
|
||||
}
|
||||
|
||||
wsServer interface {
|
||||
MountRoutes(chi.Router)
|
||||
}
|
||||
|
||||
authServicer interface {
|
||||
MountHttpRoutes(string, chi.Router)
|
||||
UpdateSettings(*settings.Settings)
|
||||
@@ -48,7 +51,7 @@ type (
|
||||
// Servers
|
||||
HttpServer httpApiServer
|
||||
GrpcServer grpcServer
|
||||
WsServer *websocket.Websocket
|
||||
WsServer wsServer
|
||||
|
||||
AuthService authServicer
|
||||
}
|
||||
|
||||
@@ -319,11 +319,11 @@ func (app *CortezaApp) InitServices(ctx context.Context) (err error) {
|
||||
corredor.Service().SetUserFinder(sysService.DefaultUser)
|
||||
corredor.Service().SetRoleFinder(sysService.DefaultRole)
|
||||
|
||||
app.WsServer = websocket.New(&websocket.Config{
|
||||
app.WsServer = websocket.Websocket(&websocket.Config{
|
||||
Timeout: app.Opt.Websocket.Timeout,
|
||||
PingTimeout: app.Opt.Websocket.PingTimeout,
|
||||
PingPeriod: app.Opt.Websocket.PingPeriod,
|
||||
})
|
||||
}, app.Log)
|
||||
|
||||
if app.Opt.Federation.Enabled {
|
||||
// Initializes federation services
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
systemCommands "github.com/cortezaproject/corteza-server/system/commands"
|
||||
)
|
||||
|
||||
// CLI function initializes basic Corteza subsystems
|
||||
// InitCLI function initializes basic Corteza subsystems
|
||||
// and sets-up the command line interface
|
||||
func (app *CortezaApp) InitCLI() {
|
||||
ctx := cli.Context()
|
||||
|
||||
@@ -98,8 +98,7 @@ func (app *CortezaApp) mountHttpRoutes(r chi.Router) {
|
||||
r.Route("/system", systemRest.MountRoutes)
|
||||
r.Route("/automation", automationRest.MountRoutes)
|
||||
r.Route("/compose", composeRest.MountRoutes)
|
||||
|
||||
app.WsServer.MountRoutes(r)
|
||||
r.Route("/websocket", app.WsServer.MountRoutes)
|
||||
|
||||
if app.Opt.Federation.Enabled {
|
||||
r.Route("/federation", federationRest.MountRoutes)
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/cortezaproject/corteza-server/pkg/sentry"
|
||||
"github.com/cortezaproject/corteza-server/pkg/wfexec"
|
||||
"github.com/cortezaproject/corteza-server/store"
|
||||
"github.com/cortezaproject/corteza-server/websocket"
|
||||
"go.uber.org/zap"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -235,6 +236,7 @@ func (svc *session) spawn(g *wfexec.Graph, workflowID uint64, trace bool) (ses *
|
||||
|
||||
func (svc *session) Watch(ctx context.Context) {
|
||||
gcTicker := time.NewTicker(time.Second)
|
||||
promptTicker := time.NewTicker(time.Second)
|
||||
|
||||
go func() {
|
||||
defer sentry.Recover()
|
||||
@@ -272,6 +274,29 @@ func (svc *session) Watch(ctx context.Context) {
|
||||
//svc.suspendAll(ctx)
|
||||
}()
|
||||
|
||||
// Prompt routine
|
||||
go func() {
|
||||
defer sentry.Recover()
|
||||
defer promptTicker.Stop()
|
||||
defer svc.log.Info("stopped")
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-promptTicker.C:
|
||||
activeSessions := websocket.GetActiveSessions()
|
||||
for _, s := range activeSessions {
|
||||
var ctxr = context.Background()
|
||||
pp := svc.PendingPrompts(auth.SetIdentityToContext(ctxr, s.User()))
|
||||
if len(pp) > 0 {
|
||||
_ = s.Send(websocket.Message(websocket.StatusOK, websocket.WorkflowApplication, pp), s.User().Identity())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
svc.log.Debug("watcher initialized")
|
||||
}
|
||||
|
||||
|
||||
@@ -45,7 +45,7 @@ func JWT(secret string, expiry time.Duration) (tkn *token, err error) {
|
||||
return tkn, nil
|
||||
}
|
||||
|
||||
// Verifies JWT and stores it into context
|
||||
// HttpVerifier verifies JWT and stores it into context
|
||||
func (t *token) HttpVerifier() func(http.Handler) http.Handler {
|
||||
return jwtauth.Verifier(t.tokenAuth)
|
||||
}
|
||||
@@ -91,7 +91,7 @@ func (t *token) HttpAuthenticator() func(http.Handler) http.Handler {
|
||||
return
|
||||
}
|
||||
|
||||
ctx = SetIdentityToContext(ctx, claimsToIdentity(claims))
|
||||
ctx = SetIdentityToContext(ctx, ClaimsToIdentity(claims))
|
||||
ctx = context.WithValue(ctx, scopeCtxKey{}, claims["scope"])
|
||||
|
||||
r = r.WithContext(ctx)
|
||||
@@ -102,8 +102,8 @@ func (t *token) HttpAuthenticator() func(http.Handler) http.Handler {
|
||||
}
|
||||
}
|
||||
|
||||
// decodes sub & roles claims into identity
|
||||
func claimsToIdentity(c jwt.MapClaims) (i *Identity) {
|
||||
// ClaimsToIdentity decodes sub & roles claims into identity
|
||||
func ClaimsToIdentity(c jwt.MapClaims) (i *Identity) {
|
||||
var (
|
||||
aux interface{}
|
||||
ok bool
|
||||
@@ -140,7 +140,3 @@ func claimsToIdentity(c jwt.MapClaims) (i *Identity) {
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func ClaimsToIdentity(c jwt.MapClaims) (i *Identity) {
|
||||
return claimsToIdentity(c)
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ imports:
|
||||
|
||||
docs:
|
||||
title: Websocket server
|
||||
description: A Websocket server emphasize the trigger events and actions.
|
||||
|
||||
props:
|
||||
- name: Timeout
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
package payload
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"github.com/cortezaproject/corteza-server/pkg/payload/incoming"
|
||||
)
|
||||
|
||||
func Unmarshal(raw []byte) (*incoming.Payload, error) {
|
||||
var p = &incoming.Payload{}
|
||||
return p, json.Unmarshal(raw, p)
|
||||
}
|
||||
@@ -1,25 +0,0 @@
|
||||
package incoming
|
||||
|
||||
import (
|
||||
"github.com/dgrijalva/jwt-go"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type (
|
||||
Token struct {
|
||||
AccessToken *string `json:"access_token"`
|
||||
}
|
||||
)
|
||||
|
||||
func (t *Token) ParseWithClaims() (jwt.MapClaims, error) {
|
||||
token, err := jwt.Parse(*t.AccessToken, nil)
|
||||
if token == nil {
|
||||
return nil, err
|
||||
}
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if ok {
|
||||
return claims, nil
|
||||
} else {
|
||||
return nil, errors.New("Invalid token")
|
||||
}
|
||||
}
|
||||
@@ -1,7 +0,0 @@
|
||||
package incoming
|
||||
|
||||
type Payload struct {
|
||||
// Token is JWT token provided by client as first message,
|
||||
// and will be passed whenever it changes
|
||||
*Token `json:"token"`
|
||||
}
|
||||
@@ -1,25 +0,0 @@
|
||||
package payload
|
||||
|
||||
import (
|
||||
"github.com/cortezaproject/corteza-server/pkg/payload/outgoing"
|
||||
"github.com/cortezaproject/corteza-server/pkg/wfexec"
|
||||
)
|
||||
|
||||
func Prompt(p *wfexec.PendingPrompt) *outgoing.Prompt {
|
||||
return &outgoing.Prompt{
|
||||
Ref: p.Ref,
|
||||
SessionID: p.SessionID,
|
||||
CreatedAt: p.CreatedAt,
|
||||
StateID: p.StateID,
|
||||
Payload: p.Payload,
|
||||
}
|
||||
}
|
||||
|
||||
func Prompts(prompts []*wfexec.PendingPrompt) *outgoing.Prompts {
|
||||
ps := make([]*outgoing.Prompt, len(prompts))
|
||||
for k, p := range prompts {
|
||||
ps[k] = Prompt(p)
|
||||
}
|
||||
retval := outgoing.Prompts(ps)
|
||||
return &retval
|
||||
}
|
||||
@@ -1,19 +0,0 @@
|
||||
package outgoing
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
type (
|
||||
Error struct {
|
||||
Message string `json:"m"`
|
||||
}
|
||||
)
|
||||
|
||||
func (p *Error) EncodeMessage() ([]byte, error) {
|
||||
return json.Marshal(Payload{Error: p})
|
||||
}
|
||||
|
||||
func NewError(err error) *Error {
|
||||
return &Error{Message: err.Error()}
|
||||
}
|
||||
@@ -1,16 +0,0 @@
|
||||
package outgoing
|
||||
|
||||
type (
|
||||
Payload struct {
|
||||
*Error `json:"error,omitempty"`
|
||||
|
||||
*Prompt `json:"prompt,omitempty"`
|
||||
*Prompts `json:"prompts,omitempty"`
|
||||
}
|
||||
|
||||
// MessageEncoder This is same-same but different as using the json.Marshaler
|
||||
// (this one does not cause json.Marshal to call itself)
|
||||
MessageEncoder interface {
|
||||
EncodeMessage() ([]byte, error)
|
||||
}
|
||||
)
|
||||
@@ -1,18 +0,0 @@
|
||||
package outgoing
|
||||
|
||||
import (
|
||||
"github.com/cortezaproject/corteza-server/pkg/expr"
|
||||
"time"
|
||||
)
|
||||
|
||||
type (
|
||||
Prompt struct {
|
||||
Ref string `json:"ref"`
|
||||
SessionID uint64 `json:"sessionID,string"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
StateID uint64 `json:"stateID,string"`
|
||||
Payload *expr.Vars `json:"payload"`
|
||||
}
|
||||
|
||||
Prompts []*Prompt
|
||||
)
|
||||
@@ -1,7 +1,36 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
const (
|
||||
StatusOK = "ok"
|
||||
StatusError = "error"
|
||||
|
||||
WorkflowApplication = "Workflow"
|
||||
)
|
||||
|
||||
type (
|
||||
message struct {
|
||||
Status string `json:"status"`
|
||||
Application string `json:"application"`
|
||||
Data interface{} `json:"data"`
|
||||
}
|
||||
|
||||
MessageEncoder interface {
|
||||
EncodeMessage() ([]byte, error)
|
||||
}
|
||||
)
|
||||
|
||||
func Message(status, application string, data interface{}) *message {
|
||||
return &message{
|
||||
Status: status,
|
||||
Application: application,
|
||||
Data: data,
|
||||
}
|
||||
}
|
||||
|
||||
func (m message) EncodeMessage() ([]byte, error) {
|
||||
return json.Marshal(m)
|
||||
}
|
||||
|
||||
@@ -2,25 +2,15 @@ package websocket
|
||||
|
||||
import (
|
||||
"github.com/go-chi/chi"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func middlewareAllowedAccess(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
//if !service.DefaultAccessControl.CanAccess(r.Context()) {
|
||||
// http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
|
||||
// return
|
||||
//}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func (ws Websocket) MountRoutes(r chi.Router) {
|
||||
// MountRoutes initialize route for websocket
|
||||
// No middleware used, since anyone can open connection and
|
||||
// send first message with valid JWT token,
|
||||
// If it's valid then we keep the connection open or close it
|
||||
func (ws *websocket) MountRoutes(r chi.Router) {
|
||||
// Initialize handlers & controllers.
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Route("/websocket", func(r chi.Router) {
|
||||
//r.Use(middlewareAllowedAccess)
|
||||
r.Get("/", ws.Open)
|
||||
})
|
||||
r.Get("/", ws.Open)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -2,25 +2,27 @@ package websocket
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/cortezaproject/corteza-server/pkg/id"
|
||||
"github.com/cortezaproject/corteza-server/pkg/logger"
|
||||
"github.com/cortezaproject/corteza-server/pkg/payload/outgoing"
|
||||
"github.com/getsentry/sentry-go"
|
||||
"github.com/pkg/errors"
|
||||
"go.uber.org/zap/zapcore"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
gWebsocket "github.com/gorilla/websocket"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/cortezaproject/corteza-server/pkg/auth"
|
||||
)
|
||||
|
||||
var sessions = make(map[uint64]*session)
|
||||
|
||||
type (
|
||||
Session struct {
|
||||
session struct {
|
||||
id uint64
|
||||
once sync.Once
|
||||
Conn *websocket.Conn
|
||||
conn *gWebsocket.Conn
|
||||
|
||||
ctx context.Context
|
||||
ctxCancel context.CancelFunc
|
||||
@@ -38,9 +40,9 @@ type (
|
||||
}
|
||||
)
|
||||
|
||||
func (*Session) New(ctx context.Context, config *Config, conn *websocket.Conn) *Session {
|
||||
s := &Session{
|
||||
Conn: conn,
|
||||
func Session(ctx context.Context, config *Config, conn *gWebsocket.Conn) *session {
|
||||
s := &session{
|
||||
conn: conn,
|
||||
config: config,
|
||||
send: make(chan []byte, 512),
|
||||
stop: make(chan []byte, 1),
|
||||
@@ -53,15 +55,19 @@ func (*Session) New(ctx context.Context, config *Config, conn *websocket.Conn) *
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Session) log(fields ...zapcore.Field) *zap.Logger {
|
||||
func (s *session) log(fields ...zapcore.Field) *zap.Logger {
|
||||
return s.logger.With(fields...)
|
||||
}
|
||||
|
||||
func (s *Session) Context() context.Context {
|
||||
func (s *session) Context() context.Context {
|
||||
return s.ctx
|
||||
}
|
||||
|
||||
func (s *Session) connected() (err error) {
|
||||
func (s *session) User() auth.Identifiable {
|
||||
return s.user
|
||||
}
|
||||
|
||||
func (s *session) connected() (err error) {
|
||||
// Tell everyone that user has connected
|
||||
if err = s.sendPresence("connected"); err != nil {
|
||||
return
|
||||
@@ -77,7 +83,7 @@ func (s *Session) connected() (err error) {
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
case <-t.C:
|
||||
s.sendPresence("")
|
||||
_ = s.sendPresence("")
|
||||
}
|
||||
}
|
||||
}()
|
||||
@@ -85,7 +91,7 @@ func (s *Session) connected() (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Session) disconnected() {
|
||||
func (s *session) disconnected() {
|
||||
// Tell everyone that user has disconnected
|
||||
_ = s.sendPresence("disconnected")
|
||||
|
||||
@@ -93,68 +99,68 @@ func (s *Session) disconnected() {
|
||||
s.ctxCancel()
|
||||
|
||||
// Close connection
|
||||
s.Conn.Close()
|
||||
s.Conn = nil
|
||||
s.conn.Close()
|
||||
s.conn = nil
|
||||
}
|
||||
|
||||
// Sends user presence information to all subscribers
|
||||
//
|
||||
// It sends "connected", "disconnected" and "" activity kinds
|
||||
func (s *Session) sendPresence(kind string) error {
|
||||
connections := store.CountConnections(s.user.Identity())
|
||||
if kind == "disconnected" {
|
||||
connections--
|
||||
}
|
||||
// sendPresence sends user presence: "connected", "disconnected" and "" activity kinds
|
||||
func (s *session) sendPresence(kind string) error {
|
||||
//connections := store.CountConnections(s.user.Identity())
|
||||
//if kind == "disconnected" {
|
||||
// connections--
|
||||
//}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Session) Handle() (err error) {
|
||||
func (s *session) Handle() (err error) {
|
||||
if err = s.connected(); err != nil {
|
||||
s.Close()
|
||||
return
|
||||
}
|
||||
|
||||
go s.readLoop()
|
||||
go func() {
|
||||
_ = s.readLoop()
|
||||
}()
|
||||
return s.writeLoop()
|
||||
}
|
||||
|
||||
func (s *Session) Close() {
|
||||
func (s *session) Close() {
|
||||
s.once.Do(func() {
|
||||
s.disconnected()
|
||||
store.Delete(s.id)
|
||||
s.Delete()
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Session) readLoop() (err error) {
|
||||
func (s *session) readLoop() (err error) {
|
||||
defer func() {
|
||||
s.Close()
|
||||
}()
|
||||
|
||||
if err = s.Conn.SetReadDeadline(time.Now().Add(s.config.PingTimeout)); err != nil {
|
||||
if err = s.conn.SetReadDeadline(time.Now().Add(s.config.PingTimeout)); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
s.Conn.SetPongHandler(func(string) error {
|
||||
return s.Conn.SetReadDeadline(time.Now().Add(s.config.PingTimeout))
|
||||
s.conn.SetPongHandler(func(string) error {
|
||||
return s.conn.SetReadDeadline(time.Now().Add(s.config.PingTimeout))
|
||||
})
|
||||
|
||||
s.remoteAddr = s.Conn.RemoteAddr().String()
|
||||
s.remoteAddr = s.conn.RemoteAddr().String()
|
||||
|
||||
for {
|
||||
_, raw, err := s.Conn.ReadMessage()
|
||||
_, raw, err := s.conn.ReadMessage()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "s.readLoop")
|
||||
}
|
||||
|
||||
if err = s.dispatch(raw); err != nil {
|
||||
s.log(zap.Error(err)).Error("could not dispatch")
|
||||
_ = s.sendReply(outgoing.NewError(err))
|
||||
//_ = s.send(outgoing.NewError(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Session) writeLoop() error {
|
||||
func (s *session) writeLoop() error {
|
||||
ticker := time.NewTicker(s.config.PingPeriod)
|
||||
|
||||
defer func() {
|
||||
@@ -163,34 +169,34 @@ func (s *Session) writeLoop() error {
|
||||
}()
|
||||
|
||||
write := func(msg []byte) (err error) {
|
||||
if s.Conn == nil {
|
||||
if s.conn == nil {
|
||||
// Connection closed, nowhere to write
|
||||
return
|
||||
}
|
||||
|
||||
if err = s.Conn.SetWriteDeadline(time.Now().Add(s.config.Timeout)); err != nil {
|
||||
if err = s.conn.SetWriteDeadline(time.Now().Add(s.config.Timeout)); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if msg != nil && s.Conn != nil {
|
||||
return s.Conn.WriteMessage(websocket.TextMessage, msg)
|
||||
if msg != nil && s.conn != nil {
|
||||
return s.conn.WriteMessage(gWebsocket.TextMessage, msg)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
ping := func() (err error) {
|
||||
if s.Conn == nil {
|
||||
if s.conn == nil {
|
||||
// Connection closed, nothing to ping
|
||||
return
|
||||
}
|
||||
|
||||
if err = s.Conn.SetWriteDeadline(time.Now().Add(s.config.Timeout)); err != nil {
|
||||
if err = s.conn.SetWriteDeadline(time.Now().Add(s.config.Timeout)); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if s.Conn != nil {
|
||||
return s.Conn.WriteMessage(websocket.PingMessage, nil)
|
||||
if s.conn != nil {
|
||||
return s.conn.WriteMessage(gWebsocket.PingMessage, nil)
|
||||
}
|
||||
|
||||
return
|
||||
@@ -220,3 +226,90 @@ func (s *Session) writeLoop() error {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *session) dispatch(raw []byte) error {
|
||||
var p, err = Unmarshal(raw)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Session.incoming: payload malformed")
|
||||
}
|
||||
|
||||
ctx := s.Context()
|
||||
|
||||
if p.Auth != nil {
|
||||
return s.authenticate(ctx, p.Auth)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *session) Save() *session {
|
||||
if s.id == 0 {
|
||||
s.id = id.Next()
|
||||
}
|
||||
|
||||
if s.user != nil {
|
||||
if _, ok := sessions[s.user.Identity()]; !ok {
|
||||
sessions[s.user.Identity()] = s
|
||||
}
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *session) Get(userID uint64) *session {
|
||||
if sess, ok := sessions[userID]; ok {
|
||||
return sess
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *session) Delete() error {
|
||||
if s.id == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if s.user != nil {
|
||||
delete(sessions, s.user.Identity())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Send sends message to user to ones we want to
|
||||
// if len(userIDs) == 0 -- it sends to everyone
|
||||
func (s *session) Send(m *message, userIDs ...uint64) error {
|
||||
pb, err := m.EncodeMessage()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sendsToAll := len(userIDs) == 0
|
||||
userIDMap := make(map[uint64]uint64)
|
||||
for _, userID := range userIDs {
|
||||
userIDMap[userID] = userID
|
||||
}
|
||||
|
||||
for _, sess := range sessions {
|
||||
_, validUser := userIDMap[sess.user.Identity()]
|
||||
if sendsToAll || (!sendsToAll && validUser) {
|
||||
_ = sess.sendBytes(pb)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// sendBytes sends byte to channel or timout
|
||||
func (s *session) sendBytes(p []byte) error {
|
||||
select {
|
||||
case s.send <- p:
|
||||
case <-time.After(2 * time.Millisecond):
|
||||
s.logger.Warn("websocket.sendBytes send timeout")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetActiveSessions() map[uint64]*session {
|
||||
return sessions
|
||||
}
|
||||
|
||||
@@ -3,10 +3,9 @@ package websocket
|
||||
import (
|
||||
"context"
|
||||
"github.com/cortezaproject/corteza-server/pkg/auth"
|
||||
"github.com/cortezaproject/corteza-server/pkg/payload/incoming"
|
||||
)
|
||||
|
||||
func (s *Session) authenticate(ctx context.Context, p *incoming.Token) error {
|
||||
func (s *session) authenticate(ctx context.Context, p *Auth) error {
|
||||
// Get JWT claims
|
||||
claims, err := p.ParseWithClaims()
|
||||
if err != nil {
|
||||
@@ -18,13 +17,11 @@ func (s *Session) authenticate(ctx context.Context, p *incoming.Token) error {
|
||||
identity := auth.ClaimsToIdentity(claims)
|
||||
|
||||
// Update the existing ws sessions if exists or create new one
|
||||
if store.CountConnections(identity.Identity()) > 0 {
|
||||
store.Walk(func(session *Session) {
|
||||
session.user = identity
|
||||
})
|
||||
} else {
|
||||
if existingSession := s.Get(identity.Identity()); existingSession == nil {
|
||||
s.user = identity
|
||||
store.Save(s)
|
||||
s.Save()
|
||||
} else {
|
||||
existingSession.user = identity
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -1,23 +0,0 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"github.com/cortezaproject/corteza-server/pkg/payload"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
func (s *Session) dispatch(raw []byte) error {
|
||||
var p, err = payload.Unmarshal(raw)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Session.incoming: payload malformed")
|
||||
}
|
||||
|
||||
ctx := s.Context()
|
||||
|
||||
switch {
|
||||
// Access token
|
||||
case p.Token != nil:
|
||||
return s.authenticate(ctx, p.Token)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,25 +0,0 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"github.com/cortezaproject/corteza-server/pkg/logger"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SendReply sends message only on this session, no need to enqueue item
|
||||
func (s *Session) sendReply(p MessageEncoder) error {
|
||||
pb, err := p.EncodeMessage()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.sendBytes(pb)
|
||||
}
|
||||
|
||||
func (s *Session) sendBytes(p []byte) error {
|
||||
select {
|
||||
case s.send <- p:
|
||||
case <-time.After(2 * time.Millisecond):
|
||||
logger.Default().Warn("websocket.sendBytes send timeout")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,78 +0,0 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"github.com/cortezaproject/corteza-server/pkg/id"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var store *Store
|
||||
|
||||
type (
|
||||
Store struct {
|
||||
sync.RWMutex
|
||||
|
||||
Sessions map[uint64]*Session
|
||||
}
|
||||
)
|
||||
|
||||
func init() {
|
||||
store = NewStore()
|
||||
}
|
||||
|
||||
func NewStore() *Store {
|
||||
return &Store{sync.RWMutex{}, make(map[uint64]*Session)}
|
||||
}
|
||||
|
||||
func (s *Store) Save(session *Session) *Session {
|
||||
session.id = id.Next()
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
s.Sessions[session.id] = session
|
||||
return session
|
||||
}
|
||||
|
||||
func (s *Store) Get(id uint64) *Session {
|
||||
s.RLock()
|
||||
defer s.RUnlock()
|
||||
return s.Sessions[id]
|
||||
}
|
||||
|
||||
func (s *Store) Delete(id uint64) {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
delete(s.Sessions, id)
|
||||
}
|
||||
|
||||
func (s *Store) Walk(callback func(*Session)) {
|
||||
s.RLock()
|
||||
defer s.RUnlock()
|
||||
for _, sess := range s.Sessions {
|
||||
callback(sess)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Store) CountConnections(userID uint64) (count uint) {
|
||||
s.Walk(func(session *Session) {
|
||||
if session.user.Identity() == userID {
|
||||
count++
|
||||
}
|
||||
})
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// GetConnectedUsers gets all connected user to ws session
|
||||
func GetConnectedUsers() []uint64 {
|
||||
var chk = map[uint64]bool{}
|
||||
|
||||
store.Walk(func(session *Session) {
|
||||
chk[session.user.Identity()] = true
|
||||
})
|
||||
|
||||
var out = make([]uint64, 0)
|
||||
for ID := range chk {
|
||||
out = append(out, ID)
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
36
websocket/types.go
Normal file
36
websocket/types.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/dgrijalva/jwt-go"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type (
|
||||
Payload struct {
|
||||
// Auth is JWT token provided by client as first message,
|
||||
// and will be passed whenever it changes
|
||||
*Auth `json:"auth"`
|
||||
}
|
||||
|
||||
Auth struct {
|
||||
AccessToken *string `json:"access_token"`
|
||||
}
|
||||
)
|
||||
|
||||
func (a *Auth) ParseWithClaims() (jwt.MapClaims, error) {
|
||||
token, err := jwt.Parse(*a.AccessToken, nil)
|
||||
if token == nil {
|
||||
return nil, err
|
||||
}
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if ok {
|
||||
return claims, nil
|
||||
} else {
|
||||
return nil, errors.New("Invalid token")
|
||||
}
|
||||
}
|
||||
|
||||
func Unmarshal(raw []byte) (p *Payload, err error) {
|
||||
return p, json.Unmarshal(raw, p)
|
||||
}
|
||||
@@ -1,18 +1,16 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/cortezaproject/corteza-server/pkg/api"
|
||||
"github.com/cortezaproject/corteza-server/pkg/logger"
|
||||
"github.com/gorilla/websocket"
|
||||
gWebsocket "github.com/gorilla/websocket"
|
||||
"github.com/pkg/errors"
|
||||
"go.uber.org/zap"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
var (
|
||||
// Handles websocket requests from peers
|
||||
upgrader = websocket.Upgrader{
|
||||
// upgrader handles websocket requests from peers
|
||||
upgrader = gWebsocket.Upgrader{
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
|
||||
@@ -22,37 +20,37 @@ var (
|
||||
)
|
||||
|
||||
type (
|
||||
Websocket struct {
|
||||
websocket struct {
|
||||
config *Config
|
||||
logger *zap.Logger
|
||||
}
|
||||
)
|
||||
|
||||
func New(config *Config) *Websocket {
|
||||
ws := &Websocket{
|
||||
func Websocket(config *Config, logger *zap.Logger) *websocket {
|
||||
return &websocket{
|
||||
config: config,
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
return ws
|
||||
}
|
||||
|
||||
func (ws Websocket) Open(w http.ResponseWriter, r *http.Request) {
|
||||
func (ws *websocket) Open(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if _, ok := err.(websocket.HandshakeError); ok {
|
||||
fmt.Println("ws: need a websocket handshake")
|
||||
if _, ok := err.(gWebsocket.HandshakeError); ok {
|
||||
ws.logger.Error("ws: need a websocket handshake")
|
||||
api.Send(w, r, errors.Wrap(err, "ws: need a websocket handshake"))
|
||||
return
|
||||
} else if err != nil {
|
||||
fmt.Println("ws: failed to upgrade connection")
|
||||
ws.logger.Error("ws: failed to upgrade connection")
|
||||
api.Send(w, r, errors.Wrap(err, "ws: failed to upgrade connection"))
|
||||
return
|
||||
}
|
||||
|
||||
session := (&Session{}).New(ctx, ws.config, conn)
|
||||
session := Session(ctx, ws.config, conn)
|
||||
|
||||
if err := session.Handle(); err != nil {
|
||||
logger.Default().
|
||||
ws.logger.
|
||||
WithOptions(zap.AddStacktrace(zap.PanicLevel)).
|
||||
Warn("websocket session handler error", zap.Error(err))
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user