3
0

250 lines
4.8 KiB
Go

package websocket
import (
"context"
"log"
"sync"
"time"
"github.com/gorilla/websocket"
"github.com/pkg/errors"
"github.com/crusttech/crust/internal/auth"
"github.com/crusttech/crust/internal/payload"
"github.com/crusttech/crust/internal/payload/outgoing"
"github.com/crusttech/crust/messaging/internal/repository"
"github.com/crusttech/crust/messaging/internal/service"
"github.com/crusttech/crust/messaging/types"
)
type (
// Session
Session struct {
id uint64
once sync.Once
conn *websocket.Conn
ctx context.Context
ctxCancel context.CancelFunc
subs *Subscriptions
send chan []byte
stop chan []byte
remoteAddr string
config *repository.Flags
user auth.Identifiable
svc struct {
ch service.ChannelService
msg service.MessageService
}
}
)
func (Session) New(ctx context.Context, config *repository.Flags, conn *websocket.Conn) *Session {
s := &Session{
conn: conn,
config: config,
subs: NewSubscriptions(),
send: make(chan []byte, 512),
stop: make(chan []byte, 1),
}
s.ctx, s.ctxCancel = context.WithCancel(ctx)
s.svc.ch = service.DefaultChannel
s.svc.msg = service.DefaultMessage
return s
}
func (sess *Session) Context() context.Context {
return sess.ctx
}
func (sess *Session) connected() (err error) {
var (
cc types.ChannelSet
)
// Push user info about all channels he has access to...
if cc, err = sess.svc.ch.With(sess.ctx).Find(&types.ChannelFilter{}); err != nil {
log.Printf("Error: %v", err)
} else {
log.Printf("Subscribing %d to %d channels", sess.user.Identity(), len(cc))
err = cc.Walk(func(c *types.Channel) error {
// Subscribe this user/session to all channels
sess.subs.Add(payload.Uint64toa(c.ID))
return nil
})
if err != nil {
return
}
}
// Tell everyone that user has connected
if err = sess.sendPresence("connected"); err != nil {
return
}
// Create a heartbeat every minute for this user
go func() {
t := time.NewTicker(time.Second * 60)
for {
select {
case <-sess.ctx.Done():
return
case <-t.C:
sess.sendPresence("")
}
}
}()
return nil
}
func (sess *Session) disconnected() {
// Tell everyone that user has disconnected
_ = sess.sendPresence("disconnected")
// Cancel context
sess.ctxCancel()
// Close connection
sess.conn.Close()
sess.conn = nil
}
// Sends user presence information to all subscribers
//
// It sends "connected", "disconnected" and "" activity kinds
func (sess *Session) sendPresence(kind string) error {
connections := store.CountConnections(sess.user.Identity())
if kind == "disconnected" {
connections--
}
// Tell everyone that user has disconnected
return sess.sendToAll(&outgoing.Activity{
UserID: sess.user.Identity(),
Kind: kind,
Present: connections > 0,
})
}
func (sess *Session) Handle() (err error) {
if err = sess.connected(); err != nil {
sess.Close()
return
}
go sess.readLoop()
return sess.writeLoop()
}
func (sess *Session) Close() {
sess.once.Do(func() {
sess.disconnected()
store.Delete(sess.id)
})
}
func (sess *Session) readLoop() (err error) {
defer func() {
sess.Close()
}()
if err = sess.conn.SetReadDeadline(time.Now().Add(sess.config.Websocket.PingTimeout)); err != nil {
return
}
sess.conn.SetPongHandler(func(string) error {
return sess.conn.SetReadDeadline(time.Now().Add(sess.config.Websocket.PingTimeout))
})
sess.remoteAddr = sess.conn.RemoteAddr().String()
for {
_, raw, err := sess.conn.ReadMessage()
if err != nil {
return errors.Wrap(err, "sess.readLoop")
}
if err = sess.dispatch(raw); err != nil {
log.Printf("Error: %v", err)
sess.sendReply(outgoing.NewError(err))
}
}
}
func (sess *Session) writeLoop() error {
ticker := time.NewTicker(sess.config.Websocket.PingPeriod)
defer func() {
ticker.Stop()
sess.Close() // break readLoop
}()
write := func(msg []byte) (err error) {
if sess.conn == nil {
// Connection closed, nowhere to write
return
}
if err = sess.conn.SetWriteDeadline(time.Now().Add(sess.config.Websocket.Timeout)); err != nil {
return
}
if msg != nil {
return sess.conn.WriteMessage(websocket.TextMessage, msg)
}
return
}
ping := func() (err error) {
if sess.conn == nil {
// Connection closed, nothing to ping
return
}
if err = sess.conn.SetWriteDeadline(time.Now().Add(sess.config.Websocket.Timeout)); err != nil {
return
}
return sess.conn.WriteMessage(websocket.PingMessage, nil)
}
for {
select {
case msg, ok := <-sess.send:
if !ok {
// channel closed
return nil
}
if err := write(msg); err != nil {
return errors.Wrap(err, "writeLoop send")
}
case msg := <-sess.stop:
// Shutdown requested, don't care if the message is delivered
_ = write(msg)
return nil
case <-ticker.C:
if err := ping(); err != nil {
return errors.Wrap(err, "writeLoop ping")
}
}
}
}