Refactor & move websocket code under pkg
This commit is contained in:
parent
bcf83b25e8
commit
70dedcaaba
@ -29,11 +29,11 @@ import (
|
||||
"github.com/cortezaproject/corteza-server/pkg/rbac"
|
||||
"github.com/cortezaproject/corteza-server/pkg/scheduler"
|
||||
"github.com/cortezaproject/corteza-server/pkg/sentry"
|
||||
"github.com/cortezaproject/corteza-server/pkg/websocket"
|
||||
"github.com/cortezaproject/corteza-server/store"
|
||||
sysService "github.com/cortezaproject/corteza-server/system/service"
|
||||
sysEvent "github.com/cortezaproject/corteza-server/system/service/event"
|
||||
"github.com/cortezaproject/corteza-server/system/types"
|
||||
"github.com/cortezaproject/corteza-server/websocket"
|
||||
"go.uber.org/zap"
|
||||
gomail "gopkg.in/mail.v2"
|
||||
)
|
||||
@ -252,7 +252,7 @@ func (app *CortezaApp) InitServices(ctx context.Context) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
app.WsServer = websocket.Websocket(app.Log, app.Opt.Websocket)
|
||||
app.WsServer = websocket.Server(app.Log, app.Opt.Websocket)
|
||||
|
||||
ctx = actionlog.RequestOriginToContext(ctx, actionlog.RequestOrigin_APP_Init)
|
||||
defer sentry.Recover()
|
||||
|
||||
@ -212,7 +212,16 @@ func (svc *session) Resume(sessionID, stateID uint64, i auth.Identifiable, input
|
||||
return errors.NotFound("session not found")
|
||||
}
|
||||
|
||||
return ses.Resume(ctx, stateID, input)
|
||||
resPrompt, err := ses.Resume(ctx, stateID, input)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = svc.promptSender.Send("workflowSessionResumed", resPrompt, resPrompt.OwnerId); err != nil {
|
||||
svc.log.Error("failed to send prompt resume status to user", zap.Error(err))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// spawns a new session
|
||||
@ -367,7 +376,7 @@ func (svc *session) stateChangeHandler(ctx context.Context) wfexec.StateChangeHa
|
||||
// Send the pending prompts to user
|
||||
if svc.promptSender != nil {
|
||||
for _, pp := range s.AllPendingPrompts() {
|
||||
if err := svc.promptSender.Send("ok", pp, pp.OwnerId); err != nil {
|
||||
if err := svc.promptSender.Send("workflowSessionPrompt", pp, pp.OwnerId); err != nil {
|
||||
svc.log.Error("failed to send prompt to user", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
@ -103,7 +103,7 @@ func (s Session) Exec(ctx context.Context, step wfexec.Step, input *expr.Vars) e
|
||||
return s.session.Exec(ctx, step, input)
|
||||
}
|
||||
|
||||
func (s Session) Resume(ctx context.Context, stateID uint64, input *expr.Vars) error {
|
||||
func (s Session) Resume(ctx context.Context, stateID uint64, input *expr.Vars) (*wfexec.ResumedPrompt, error) {
|
||||
return s.session.Resume(ctx, stateID, input)
|
||||
}
|
||||
|
||||
|
||||
44
pkg/websocket/payload.go
Normal file
44
pkg/websocket/payload.go
Normal file
@ -0,0 +1,44 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
type (
|
||||
// Auth is JWT token provided by client as first message,
|
||||
// and will be passed whenever it changes
|
||||
payloadAuth struct {
|
||||
AccessToken string `json:"accessToken"`
|
||||
}
|
||||
|
||||
payloadWrap struct {
|
||||
Type string `json:"@type"`
|
||||
Value json.RawMessage `json:"@value"`
|
||||
}
|
||||
)
|
||||
|
||||
const (
|
||||
payloadTypeCredentials = "credentials"
|
||||
)
|
||||
|
||||
var (
|
||||
closingUnidentifiedConn, _ = MarshalPayload("error", "closing unidentified connection")
|
||||
ok, _ = MarshalPayload("message", "authenticated")
|
||||
)
|
||||
|
||||
func (p payloadWrap) UnmarshalValue(m interface{}) error {
|
||||
return json.Unmarshal(p.Value, m)
|
||||
}
|
||||
|
||||
func MarshalPayload(t string, m interface{}) ([]byte, error) {
|
||||
var (
|
||||
err error
|
||||
w = payloadWrap{Type: t}
|
||||
)
|
||||
|
||||
if w.Value, err = json.Marshal(m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return json.Marshal(w)
|
||||
}
|
||||
@ -8,7 +8,7 @@ import (
|
||||
// No middleware used, since anyone can open connection and
|
||||
// send first message with valid JWT token,
|
||||
// If it's valid then we keep the connection open or close it
|
||||
func (ws *websocket) MountRoutes(r chi.Router) {
|
||||
func (ws *server) MountRoutes(r chi.Router) {
|
||||
// Initialize handlers & controllers.
|
||||
r.Get("/", ws.Open)
|
||||
}
|
||||
136
pkg/websocket/server.go
Normal file
136
pkg/websocket/server.go
Normal file
@ -0,0 +1,136 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"github.com/cortezaproject/corteza-server/pkg/auth"
|
||||
"github.com/cortezaproject/corteza-server/pkg/errors"
|
||||
"github.com/cortezaproject/corteza-server/pkg/options"
|
||||
"github.com/cortezaproject/corteza-server/pkg/slice"
|
||||
"github.com/dgrijalva/jwt-go"
|
||||
"github.com/gorilla/websocket"
|
||||
"go.uber.org/zap"
|
||||
"io"
|
||||
"net/http"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var (
|
||||
// upgrader handles websocket requests from peers
|
||||
upgrader = websocket.Upgrader{
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
|
||||
// Allow connections from any Origin
|
||||
CheckOrigin: func(r *http.Request) bool { return true },
|
||||
}
|
||||
)
|
||||
|
||||
type (
|
||||
server struct {
|
||||
config options.WebsocketOpt
|
||||
logger *zap.Logger
|
||||
|
||||
// user id => session id => session
|
||||
sessions map[uint64]map[uint64]io.Writer
|
||||
|
||||
accessToken interface {
|
||||
Authenticate(string) (jwt.MapClaims, error)
|
||||
}
|
||||
|
||||
// keep lock on session map changes
|
||||
l sync.RWMutex
|
||||
}
|
||||
)
|
||||
|
||||
func Server(logger *zap.Logger, config options.WebsocketOpt) *server {
|
||||
if !config.LogEnabled {
|
||||
logger = zap.NewNop()
|
||||
}
|
||||
|
||||
return &server{
|
||||
config: config,
|
||||
logger: logger.WithOptions(zap.AddStacktrace(zap.PanicLevel)).Named("websocket"),
|
||||
accessToken: auth.DefaultJwtHandler,
|
||||
sessions: make(map[uint64]map[uint64]io.Writer),
|
||||
}
|
||||
}
|
||||
|
||||
func (ws *server) Open(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if _, ok := err.(websocket.HandshakeError); ok {
|
||||
errors.ProperlyServeHTTP(w, r, errors.Internal("need a websocket handshake"), false)
|
||||
return
|
||||
} else if err != nil {
|
||||
errors.ProperlyServeHTTP(w, r, errors.Internal("failed to upgrade connection").Wrap(err), false)
|
||||
return
|
||||
}
|
||||
|
||||
// init new session
|
||||
//
|
||||
// session will add itself back to server's session map when
|
||||
// ready (if user authenticates itself)
|
||||
ses := Session(ctx, ws, conn)
|
||||
|
||||
if err = ses.Handle(); err != nil {
|
||||
ws.logger.Warn("websocket session handler error", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
// Send delivers payload to one, more or all users
|
||||
//
|
||||
// Omit userIDs to deliver to ALL users
|
||||
func (ws *server) Send(t string, payload interface{}, userIDs ...uint64) error {
|
||||
pb, err := MarshalPayload(t, payload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var (
|
||||
sendToAll = len(userIDs) == 0
|
||||
uMap = slice.ToUint64BoolMap(userIDs)
|
||||
)
|
||||
|
||||
ws.l.RLock()
|
||||
defer ws.l.RUnlock()
|
||||
|
||||
for uid := range ws.sessions {
|
||||
if sendToAll || (!sendToAll && uMap[uid]) {
|
||||
for _, s := range ws.sessions[uid] {
|
||||
_, err = s.Write(pb)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ws *server) StoreSession(s *session) {
|
||||
ws.l.Lock()
|
||||
defer ws.l.Unlock()
|
||||
if s.identity != nil {
|
||||
ws.storeSession(s, s.identity.Identity(), s.id)
|
||||
}
|
||||
}
|
||||
|
||||
func (ws *server) storeSession(w io.Writer, uid, sid uint64) {
|
||||
if ws.sessions[uid] == nil {
|
||||
ws.sessions[uid] = make(map[uint64]io.Writer)
|
||||
|
||||
}
|
||||
|
||||
ws.sessions[uid][sid] = w
|
||||
}
|
||||
|
||||
func (ws *server) RemoveSession(s *session) {
|
||||
ws.l.Lock()
|
||||
defer ws.l.Unlock()
|
||||
if s.identity != nil {
|
||||
uid := s.identity.Identity()
|
||||
delete(ws.sessions[uid], s.id)
|
||||
|
||||
if len(ws.sessions[uid]) == 0 {
|
||||
delete(ws.sessions, uid)
|
||||
}
|
||||
}
|
||||
}
|
||||
63
pkg/websocket/server_test.go
Normal file
63
pkg/websocket/server_test.go
Normal file
@ -0,0 +1,63 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"github.com/cortezaproject/corteza-server/pkg/options"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestWebsocketSend_NoSessions(t *testing.T) {
|
||||
var (
|
||||
req = require.New(t)
|
||||
ws = Server(zap.NewNop(), options.WebsocketOpt{})
|
||||
)
|
||||
|
||||
req.NoError(ws.Send("msg", "msg"))
|
||||
req.NoError(ws.Send("msg", "msg", 1))
|
||||
req.NoError(ws.Send("msg", "msg", 1, 2))
|
||||
req.NoError(ws.Send("msg", "msg", 1, 2, 3))
|
||||
}
|
||||
|
||||
func TestWebsocketSend_ExistingSessions(t *testing.T) {
|
||||
var (
|
||||
req = require.New(t)
|
||||
ws = Server(zap.NewNop(), options.WebsocketOpt{})
|
||||
|
||||
s1User uint64 = 100
|
||||
s1ID uint64 = 101
|
||||
|
||||
s2User uint64 = 200
|
||||
s2ID uint64 = 201
|
||||
|
||||
s1 = &bytes.Buffer{}
|
||||
s2 = &bytes.Buffer{}
|
||||
)
|
||||
|
||||
ws.storeSession(s1, s1User, s1ID)
|
||||
ws.storeSession(s2, s2User, s2ID)
|
||||
|
||||
req.Empty(s1)
|
||||
req.Empty(s2)
|
||||
|
||||
req.NoError(ws.Send("msg", "msg", 0))
|
||||
req.Empty(s1)
|
||||
req.Empty(s2)
|
||||
|
||||
req.NoError(ws.Send("msg", "msg1", s1User))
|
||||
req.Equal(s1.String(), `{"msg":"msg1"}`)
|
||||
req.Equal(s2.String(), "")
|
||||
|
||||
req.NoError(ws.Send("msg", "msg2", s2User))
|
||||
req.Equal(s1.String(), `{"msg":"msg1"}`)
|
||||
req.Equal(s2.String(), `{"msg":"msg2"}`)
|
||||
|
||||
req.NoError(ws.Send("both", "msg3", s1User, s2User))
|
||||
req.Equal(s1.String(), `{"msg":"msg1"}{"both":"msg3"}`)
|
||||
req.Equal(s2.String(), `{"msg":"msg2"}{"both":"msg3"}`)
|
||||
|
||||
req.NoError(ws.Send("all", "msg4"))
|
||||
req.Equal(s1.String(), `{"msg":"msg1"}{"both":"msg3"}{"all":"msg4"}`)
|
||||
req.Equal(s2.String(), `{"msg":"msg2"}{"both":"msg3"}{"all":"msg4"}`)
|
||||
}
|
||||
340
pkg/websocket/session.go
Normal file
340
pkg/websocket/session.go
Normal file
@ -0,0 +1,340 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/cortezaproject/corteza-server/pkg/auth"
|
||||
"github.com/cortezaproject/corteza-server/pkg/errors"
|
||||
"github.com/cortezaproject/corteza-server/pkg/id"
|
||||
"github.com/cortezaproject/corteza-server/pkg/options"
|
||||
"github.com/gorilla/websocket"
|
||||
"go.uber.org/zap"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// active sessions of users
|
||||
var (
|
||||
// wrapper around nextID that will aid service testing
|
||||
nextID = func() uint64 {
|
||||
return id.Next()
|
||||
}
|
||||
)
|
||||
|
||||
type (
|
||||
session struct {
|
||||
id uint64
|
||||
once sync.Once
|
||||
conn *websocket.Conn
|
||||
|
||||
ctx context.Context
|
||||
ctxCancel context.CancelFunc
|
||||
|
||||
logger *zap.Logger
|
||||
|
||||
send chan []byte
|
||||
stop chan []byte
|
||||
|
||||
remoteAddr string
|
||||
|
||||
config options.WebsocketOpt
|
||||
|
||||
identity auth.Identifiable
|
||||
|
||||
server *server
|
||||
}
|
||||
)
|
||||
|
||||
func Session(ctx context.Context, ws *server, conn *websocket.Conn) *session {
|
||||
s := &session{
|
||||
id: nextID(),
|
||||
conn: conn,
|
||||
config: ws.config,
|
||||
send: make(chan []byte, 512),
|
||||
stop: make(chan []byte, 1),
|
||||
server: ws,
|
||||
}
|
||||
|
||||
s.ctx, s.ctxCancel = context.WithCancel(ctx)
|
||||
|
||||
s.logger = ws.logger.
|
||||
Named("session").
|
||||
With(
|
||||
zap.Uint64("id", s.id),
|
||||
)
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *session) connected() (err error) {
|
||||
s.logger.Info("connected", zap.String("remoteAddr", s.conn.RemoteAddr().String()))
|
||||
|
||||
//// Tell everyone that user has connected
|
||||
//if err = s.sendPresence("connected"); err != nil {
|
||||
// return
|
||||
//}
|
||||
//
|
||||
//
|
||||
//// Create a heartbeat every minute for this user
|
||||
//go func() {
|
||||
// defer sentry.Recover()
|
||||
//
|
||||
// t := time.NewTicker(time.Second * 60)
|
||||
// for {
|
||||
// select {
|
||||
// case <-s.ctx.Done():
|
||||
// return
|
||||
// case <-t.C:
|
||||
// _ = s.sendPresence("active")
|
||||
// }
|
||||
// }
|
||||
//}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *session) disconnected() {
|
||||
// Cancel context
|
||||
s.ctxCancel()
|
||||
|
||||
s.logger.Info("disconnected")
|
||||
|
||||
// Close connection
|
||||
_ = s.conn.Close()
|
||||
s.conn = nil
|
||||
}
|
||||
|
||||
//func (s *session) sendPresence(_ string) error {
|
||||
// return nil
|
||||
//}
|
||||
|
||||
func (s *session) Handle() (err error) {
|
||||
if err = s.connected(); err != nil {
|
||||
s.Close()
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
// Close unidentified connections in 5sec
|
||||
<-time.NewTimer(time.Second * 5).C
|
||||
if s.identity == nil {
|
||||
s.Write([]byte(closingUnidentifiedConn))
|
||||
s.logger.Info("closing unidentified connection")
|
||||
s.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
if err = s.readLoop(); err != nil {
|
||||
s.logger.Error("read failure", zap.Error(err))
|
||||
}
|
||||
s.Close()
|
||||
}()
|
||||
|
||||
if err = s.writeLoop(); err != nil {
|
||||
s.logger.Error("write failure", zap.Error(err))
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (s *session) Close() {
|
||||
s.once.Do(func() {
|
||||
s.disconnected()
|
||||
s.server.RemoveSession(s)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *session) readLoop() (err error) {
|
||||
if err = s.conn.SetReadDeadline(time.Now().Add(s.config.PingTimeout)); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
s.conn.SetPongHandler(func(string) error {
|
||||
return s.conn.SetReadDeadline(time.Now().Add(s.config.PingTimeout))
|
||||
})
|
||||
|
||||
s.remoteAddr = s.conn.RemoteAddr().String()
|
||||
|
||||
var (
|
||||
raw []byte
|
||||
)
|
||||
|
||||
for {
|
||||
if s.conn == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, raw, err = s.conn.ReadMessage(); err != nil {
|
||||
return errHandler("read failed", err)
|
||||
}
|
||||
|
||||
if err = s.procRawMessage(raw); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *session) procRawMessage(raw []byte) (err error) {
|
||||
pw := payloadWrap{}
|
||||
if err = json.Unmarshal(raw, &pw); err != nil {
|
||||
return fmt.Errorf("could not unmarshal session message: %w", err)
|
||||
}
|
||||
|
||||
if pw.Type == payloadTypeCredentials {
|
||||
authPayload := &payloadAuth{}
|
||||
if err = pw.UnmarshalValue(authPayload); err != nil {
|
||||
return fmt.Errorf("could not unmarshal session payload: %w", err)
|
||||
}
|
||||
|
||||
if err = s.authenticate(authPayload); err != nil {
|
||||
return fmt.Errorf("unauthorized: %w", err)
|
||||
}
|
||||
|
||||
s.logger.Debug(
|
||||
"authenticated",
|
||||
zap.Uint64("userID", s.identity.Identity()),
|
||||
zap.Uint64s("roles", s.identity.Roles()),
|
||||
)
|
||||
|
||||
s.server.StoreSession(s)
|
||||
|
||||
// not expecting anything else
|
||||
return
|
||||
}
|
||||
|
||||
if s.identity == nil {
|
||||
return fmt.Errorf("unauthenticated session")
|
||||
}
|
||||
|
||||
// at the moment we do not support any other kinds of message types
|
||||
return fmt.Errorf("unknown message type '%s'", pw.Type)
|
||||
}
|
||||
|
||||
func (s *session) writeLoop() error {
|
||||
ticker := time.NewTicker(s.config.PingPeriod)
|
||||
|
||||
defer func() {
|
||||
ticker.Stop()
|
||||
s.Close() // break readLoop
|
||||
}()
|
||||
|
||||
write := func(msg []byte) (err error) {
|
||||
if s.conn == nil {
|
||||
// Connection closed, nowhere to write
|
||||
return
|
||||
}
|
||||
|
||||
if err = s.conn.SetWriteDeadline(time.Now().Add(s.config.Timeout)); err != nil {
|
||||
return fmt.Errorf("deadline error: %w", err)
|
||||
}
|
||||
|
||||
if msg != nil && s.conn != nil {
|
||||
return s.conn.WriteMessage(websocket.TextMessage, msg)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
ping := func() (err error) {
|
||||
if s.conn == nil {
|
||||
// Connection closed, nothing to ping
|
||||
return
|
||||
}
|
||||
|
||||
if err = s.conn.SetWriteDeadline(time.Now().Add(s.config.Timeout)); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if s.conn != nil {
|
||||
return s.conn.WriteMessage(websocket.PingMessage, nil)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
if s.conn == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
select {
|
||||
case msg, ok := <-s.send:
|
||||
if !ok {
|
||||
// channel closed
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := errHandler("send failed", write(msg)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
case msg := <-s.stop:
|
||||
// Shutdown requested, don't care if the message is delivered
|
||||
_ = write(msg)
|
||||
return nil
|
||||
|
||||
case <-ticker.C:
|
||||
if err := ping(); err != nil {
|
||||
return errHandler("ping failed", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *session) authenticate(p *payloadAuth) error {
|
||||
claims, err := s.server.accessToken.Authenticate(p.AccessToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !auth.CheckScope(claims["scope"], "api") {
|
||||
return fmt.Errorf("client does not allow use of websockets (missing 'api' scope)")
|
||||
}
|
||||
|
||||
// Get identity using JWT claims
|
||||
identity := auth.ClaimsToIdentity(claims)
|
||||
|
||||
if s.identity != nil {
|
||||
if s.identity.Identity() != identity.Identity() {
|
||||
return fmt.Errorf("identity does not match")
|
||||
}
|
||||
}
|
||||
|
||||
if !identity.Valid() {
|
||||
return fmt.Errorf("invalid identity")
|
||||
}
|
||||
|
||||
s.identity = identity
|
||||
s.Write([]byte(ok))
|
||||
return nil
|
||||
}
|
||||
|
||||
// sendBytes sends byte to channel or timeout
|
||||
func (s *session) Write(p []byte) (int, error) {
|
||||
select {
|
||||
case s.send <- p:
|
||||
return len(p), nil
|
||||
case <-time.After(2 * time.Millisecond):
|
||||
return 0, fmt.Errorf("write timedout")
|
||||
}
|
||||
}
|
||||
|
||||
func errHandler(wrap string, err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
|
||||
// normal closing
|
||||
return nil
|
||||
}
|
||||
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
// suppress errors when reading/writing from/to a closed connection
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf(wrap+": %w", err)
|
||||
}
|
||||
53
pkg/websocket/session_test.go
Normal file
53
pkg/websocket/session_test.go
Normal file
@ -0,0 +1,53 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"github.com/cortezaproject/corteza-server/pkg/auth"
|
||||
"github.com/cortezaproject/corteza-server/pkg/options"
|
||||
"github.com/stretchr/testify/require"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestSession_procRawMessage(t *testing.T) {
|
||||
var (
|
||||
req = require.New(t)
|
||||
s = session{server: Server(nil, options.WebsocketOpt{})}
|
||||
jwtHandler, err = auth.JWT("secret", time.Minute)
|
||||
|
||||
userID uint64 = 123
|
||||
)
|
||||
|
||||
req.NoError(err)
|
||||
s.server.accessToken = jwtHandler
|
||||
|
||||
jwt := jwtHandler.Encode(auth.NewIdentity(userID, 456, 789))
|
||||
|
||||
req.EqualError(s.procRawMessage([]byte("{}")), "empty payload")
|
||||
req.Nil(s.identity)
|
||||
|
||||
req.EqualError(s.procRawMessage([]byte(`{"auth":{}}`)), "unauthorized: token contains an invalid number of segments")
|
||||
req.Nil(s.identity)
|
||||
|
||||
req.EqualError(s.procRawMessage([]byte(`{"auth":{"access_token": ""}}`)), "unauthorized: token contains an invalid number of segments")
|
||||
req.Nil(s.identity)
|
||||
|
||||
req.NoError(s.procRawMessage([]byte(`{"auth":{"access_token": "` + jwt + `"}}`)))
|
||||
req.NotNil(s.identity)
|
||||
req.Equal(userID, s.identity.Identity())
|
||||
|
||||
req.EqualError(s.procRawMessage([]byte("{}")), "empty payload")
|
||||
req.Equal(userID, s.identity.Identity())
|
||||
|
||||
// Repeat with the same user
|
||||
jwt = jwtHandler.Encode(auth.NewIdentity(userID, 456, 789))
|
||||
|
||||
req.NoError(s.procRawMessage([]byte(`{"auth":{"access_token": "` + jwt + `"}}`)))
|
||||
req.NotNil(s.identity)
|
||||
req.Equal(userID, s.identity.Identity())
|
||||
|
||||
// Try to authenticate on an existing authenticated session as a different user
|
||||
jwt = jwtHandler.Encode(auth.NewIdentity(userID+1, 456, 789))
|
||||
|
||||
req.EqualError(s.procRawMessage([]byte(`{"auth":{"access_token": "`+jwt+`"}}`)), "unauthorized: identity does not match")
|
||||
|
||||
}
|
||||
@ -29,6 +29,11 @@ type (
|
||||
Payload *expr.Vars `json:"payload"`
|
||||
OwnerId uint64 `json:"-"`
|
||||
}
|
||||
|
||||
ResumedPrompt struct {
|
||||
StateID uint64 `json:"stateID,string"`
|
||||
OwnerId uint64 `json:"-"`
|
||||
}
|
||||
)
|
||||
|
||||
func Prompt(ownerId uint64, ref string, payload *expr.Vars) *prompted {
|
||||
@ -44,3 +49,10 @@ func (p *prompted) toPending() *PendingPrompt {
|
||||
OwnerId: p.ownerId,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *prompted) toResumed() *ResumedPrompt {
|
||||
return &ResumedPrompt{
|
||||
StateID: p.state.stateId,
|
||||
OwnerId: p.ownerId,
|
||||
}
|
||||
}
|
||||
|
||||
@ -288,7 +288,7 @@ func (s *Session) AllPendingPrompts() (out []*PendingPrompt) {
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Session) Resume(ctx context.Context, stateId uint64, input *expr.Vars) error {
|
||||
func (s *Session) Resume(ctx context.Context, stateId uint64, input *expr.Vars) (*ResumedPrompt, error) {
|
||||
defer s.mux.Unlock()
|
||||
s.mux.Lock()
|
||||
|
||||
@ -297,11 +297,11 @@ func (s *Session) Resume(ctx context.Context, stateId uint64, input *expr.Vars)
|
||||
p, has = s.prompted[stateId]
|
||||
)
|
||||
if !has {
|
||||
return fmt.Errorf("unexisting state")
|
||||
return nil, fmt.Errorf("unexisting state")
|
||||
}
|
||||
|
||||
if i == nil || p.ownerId != i.Identity() {
|
||||
return fmt.Errorf("state access denied")
|
||||
return nil, fmt.Errorf("state access denied")
|
||||
}
|
||||
|
||||
delete(s.prompted, stateId)
|
||||
@ -309,7 +309,11 @@ func (s *Session) Resume(ctx context.Context, stateId uint64, input *expr.Vars)
|
||||
// setting received input to state
|
||||
p.state.input = input
|
||||
|
||||
return s.enqueue(ctx, p.state)
|
||||
if err := s.enqueue(ctx, p.state); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return p.toResumed(), nil
|
||||
}
|
||||
|
||||
func (s *Session) enqueue(ctx context.Context, st *State) error {
|
||||
|
||||
@ -278,7 +278,10 @@ func (svc reminder) Delete(ctx context.Context, ID uint64) (err error) {
|
||||
|
||||
func (svc reminder) Watch(ctx context.Context) {
|
||||
if svc.reminderSender != nil {
|
||||
rTicker := time.NewTicker(time.Second)
|
||||
var (
|
||||
interval = time.Second
|
||||
rTicker = time.NewTicker(interval)
|
||||
)
|
||||
|
||||
go func() {
|
||||
defer sentry.Recover()
|
||||
@ -289,26 +292,21 @@ func (svc reminder) Watch(ctx context.Context) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case t := <-rTicker.C:
|
||||
case <-rTicker.C:
|
||||
// Get scheduled reminders of users
|
||||
rr, _, err := svc.Find(ctx, types.ReminderFilter{
|
||||
ExcludeDismissed: true,
|
||||
ScheduledOnly: true,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
svc.log.Error("failed to get reminders of users", zap.Error(err))
|
||||
}
|
||||
|
||||
// sendReminderNow checks time is equal to current time or not
|
||||
sendReminderNow := func(tt time.Time) bool {
|
||||
timeLayout := time.RFC3339
|
||||
return tt.Format(timeLayout) == t.Format(timeLayout)
|
||||
}
|
||||
|
||||
// Send scheduled reminders to users
|
||||
_ = rr.Walk(func(r *types.Reminder) error {
|
||||
if r.RemindAt != nil && sendReminderNow(*r.RemindAt) {
|
||||
if err := svc.reminderSender.Send("ok", r, r.AssignedTo); err != nil {
|
||||
if r.RemindAt != nil && now().Round(interval) == r.RemindAt.Round(interval) {
|
||||
if err := svc.reminderSender.Send("reminder", r, r.AssignedTo); err != nil {
|
||||
svc.log.Error("failed to send reminder to user", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,309 +0,0 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/cortezaproject/corteza-server/pkg/id"
|
||||
"github.com/cortezaproject/corteza-server/pkg/options"
|
||||
"github.com/getsentry/sentry-go"
|
||||
"github.com/pkg/errors"
|
||||
"go.uber.org/zap/zapcore"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
gWebsocket "github.com/gorilla/websocket"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/cortezaproject/corteza-server/pkg/auth"
|
||||
)
|
||||
|
||||
// active sessions of users
|
||||
var sessions = make(map[uint64][]*session)
|
||||
|
||||
type (
|
||||
session struct {
|
||||
id uint64
|
||||
once sync.Once
|
||||
conn *gWebsocket.Conn
|
||||
|
||||
ctx context.Context
|
||||
ctxCancel context.CancelFunc
|
||||
|
||||
logger *zap.Logger
|
||||
|
||||
send chan []byte
|
||||
stop chan []byte
|
||||
|
||||
remoteAddr string
|
||||
|
||||
config options.WebsocketOpt
|
||||
|
||||
user auth.Identifiable
|
||||
}
|
||||
)
|
||||
|
||||
func Session(ctx context.Context, logger *zap.Logger, config options.WebsocketOpt, conn *gWebsocket.Conn) *session {
|
||||
s := &session{
|
||||
conn: conn,
|
||||
config: config,
|
||||
send: make(chan []byte, 512),
|
||||
stop: make(chan []byte, 1),
|
||||
}
|
||||
|
||||
s.ctx, s.ctxCancel = context.WithCancel(ctx)
|
||||
|
||||
s.logger = logger
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *session) log(fields ...zapcore.Field) *zap.Logger {
|
||||
return s.logger.With(fields...)
|
||||
}
|
||||
|
||||
func (s *session) Context() context.Context {
|
||||
return s.ctx
|
||||
}
|
||||
|
||||
func (s *session) User() auth.Identifiable {
|
||||
return s.user
|
||||
}
|
||||
|
||||
func (s *session) connected() (err error) {
|
||||
// Tell everyone that user has connected
|
||||
if err = s.sendPresence("connected"); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Create a heartbeat every minute for this user
|
||||
go func() {
|
||||
defer sentry.Recover()
|
||||
|
||||
t := time.NewTicker(time.Second * 60)
|
||||
for {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
case <-t.C:
|
||||
_ = s.sendPresence("")
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *session) disconnected() {
|
||||
// Tell everyone that user has disconnected
|
||||
_ = s.sendPresence("disconnected")
|
||||
|
||||
// Cancel context
|
||||
s.ctxCancel()
|
||||
|
||||
// Close connection
|
||||
_ = s.conn.Close()
|
||||
s.conn = nil
|
||||
}
|
||||
|
||||
// sendPresence sends user presence: "connected", "disconnected" and "" activity kinds
|
||||
func (s *session) sendPresence(kind string) error {
|
||||
//connections := store.CountConnections(s.user.Identity())
|
||||
//if kind == "disconnected" {
|
||||
// connections--
|
||||
//}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *session) Handle() (err error) {
|
||||
if err = s.connected(); err != nil {
|
||||
s.Close()
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
_ = s.readLoop()
|
||||
}()
|
||||
return s.writeLoop()
|
||||
}
|
||||
|
||||
func (s *session) Close() {
|
||||
s.once.Do(func() {
|
||||
s.disconnected()
|
||||
_ = s.Delete()
|
||||
})
|
||||
}
|
||||
|
||||
func (s *session) readLoop() (err error) {
|
||||
defer func() {
|
||||
s.Close()
|
||||
}()
|
||||
|
||||
if err = s.conn.SetReadDeadline(time.Now().Add(s.config.PingTimeout)); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
s.conn.SetPongHandler(func(string) error {
|
||||
return s.conn.SetReadDeadline(time.Now().Add(s.config.PingTimeout))
|
||||
})
|
||||
|
||||
s.remoteAddr = s.conn.RemoteAddr().String()
|
||||
|
||||
for {
|
||||
_, raw, err := s.conn.ReadMessage()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "s.readLoop")
|
||||
}
|
||||
|
||||
if err = s.dispatch(raw); err != nil {
|
||||
s.log(zap.Error(err)).Error("could not dispatch")
|
||||
//_ = s.send(outgoing.NewError(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *session) writeLoop() error {
|
||||
ticker := time.NewTicker(s.config.PingPeriod)
|
||||
|
||||
defer func() {
|
||||
ticker.Stop()
|
||||
s.Close() // break readLoop
|
||||
}()
|
||||
|
||||
write := func(msg []byte) (err error) {
|
||||
if s.conn == nil {
|
||||
// Connection closed, nowhere to write
|
||||
return
|
||||
}
|
||||
|
||||
if err = s.conn.SetWriteDeadline(time.Now().Add(s.config.Timeout)); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if msg != nil && s.conn != nil {
|
||||
return s.conn.WriteMessage(gWebsocket.TextMessage, msg)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
ping := func() (err error) {
|
||||
if s.conn == nil {
|
||||
// Connection closed, nothing to ping
|
||||
return
|
||||
}
|
||||
|
||||
if err = s.conn.SetWriteDeadline(time.Now().Add(s.config.Timeout)); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if s.conn != nil {
|
||||
return s.conn.WriteMessage(gWebsocket.PingMessage, nil)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case msg, ok := <-s.send:
|
||||
if !ok {
|
||||
// channel closed
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := write(msg); err != nil {
|
||||
return errors.Wrap(err, "writeLoop send")
|
||||
}
|
||||
|
||||
case msg := <-s.stop:
|
||||
// Shutdown requested, don't care if the message is delivered
|
||||
_ = write(msg)
|
||||
return nil
|
||||
|
||||
case <-ticker.C:
|
||||
if err := ping(); err != nil {
|
||||
return errors.Wrap(err, "writeLoop ping")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *session) dispatch(raw []byte) error {
|
||||
var p, err = Unmarshal(raw)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Session.incoming: payload malformed")
|
||||
}
|
||||
|
||||
if p.Auth != nil {
|
||||
return s.authenticate(p.Auth)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *session) authenticate(p *Auth) error {
|
||||
// Get JWT claims
|
||||
claims, err := p.ParseWithClaims()
|
||||
if err != nil {
|
||||
s.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
// Get identity using JWT claims
|
||||
identity := auth.ClaimsToIdentity(claims)
|
||||
s.Save(identity)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *session) Save(identity *auth.Identity) *session {
|
||||
if s.id == 0 {
|
||||
s.id = id.Next()
|
||||
}
|
||||
|
||||
if identity != nil {
|
||||
userID := identity.Identity()
|
||||
existingSessions, ok := sessions[userID]
|
||||
// Add sessions for user
|
||||
if !ok {
|
||||
s.user = identity
|
||||
sessions[userID] = append(sessions[userID], s)
|
||||
}
|
||||
|
||||
// Update the identity in existing sessions
|
||||
for _, sess := range existingSessions {
|
||||
sess.user = identity
|
||||
}
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *session) Get(userID uint64) []*session {
|
||||
if sess, ok := sessions[userID]; ok {
|
||||
return sess
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *session) Delete() error {
|
||||
if s.id == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if s.user != nil {
|
||||
delete(sessions, s.user.Identity())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// sendBytes sends byte to channel or timout
|
||||
func (s *session) sendBytes(p []byte) error {
|
||||
select {
|
||||
case s.send <- p:
|
||||
case <-time.After(2 * time.Millisecond):
|
||||
s.logger.Warn("websocket.sendBytes send timeout")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@ -1,55 +0,0 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/dgrijalva/jwt-go"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type (
|
||||
// Auth is JWT token provided by client as first message,
|
||||
// and will be passed whenever it changes
|
||||
Auth struct {
|
||||
AccessToken *string `json:"access_token"`
|
||||
}
|
||||
|
||||
// payload for incoming messages from user
|
||||
payload struct {
|
||||
*Auth `json:"auth"`
|
||||
}
|
||||
|
||||
// response for sending messages to user
|
||||
response struct {
|
||||
Status string `json:"status"`
|
||||
Data interface{} `json:"data"`
|
||||
}
|
||||
)
|
||||
|
||||
func (a *Auth) ParseWithClaims() (jwt.MapClaims, error) {
|
||||
token, err := jwt.Parse(*a.AccessToken, nil)
|
||||
if token == nil {
|
||||
return nil, err
|
||||
}
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if ok {
|
||||
return claims, nil
|
||||
} else {
|
||||
return nil, errors.New("Invalid token")
|
||||
}
|
||||
}
|
||||
|
||||
func Unmarshal(raw []byte) (*payload, error) {
|
||||
var p payload
|
||||
return &p, json.Unmarshal(raw, &p)
|
||||
}
|
||||
|
||||
func Response(status string, data interface{}) *response {
|
||||
return &response{
|
||||
Status: status,
|
||||
Data: data,
|
||||
}
|
||||
}
|
||||
|
||||
func (m response) Marshal() ([]byte, error) {
|
||||
return json.Marshal(m)
|
||||
}
|
||||
@ -1,87 +0,0 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"github.com/cortezaproject/corteza-server/pkg/api"
|
||||
"github.com/cortezaproject/corteza-server/pkg/options"
|
||||
gWebsocket "github.com/gorilla/websocket"
|
||||
"github.com/pkg/errors"
|
||||
"go.uber.org/zap"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
var (
|
||||
// upgrader handles websocket requests from peers
|
||||
upgrader = gWebsocket.Upgrader{
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
|
||||
// Allow connections from any Origin
|
||||
CheckOrigin: func(r *http.Request) bool { return true },
|
||||
}
|
||||
)
|
||||
|
||||
type (
|
||||
websocket struct {
|
||||
config options.WebsocketOpt
|
||||
logger *zap.Logger
|
||||
}
|
||||
)
|
||||
|
||||
func Websocket(logger *zap.Logger, config options.WebsocketOpt) *websocket {
|
||||
if !config.LogEnabled {
|
||||
logger = zap.NewNop()
|
||||
}
|
||||
|
||||
return &websocket{
|
||||
config: config,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (ws *websocket) Open(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if _, ok := err.(gWebsocket.HandshakeError); ok {
|
||||
ws.logger.Error("ws: need a websocket handshake")
|
||||
api.Send(w, r, errors.Wrap(err, "ws: need a websocket handshake"))
|
||||
return
|
||||
} else if err != nil {
|
||||
ws.logger.Error("ws: failed to upgrade connection")
|
||||
api.Send(w, r, errors.Wrap(err, "ws: failed to upgrade connection"))
|
||||
return
|
||||
}
|
||||
|
||||
session := Session(ctx, ws.logger, ws.config, conn)
|
||||
|
||||
if err := session.Handle(); err != nil {
|
||||
ws.logger.
|
||||
WithOptions(zap.AddStacktrace(zap.PanicLevel)).
|
||||
Warn("websocket session handler error", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
// Send delivers message to user to ones we want to
|
||||
// if len(userIDs) == 0 -- it delivers to everyone
|
||||
func (ws *websocket) Send(kind string, payload interface{}, userIDs ...uint64) error {
|
||||
pb, err := Response(kind, payload).Marshal()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sendsToAll := len(userIDs) == 0
|
||||
userIDMap := make(map[uint64]bool)
|
||||
for _, userID := range userIDs {
|
||||
userIDMap[userID] = true
|
||||
}
|
||||
|
||||
for uid, uSessions := range sessions {
|
||||
if sendsToAll || (!sendsToAll && userIDMap[uid]) {
|
||||
for _, sess := range uSessions {
|
||||
_ = sess.sendBytes(pb)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -1,127 +0,0 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/cortezaproject/corteza-server/pkg/options"
|
||||
gWebsocket "github.com/gorilla/websocket"
|
||||
"go.uber.org/zap"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// WebsocketServer provide websocket server for testing
|
||||
func WebsocketTestServer(t *testing.T) (*httptest.Server, *gWebsocket.Conn) {
|
||||
// Create test server with the websocket handler.
|
||||
s := httptest.NewServer(http.HandlerFunc(wsOpen))
|
||||
|
||||
// Convert http://.. to ws://..
|
||||
u := "ws" + strings.TrimPrefix(s.URL, "http")
|
||||
|
||||
// Connect to the server
|
||||
conn, _, err := gWebsocket.DefaultDialer.Dial(u, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("WebsocketServer() error while creating connection = %v", err)
|
||||
}
|
||||
|
||||
return s, conn
|
||||
}
|
||||
|
||||
// wsOpen opens websocket connection
|
||||
func wsOpen(w http.ResponseWriter, r *http.Request) {
|
||||
var gUpgrader = gWebsocket.Upgrader{}
|
||||
c, err := gUpgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer func(c *gWebsocket.Conn) {
|
||||
_ = c.Close()
|
||||
}(c)
|
||||
|
||||
for {
|
||||
mt, message, err := c.ReadMessage()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
err = c.WriteMessage(mt, message)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendingMessageToUser(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
kind string
|
||||
payload interface{}
|
||||
expectedP string
|
||||
}{
|
||||
{
|
||||
name: "send json",
|
||||
kind: "Json",
|
||||
payload: struct {
|
||||
Title string `json:"title"`
|
||||
Description string `json:"description"`
|
||||
}{Title: "Websocket", Description: "Testing connection.."},
|
||||
expectedP: `{"status":"ok","data":{"title":"Websocket","description":"Testing connection.."}}`,
|
||||
},
|
||||
{
|
||||
name: "send text",
|
||||
kind: "Text",
|
||||
payload: "testing connectivity",
|
||||
expectedP: "testing connectivity",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var logger *zap.Logger
|
||||
var config options.WebsocketOpt
|
||||
ws := Websocket(logger, config)
|
||||
s, conn := WebsocketTestServer(t)
|
||||
defer s.Close()
|
||||
defer func(ws *gWebsocket.Conn) {
|
||||
err := ws.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("TestSendingMessageToUser() error closing connection = %v", err)
|
||||
}
|
||||
}(conn)
|
||||
|
||||
// Open a session using ws connection
|
||||
wsSession := Session(context.Background(), ws.logger, ws.config, conn)
|
||||
|
||||
var messageType int
|
||||
var data []byte
|
||||
switch tt.kind {
|
||||
case "Text":
|
||||
messageType = gWebsocket.TextMessage
|
||||
data = []byte(fmt.Sprintf("%v", tt.payload))
|
||||
case "Json":
|
||||
res := Response("ok", tt.payload)
|
||||
messageType = gWebsocket.BinaryMessage
|
||||
var err error
|
||||
data, err = res.Marshal()
|
||||
if err != nil {
|
||||
t.Fatalf("TestSendingMessageToUser() error while marshaling payload = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Send message to server, read response and check to see if it's what we expect.
|
||||
if err := wsSession.conn.WriteMessage(messageType, data); err != nil {
|
||||
t.Fatalf("TestSendingMessageToUser() error while sending message = %v", err)
|
||||
}
|
||||
|
||||
_, p, err := wsSession.conn.ReadMessage()
|
||||
if err != nil {
|
||||
t.Fatalf("TestSendingMessageToUser() error while reading message =%v", err)
|
||||
}
|
||||
|
||||
if string(p) != tt.expectedP {
|
||||
t.Fatalf("TestSendingMessageToUser() gotP = %v, want = %v", string(p), tt.expectedP)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user