diff --git a/pkg/auth/identity.go b/pkg/auth/identity.go index 4d403a762..4869b72be 100644 --- a/pkg/auth/identity.go +++ b/pkg/auth/identity.go @@ -1,5 +1,9 @@ package auth +import ( + "fmt" +) + type ( Identity struct { id uint64 @@ -30,6 +34,10 @@ func (i Identity) Valid() bool { return i.id > 0 } +func (i Identity) String() string { + return fmt.Sprintf("%d", i.id) +} + func NewSuperUserIdentity() *Identity { return NewIdentity(superUserID) } diff --git a/pkg/auth/interfaces.go b/pkg/auth/interfaces.go index b9e779c58..52ecfb169 100644 --- a/pkg/auth/interfaces.go +++ b/pkg/auth/interfaces.go @@ -9,6 +9,7 @@ type ( Identity() uint64 Roles() []uint64 Valid() bool + String() string } TokenEncoder interface { diff --git a/system/service/user.go b/system/service/user.go index a1147c239..d7ad4cd09 100644 --- a/system/service/user.go +++ b/system/service/user.go @@ -3,6 +3,8 @@ package service import ( "context" "io" + "strconv" + "strings" "github.com/pkg/errors" "github.com/titpetric/factory" @@ -79,6 +81,7 @@ type ( FindByEmail(email string) (*types.User, error) FindByHandle(handle string) (*types.User, error) FindByID(id uint64) (*types.User, error) + FindByAny(any string) (*types.User, error) Find(types.UserFilter) (types.UserSet, types.UserFilter, error) Create(input *types.User) (*types.User, error) @@ -154,6 +157,20 @@ func (svc user) FindByHandle(handle string) (*types.User, error) { return svc.proc(svc.user.FindByHandle(handle)) } +func (svc user) FindByAny(any string) (*types.User, error) { + return svc.proc(func() (*types.User, error) { + if id, _ := strconv.ParseUint(any, 10, 64); id > 0 { + return svc.user.FindByID(id) + } + + if strings.Contains(any, "@") { + return svc.user.FindByEmail(any) + } + + return svc.user.FindByHandle(any) + }()) +} + func (svc user) proc(u *types.User, err error) (*types.User, error) { if err != nil { return nil, err diff --git a/system/types/user.go b/system/types/user.go index ed5e4ce7d..38cac4f15 100644 --- a/system/types/user.go +++ b/system/types/user.go @@ -3,6 +3,7 @@ package types import ( "database/sql/driver" "encoding/json" + "fmt" "time" "github.com/pkg/errors" @@ -89,6 +90,10 @@ const ( BotUser UserKind = "bot" ) +func (u User) String() string { + return fmt.Sprintf("%d", u.ID) +} + func (u *User) Valid() bool { return u.ID > 0 && u.SuspendedAt == nil && u.DeletedAt == nil }