3
0

Sends pending prompts via websocket

This commit is contained in:
Vivek Patel
2021-04-21 21:06:52 +05:30
committed by Denis Arh
parent d6b3278b6e
commit 21c9e9010e
23 changed files with 265 additions and 346 deletions

View File

@@ -4,7 +4,6 @@ import (
"context"
"github.com/cortezaproject/corteza-server/auth/settings"
"github.com/cortezaproject/corteza-server/store"
"github.com/cortezaproject/corteza-server/websocket"
"github.com/go-chi/chi"
"github.com/spf13/cobra"
"go.uber.org/zap"
@@ -22,6 +21,10 @@ type (
Serve(ctx context.Context)
}
wsServer interface {
MountRoutes(chi.Router)
}
authServicer interface {
MountHttpRoutes(string, chi.Router)
UpdateSettings(*settings.Settings)
@@ -48,7 +51,7 @@ type (
// Servers
HttpServer httpApiServer
GrpcServer grpcServer
WsServer *websocket.Websocket
WsServer wsServer
AuthService authServicer
}

View File

@@ -319,11 +319,11 @@ func (app *CortezaApp) InitServices(ctx context.Context) (err error) {
corredor.Service().SetUserFinder(sysService.DefaultUser)
corredor.Service().SetRoleFinder(sysService.DefaultRole)
app.WsServer = websocket.New(&websocket.Config{
app.WsServer = websocket.Websocket(&websocket.Config{
Timeout: app.Opt.Websocket.Timeout,
PingTimeout: app.Opt.Websocket.PingTimeout,
PingPeriod: app.Opt.Websocket.PingPeriod,
})
}, app.Log)
if app.Opt.Federation.Enabled {
// Initializes federation services

View File

@@ -11,7 +11,7 @@ import (
systemCommands "github.com/cortezaproject/corteza-server/system/commands"
)
// CLI function initializes basic Corteza subsystems
// InitCLI function initializes basic Corteza subsystems
// and sets-up the command line interface
func (app *CortezaApp) InitCLI() {
ctx := cli.Context()

View File

@@ -98,8 +98,7 @@ func (app *CortezaApp) mountHttpRoutes(r chi.Router) {
r.Route("/system", systemRest.MountRoutes)
r.Route("/automation", automationRest.MountRoutes)
r.Route("/compose", composeRest.MountRoutes)
app.WsServer.MountRoutes(r)
r.Route("/websocket", app.WsServer.MountRoutes)
if app.Opt.Federation.Enabled {
r.Route("/federation", federationRest.MountRoutes)

View File

@@ -11,6 +11,7 @@ import (
"github.com/cortezaproject/corteza-server/pkg/sentry"
"github.com/cortezaproject/corteza-server/pkg/wfexec"
"github.com/cortezaproject/corteza-server/store"
"github.com/cortezaproject/corteza-server/websocket"
"go.uber.org/zap"
"sync"
"time"
@@ -235,6 +236,7 @@ func (svc *session) spawn(g *wfexec.Graph, workflowID uint64, trace bool) (ses *
func (svc *session) Watch(ctx context.Context) {
gcTicker := time.NewTicker(time.Second)
promptTicker := time.NewTicker(time.Second)
go func() {
defer sentry.Recover()
@@ -272,6 +274,29 @@ func (svc *session) Watch(ctx context.Context) {
//svc.suspendAll(ctx)
}()
// Prompt routine
go func() {
defer sentry.Recover()
defer promptTicker.Stop()
defer svc.log.Info("stopped")
for {
select {
case <-ctx.Done():
return
case <-promptTicker.C:
activeSessions := websocket.GetActiveSessions()
for _, s := range activeSessions {
var ctxr = context.Background()
pp := svc.PendingPrompts(auth.SetIdentityToContext(ctxr, s.User()))
if len(pp) > 0 {
_ = s.Send(websocket.Message(websocket.StatusOK, websocket.WorkflowApplication, pp), s.User().Identity())
}
}
}
}
}()
svc.log.Debug("watcher initialized")
}

View File

@@ -45,7 +45,7 @@ func JWT(secret string, expiry time.Duration) (tkn *token, err error) {
return tkn, nil
}
// Verifies JWT and stores it into context
// HttpVerifier verifies JWT and stores it into context
func (t *token) HttpVerifier() func(http.Handler) http.Handler {
return jwtauth.Verifier(t.tokenAuth)
}
@@ -91,7 +91,7 @@ func (t *token) HttpAuthenticator() func(http.Handler) http.Handler {
return
}
ctx = SetIdentityToContext(ctx, claimsToIdentity(claims))
ctx = SetIdentityToContext(ctx, ClaimsToIdentity(claims))
ctx = context.WithValue(ctx, scopeCtxKey{}, claims["scope"])
r = r.WithContext(ctx)
@@ -102,8 +102,8 @@ func (t *token) HttpAuthenticator() func(http.Handler) http.Handler {
}
}
// decodes sub & roles claims into identity
func claimsToIdentity(c jwt.MapClaims) (i *Identity) {
// ClaimsToIdentity decodes sub & roles claims into identity
func ClaimsToIdentity(c jwt.MapClaims) (i *Identity) {
var (
aux interface{}
ok bool
@@ -140,7 +140,3 @@ func claimsToIdentity(c jwt.MapClaims) (i *Identity) {
return
}
func ClaimsToIdentity(c jwt.MapClaims) (i *Identity) {
return claimsToIdentity(c)
}

View File

@@ -3,6 +3,7 @@ imports:
docs:
title: Websocket server
description: A Websocket server emphasize the trigger events and actions.
props:
- name: Timeout

View File

@@ -1,12 +0,0 @@
package payload
import (
"encoding/json"
"github.com/cortezaproject/corteza-server/pkg/payload/incoming"
)
func Unmarshal(raw []byte) (*incoming.Payload, error) {
var p = &incoming.Payload{}
return p, json.Unmarshal(raw, p)
}

View File

@@ -1,25 +0,0 @@
package incoming
import (
"github.com/dgrijalva/jwt-go"
"github.com/pkg/errors"
)
type (
Token struct {
AccessToken *string `json:"access_token"`
}
)
func (t *Token) ParseWithClaims() (jwt.MapClaims, error) {
token, err := jwt.Parse(*t.AccessToken, nil)
if token == nil {
return nil, err
}
claims, ok := token.Claims.(jwt.MapClaims)
if ok {
return claims, nil
} else {
return nil, errors.New("Invalid token")
}
}

View File

@@ -1,7 +0,0 @@
package incoming
type Payload struct {
// Token is JWT token provided by client as first message,
// and will be passed whenever it changes
*Token `json:"token"`
}

View File

@@ -1,25 +0,0 @@
package payload
import (
"github.com/cortezaproject/corteza-server/pkg/payload/outgoing"
"github.com/cortezaproject/corteza-server/pkg/wfexec"
)
func Prompt(p *wfexec.PendingPrompt) *outgoing.Prompt {
return &outgoing.Prompt{
Ref: p.Ref,
SessionID: p.SessionID,
CreatedAt: p.CreatedAt,
StateID: p.StateID,
Payload: p.Payload,
}
}
func Prompts(prompts []*wfexec.PendingPrompt) *outgoing.Prompts {
ps := make([]*outgoing.Prompt, len(prompts))
for k, p := range prompts {
ps[k] = Prompt(p)
}
retval := outgoing.Prompts(ps)
return &retval
}

View File

@@ -1,19 +0,0 @@
package outgoing
import (
"encoding/json"
)
type (
Error struct {
Message string `json:"m"`
}
)
func (p *Error) EncodeMessage() ([]byte, error) {
return json.Marshal(Payload{Error: p})
}
func NewError(err error) *Error {
return &Error{Message: err.Error()}
}

View File

@@ -1,16 +0,0 @@
package outgoing
type (
Payload struct {
*Error `json:"error,omitempty"`
*Prompt `json:"prompt,omitempty"`
*Prompts `json:"prompts,omitempty"`
}
// MessageEncoder This is same-same but different as using the json.Marshaler
// (this one does not cause json.Marshal to call itself)
MessageEncoder interface {
EncodeMessage() ([]byte, error)
}
)

View File

@@ -1,18 +0,0 @@
package outgoing
import (
"github.com/cortezaproject/corteza-server/pkg/expr"
"time"
)
type (
Prompt struct {
Ref string `json:"ref"`
SessionID uint64 `json:"sessionID,string"`
CreatedAt time.Time `json:"createdAt"`
StateID uint64 `json:"stateID,string"`
Payload *expr.Vars `json:"payload"`
}
Prompts []*Prompt
)

View File

@@ -1,7 +1,36 @@
package websocket
import (
"encoding/json"
)
const (
StatusOK = "ok"
StatusError = "error"
WorkflowApplication = "Workflow"
)
type (
message struct {
Status string `json:"status"`
Application string `json:"application"`
Data interface{} `json:"data"`
}
MessageEncoder interface {
EncodeMessage() ([]byte, error)
}
)
func Message(status, application string, data interface{}) *message {
return &message{
Status: status,
Application: application,
Data: data,
}
}
func (m message) EncodeMessage() ([]byte, error) {
return json.Marshal(m)
}

View File

@@ -2,25 +2,15 @@ package websocket
import (
"github.com/go-chi/chi"
"net/http"
)
func middlewareAllowedAccess(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
//if !service.DefaultAccessControl.CanAccess(r.Context()) {
// http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
// return
//}
next.ServeHTTP(w, r)
})
}
func (ws Websocket) MountRoutes(r chi.Router) {
// MountRoutes initialize route for websocket
// No middleware used, since anyone can open connection and
// send first message with valid JWT token,
// If it's valid then we keep the connection open or close it
func (ws *websocket) MountRoutes(r chi.Router) {
// Initialize handlers & controllers.
r.Group(func(r chi.Router) {
r.Route("/websocket", func(r chi.Router) {
//r.Use(middlewareAllowedAccess)
r.Get("/", ws.Open)
})
r.Get("/", ws.Open)
})
}

View File

@@ -2,25 +2,27 @@ package websocket
import (
"context"
"github.com/cortezaproject/corteza-server/pkg/id"
"github.com/cortezaproject/corteza-server/pkg/logger"
"github.com/cortezaproject/corteza-server/pkg/payload/outgoing"
"github.com/getsentry/sentry-go"
"github.com/pkg/errors"
"go.uber.org/zap/zapcore"
"sync"
"time"
"github.com/gorilla/websocket"
gWebsocket "github.com/gorilla/websocket"
"go.uber.org/zap"
"github.com/cortezaproject/corteza-server/pkg/auth"
)
var sessions = make(map[uint64]*session)
type (
Session struct {
session struct {
id uint64
once sync.Once
Conn *websocket.Conn
conn *gWebsocket.Conn
ctx context.Context
ctxCancel context.CancelFunc
@@ -38,9 +40,9 @@ type (
}
)
func (*Session) New(ctx context.Context, config *Config, conn *websocket.Conn) *Session {
s := &Session{
Conn: conn,
func Session(ctx context.Context, config *Config, conn *gWebsocket.Conn) *session {
s := &session{
conn: conn,
config: config,
send: make(chan []byte, 512),
stop: make(chan []byte, 1),
@@ -53,15 +55,19 @@ func (*Session) New(ctx context.Context, config *Config, conn *websocket.Conn) *
return s
}
func (s *Session) log(fields ...zapcore.Field) *zap.Logger {
func (s *session) log(fields ...zapcore.Field) *zap.Logger {
return s.logger.With(fields...)
}
func (s *Session) Context() context.Context {
func (s *session) Context() context.Context {
return s.ctx
}
func (s *Session) connected() (err error) {
func (s *session) User() auth.Identifiable {
return s.user
}
func (s *session) connected() (err error) {
// Tell everyone that user has connected
if err = s.sendPresence("connected"); err != nil {
return
@@ -77,7 +83,7 @@ func (s *Session) connected() (err error) {
case <-s.ctx.Done():
return
case <-t.C:
s.sendPresence("")
_ = s.sendPresence("")
}
}
}()
@@ -85,7 +91,7 @@ func (s *Session) connected() (err error) {
return nil
}
func (s *Session) disconnected() {
func (s *session) disconnected() {
// Tell everyone that user has disconnected
_ = s.sendPresence("disconnected")
@@ -93,68 +99,68 @@ func (s *Session) disconnected() {
s.ctxCancel()
// Close connection
s.Conn.Close()
s.Conn = nil
s.conn.Close()
s.conn = nil
}
// Sends user presence information to all subscribers
//
// It sends "connected", "disconnected" and "" activity kinds
func (s *Session) sendPresence(kind string) error {
connections := store.CountConnections(s.user.Identity())
if kind == "disconnected" {
connections--
}
// sendPresence sends user presence: "connected", "disconnected" and "" activity kinds
func (s *session) sendPresence(kind string) error {
//connections := store.CountConnections(s.user.Identity())
//if kind == "disconnected" {
// connections--
//}
return nil
}
func (s *Session) Handle() (err error) {
func (s *session) Handle() (err error) {
if err = s.connected(); err != nil {
s.Close()
return
}
go s.readLoop()
go func() {
_ = s.readLoop()
}()
return s.writeLoop()
}
func (s *Session) Close() {
func (s *session) Close() {
s.once.Do(func() {
s.disconnected()
store.Delete(s.id)
s.Delete()
})
}
func (s *Session) readLoop() (err error) {
func (s *session) readLoop() (err error) {
defer func() {
s.Close()
}()
if err = s.Conn.SetReadDeadline(time.Now().Add(s.config.PingTimeout)); err != nil {
if err = s.conn.SetReadDeadline(time.Now().Add(s.config.PingTimeout)); err != nil {
return
}
s.Conn.SetPongHandler(func(string) error {
return s.Conn.SetReadDeadline(time.Now().Add(s.config.PingTimeout))
s.conn.SetPongHandler(func(string) error {
return s.conn.SetReadDeadline(time.Now().Add(s.config.PingTimeout))
})
s.remoteAddr = s.Conn.RemoteAddr().String()
s.remoteAddr = s.conn.RemoteAddr().String()
for {
_, raw, err := s.Conn.ReadMessage()
_, raw, err := s.conn.ReadMessage()
if err != nil {
return errors.Wrap(err, "s.readLoop")
}
if err = s.dispatch(raw); err != nil {
s.log(zap.Error(err)).Error("could not dispatch")
_ = s.sendReply(outgoing.NewError(err))
//_ = s.send(outgoing.NewError(err))
}
}
}
func (s *Session) writeLoop() error {
func (s *session) writeLoop() error {
ticker := time.NewTicker(s.config.PingPeriod)
defer func() {
@@ -163,34 +169,34 @@ func (s *Session) writeLoop() error {
}()
write := func(msg []byte) (err error) {
if s.Conn == nil {
if s.conn == nil {
// Connection closed, nowhere to write
return
}
if err = s.Conn.SetWriteDeadline(time.Now().Add(s.config.Timeout)); err != nil {
if err = s.conn.SetWriteDeadline(time.Now().Add(s.config.Timeout)); err != nil {
return
}
if msg != nil && s.Conn != nil {
return s.Conn.WriteMessage(websocket.TextMessage, msg)
if msg != nil && s.conn != nil {
return s.conn.WriteMessage(gWebsocket.TextMessage, msg)
}
return
}
ping := func() (err error) {
if s.Conn == nil {
if s.conn == nil {
// Connection closed, nothing to ping
return
}
if err = s.Conn.SetWriteDeadline(time.Now().Add(s.config.Timeout)); err != nil {
if err = s.conn.SetWriteDeadline(time.Now().Add(s.config.Timeout)); err != nil {
return
}
if s.Conn != nil {
return s.Conn.WriteMessage(websocket.PingMessage, nil)
if s.conn != nil {
return s.conn.WriteMessage(gWebsocket.PingMessage, nil)
}
return
@@ -220,3 +226,90 @@ func (s *Session) writeLoop() error {
}
}
}
func (s *session) dispatch(raw []byte) error {
var p, err = Unmarshal(raw)
if err != nil {
return errors.Wrap(err, "Session.incoming: payload malformed")
}
ctx := s.Context()
if p.Auth != nil {
return s.authenticate(ctx, p.Auth)
}
return nil
}
func (s *session) Save() *session {
if s.id == 0 {
s.id = id.Next()
}
if s.user != nil {
if _, ok := sessions[s.user.Identity()]; !ok {
sessions[s.user.Identity()] = s
}
}
return s
}
func (s *session) Get(userID uint64) *session {
if sess, ok := sessions[userID]; ok {
return sess
}
return nil
}
func (s *session) Delete() error {
if s.id == 0 {
return nil
}
if s.user != nil {
delete(sessions, s.user.Identity())
}
return nil
}
// Send sends message to user to ones we want to
// if len(userIDs) == 0 -- it sends to everyone
func (s *session) Send(m *message, userIDs ...uint64) error {
pb, err := m.EncodeMessage()
if err != nil {
return err
}
sendsToAll := len(userIDs) == 0
userIDMap := make(map[uint64]uint64)
for _, userID := range userIDs {
userIDMap[userID] = userID
}
for _, sess := range sessions {
_, validUser := userIDMap[sess.user.Identity()]
if sendsToAll || (!sendsToAll && validUser) {
_ = sess.sendBytes(pb)
}
}
return nil
}
// sendBytes sends byte to channel or timout
func (s *session) sendBytes(p []byte) error {
select {
case s.send <- p:
case <-time.After(2 * time.Millisecond):
s.logger.Warn("websocket.sendBytes send timeout")
}
return nil
}
func GetActiveSessions() map[uint64]*session {
return sessions
}

View File

@@ -3,10 +3,9 @@ package websocket
import (
"context"
"github.com/cortezaproject/corteza-server/pkg/auth"
"github.com/cortezaproject/corteza-server/pkg/payload/incoming"
)
func (s *Session) authenticate(ctx context.Context, p *incoming.Token) error {
func (s *session) authenticate(ctx context.Context, p *Auth) error {
// Get JWT claims
claims, err := p.ParseWithClaims()
if err != nil {
@@ -18,13 +17,11 @@ func (s *Session) authenticate(ctx context.Context, p *incoming.Token) error {
identity := auth.ClaimsToIdentity(claims)
// Update the existing ws sessions if exists or create new one
if store.CountConnections(identity.Identity()) > 0 {
store.Walk(func(session *Session) {
session.user = identity
})
} else {
if existingSession := s.Get(identity.Identity()); existingSession == nil {
s.user = identity
store.Save(s)
s.Save()
} else {
existingSession.user = identity
}
return nil

View File

@@ -1,23 +0,0 @@
package websocket
import (
"github.com/cortezaproject/corteza-server/pkg/payload"
"github.com/pkg/errors"
)
func (s *Session) dispatch(raw []byte) error {
var p, err = payload.Unmarshal(raw)
if err != nil {
return errors.Wrap(err, "Session.incoming: payload malformed")
}
ctx := s.Context()
switch {
// Access token
case p.Token != nil:
return s.authenticate(ctx, p.Token)
}
return nil
}

View File

@@ -1,25 +0,0 @@
package websocket
import (
"github.com/cortezaproject/corteza-server/pkg/logger"
"time"
)
// SendReply sends message only on this session, no need to enqueue item
func (s *Session) sendReply(p MessageEncoder) error {
pb, err := p.EncodeMessage()
if err != nil {
return err
}
return s.sendBytes(pb)
}
func (s *Session) sendBytes(p []byte) error {
select {
case s.send <- p:
case <-time.After(2 * time.Millisecond):
logger.Default().Warn("websocket.sendBytes send timeout")
}
return nil
}

View File

@@ -1,78 +0,0 @@
package websocket
import (
"github.com/cortezaproject/corteza-server/pkg/id"
"sync"
)
var store *Store
type (
Store struct {
sync.RWMutex
Sessions map[uint64]*Session
}
)
func init() {
store = NewStore()
}
func NewStore() *Store {
return &Store{sync.RWMutex{}, make(map[uint64]*Session)}
}
func (s *Store) Save(session *Session) *Session {
session.id = id.Next()
s.Lock()
defer s.Unlock()
s.Sessions[session.id] = session
return session
}
func (s *Store) Get(id uint64) *Session {
s.RLock()
defer s.RUnlock()
return s.Sessions[id]
}
func (s *Store) Delete(id uint64) {
s.Lock()
defer s.Unlock()
delete(s.Sessions, id)
}
func (s *Store) Walk(callback func(*Session)) {
s.RLock()
defer s.RUnlock()
for _, sess := range s.Sessions {
callback(sess)
}
}
func (s *Store) CountConnections(userID uint64) (count uint) {
s.Walk(func(session *Session) {
if session.user.Identity() == userID {
count++
}
})
return
}
// GetConnectedUsers gets all connected user to ws session
func GetConnectedUsers() []uint64 {
var chk = map[uint64]bool{}
store.Walk(func(session *Session) {
chk[session.user.Identity()] = true
})
var out = make([]uint64, 0)
for ID := range chk {
out = append(out, ID)
}
return out
}

36
websocket/types.go Normal file
View File

@@ -0,0 +1,36 @@
package websocket
import (
"encoding/json"
"github.com/dgrijalva/jwt-go"
"github.com/pkg/errors"
)
type (
Payload struct {
// Auth is JWT token provided by client as first message,
// and will be passed whenever it changes
*Auth `json:"auth"`
}
Auth struct {
AccessToken *string `json:"access_token"`
}
)
func (a *Auth) ParseWithClaims() (jwt.MapClaims, error) {
token, err := jwt.Parse(*a.AccessToken, nil)
if token == nil {
return nil, err
}
claims, ok := token.Claims.(jwt.MapClaims)
if ok {
return claims, nil
} else {
return nil, errors.New("Invalid token")
}
}
func Unmarshal(raw []byte) (p *Payload, err error) {
return p, json.Unmarshal(raw, p)
}

View File

@@ -1,18 +1,16 @@
package websocket
import (
"fmt"
"github.com/cortezaproject/corteza-server/pkg/api"
"github.com/cortezaproject/corteza-server/pkg/logger"
"github.com/gorilla/websocket"
gWebsocket "github.com/gorilla/websocket"
"github.com/pkg/errors"
"go.uber.org/zap"
"net/http"
)
var (
// Handles websocket requests from peers
upgrader = websocket.Upgrader{
// upgrader handles websocket requests from peers
upgrader = gWebsocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
@@ -22,37 +20,37 @@ var (
)
type (
Websocket struct {
websocket struct {
config *Config
logger *zap.Logger
}
)
func New(config *Config) *Websocket {
ws := &Websocket{
func Websocket(config *Config, logger *zap.Logger) *websocket {
return &websocket{
config: config,
logger: logger,
}
return ws
}
func (ws Websocket) Open(w http.ResponseWriter, r *http.Request) {
func (ws *websocket) Open(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
conn, err := upgrader.Upgrade(w, r, nil)
if _, ok := err.(websocket.HandshakeError); ok {
fmt.Println("ws: need a websocket handshake")
if _, ok := err.(gWebsocket.HandshakeError); ok {
ws.logger.Error("ws: need a websocket handshake")
api.Send(w, r, errors.Wrap(err, "ws: need a websocket handshake"))
return
} else if err != nil {
fmt.Println("ws: failed to upgrade connection")
ws.logger.Error("ws: failed to upgrade connection")
api.Send(w, r, errors.Wrap(err, "ws: failed to upgrade connection"))
return
}
session := (&Session{}).New(ctx, ws.config, conn)
session := Session(ctx, ws.config, conn)
if err := session.Handle(); err != nil {
logger.Default().
ws.logger.
WithOptions(zap.AddStacktrace(zap.PanicLevel)).
Warn("websocket session handler error", zap.Error(err))
}