From 8265d4f3c30fe38420bf897df909122e5cb111a7 Mon Sep 17 00:00:00 2001 From: Tit Petric Date: Tue, 21 Aug 2018 13:04:14 +0200 Subject: [PATCH] add(sam): import pubsub service --- sam/service/pubsub.go | 121 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 121 insertions(+) create mode 100644 sam/service/pubsub.go diff --git a/sam/service/pubsub.go b/sam/service/pubsub.go new file mode 100644 index 000000000..53b767e13 --- /dev/null +++ b/sam/service/pubsub.go @@ -0,0 +1,121 @@ +package service + +import ( + "context" + "github.com/gomodule/redigo/redis" + "github.com/pkg/errors" + "time" +) + +type PubSub struct { + ctx context.Context + + redisServerAddr string + + healthCheckInterval time.Duration +} + +func (ps PubSub) New(redisServerAddr string, ctx context.Context) *PubSub { + return &PubSub{ + ctx: ctx, + redisServerAddr: redisServerAddr, + healthCheckInterval: time.Minute, + } +} + +func (ps *PubSub) With(ctx context.Context) *PubSub { + return &PubSub{ + ctx: ctx, + healthCheckInterval: ps.healthCheckInterval, + } +} + +func (ps *PubSub) dial() (redis.Conn, error) { + readTimeout := redis.DialReadTimeout(ps.healthCheckInterval + 10*time.Second) + writeTimeout := redis.DialWriteTimeout(10 * time.Second) + return redis.Dial("tcp", ps.redisServerAddr, readTimeout, writeTimeout) +} + +func (ps *PubSub) Subscribe(onStart func() error, onMessage func(channel string, payload []byte) error, channels ...string) error { + if len(channels) == 0 { + return errors.New("Need to subscribe at least to one channel") + } + + // main redis connection + conn, err := ps.dial() + if err != nil { + return err + } + defer conn.Close() + + // pubsub object + psc := redis.PubSubConn{Conn: conn} + if err := psc.Subscribe(redis.Args{}.AddFlat(channels)...); err != nil { + return err + } + + done := make(chan error, 1) + + // Start a goroutine to receive notifications from the server. + go func() { + for { + switch n := psc.Receive().(type) { + case error: + done <- n + return + case redis.Message: + if err := onMessage(n.Channel, n.Data); err != nil { + done <- err + return + } + case redis.Subscription: + switch n.Count { + case len(channels): + // Notify application when all channels are subscribed. + if err := onStart(); err != nil { + done <- err + return + } + case 0: + // Return from the goroutine when all channels are unsubscribed. + done <- nil + return + } + } + } + }() + + ticker := time.NewTicker(ps.healthCheckInterval) + defer ticker.Stop() + + cleanup := func(err error) error { + psc.Unsubscribe() + return err + } + + for { + select { + case <-ticker.C: + if err := psc.Ping(""); err != nil { + return cleanup(err) + } + case <-ps.ctx.Done(): + return cleanup(ps.ctx.Err()) + case err := <-done: + return err + } + } +} + +func (ps *PubSub) Publish(channel string, payload string) error { + // main redis connection + conn, err := ps.dial() + if err != nil { + return err + } + defer conn.Close() + + // publish payload on channel + _, err = conn.Do("PUBLISH", channel, payload) + return err +}