3
0

Replace r.fetchSet with rh.FetchAll

This commit is contained in:
Denis Arh 2019-10-19 16:58:37 +02:00
parent fffeb84abb
commit 93ceb60a42
2 changed files with 11 additions and 29 deletions

View File

@ -3,7 +3,6 @@ package repository
import (
"context"
"github.com/Masterminds/squirrel"
"github.com/titpetric/factory"
"github.com/cortezaproject/corteza-server/pkg/auth"
@ -46,21 +45,3 @@ func (r *repository) db() *factory.DB {
}
return DB(r.ctx)
}
// Fetches single row from table
func (r repository) fetchSet(set interface{}, q squirrel.SelectBuilder) (err error) {
var (
sql string
args []interface{}
)
if sql, args, err = q.ToSql(); err != nil {
return
}
if err = r.db().Select(set, sql, args...); err != nil {
return
}
return
}

View File

@ -7,6 +7,7 @@ import (
"github.com/titpetric/factory"
"github.com/cortezaproject/corteza-server/messaging/types"
"github.com/cortezaproject/corteza-server/pkg/rh"
)
type (
@ -73,14 +74,14 @@ func (r *unread) Count(userID, channelID uint64, threadIDs ...uint64) (types.Unr
var (
uu = types.UnreadSet{}
q = squirrel.
Select().
From(r.table()).
Columns(
Select(
"rel_channel",
"rel_last_message",
"rel_user",
"rel_reply_to",
"count")
"count",
).
From(r.table())
)
if userID > 0 {
@ -97,7 +98,7 @@ func (r *unread) Count(userID, channelID uint64, threadIDs ...uint64) (types.Unr
q = q.Where(squirrel.Eq{"rel_reply_to": threadIDs})
}
return uu, r.fetchSet(&uu, q)
return uu, rh.FetchAll(r.db(), q, &uu)
}
// CountReplies counts unread thread info
@ -116,13 +117,13 @@ func (r unread) CountThreads(userID, channelID uint64) (types.UnreadSet, error)
temp = []*u{}
q = squirrel.
Select().
From(r.table()).
Columns(
Select(
"rel_channel",
"rel_user",
"sum(count) AS count",
"sum(CASE WHEN count > 0 THEN 1 ELSE 0 END) AS total").
"sum(CASE WHEN count > 0 THEN 1 ELSE 0 END) AS total",
).
From(r.table()).
Where("rel_reply_to > 0 AND count > 0").
GroupBy("rel_channel", "rel_user")
)
@ -135,7 +136,7 @@ func (r unread) CountThreads(userID, channelID uint64) (types.UnreadSet, error)
q = q.Where("rel_channel = ?", channelID)
}
err = r.fetchSet(&temp, q)
err = rh.FetchAll(r.db(), q, &temp)
if err != nil {
return nil, err
}