diff --git a/system/internal/repository/credentials.go b/system/internal/repository/credentials.go index 68de68dfd..9a5d804a0 100644 --- a/system/internal/repository/credentials.go +++ b/system/internal/repository/credentials.go @@ -100,7 +100,7 @@ func (r *credentials) Create(c *types.Credentials) (*types.Credentials, error) { func (r *credentials) Update(c *types.Credentials) (*types.Credentials, error) { updatedAt := time.Now() c.UpdatedAt = &updatedAt - return c, r.db().Update(r.tblname, c) + return c, r.db().Replace(r.tblname, c) } func (r *credentials) DeleteByID(id uint64) error { diff --git a/system/internal/service/auth.go b/system/internal/service/auth.go index ceaff9485..3eddf4dda 100644 --- a/system/internal/service/auth.go +++ b/system/internal/service/auth.go @@ -112,6 +112,19 @@ func (svc *auth) External(profile goth.User) (u *types.User, err error) { c.ID) } else if u.Valid() { // Valid user, matching emails. Bingo! + c.LastUsedAt = svc.now() + if c, err = svc.credentials.Update(c); err != nil { + return err + } + + log.Printf( + "updating credential entry (%v, %v) for exisintg user (%v, %v)", + c.ID, + profile.Provider, + u.ID, + u.Email, + ) + return nil } else { // Scenario: linked to an invalid user @@ -161,50 +174,30 @@ func (svc *auth) External(profile goth.User) (u *types.User, err error) { ) } - if c == nil { - c = &types.Credentials{ - Kind: profile.Provider, - OwnerID: u.ID, - Credentials: profile.UserID, - LastUsedAt: svc.now(), - } - - if !profile.ExpiresAt.IsZero() { - // Copy expiration date when provided - c.ExpiresAt = &profile.ExpiresAt - } - - if c, err = svc.credentials.Create(c); err != nil { - return err - } - - log.Printf( - "creating new credential entry (%v, %v) for exisintg user (%v, %v)", - c.ID, - profile.Provider, - u.ID, - u.Email, - ) - } else { - if !profile.ExpiresAt.IsZero() { - // Copy expiration date when provided - c.ExpiresAt = &profile.ExpiresAt - } - - c.LastUsedAt = svc.now() - if c, err = svc.credentials.Update(c); err != nil { - return err - } - - log.Printf( - "updating credential entry (%v, %v) for exisintg user (%v, %v)", - c.ID, - profile.Provider, - u.ID, - u.Email, - ) + c = &types.Credentials{ + Kind: profile.Provider, + OwnerID: u.ID, + Credentials: profile.UserID, + LastUsedAt: svc.now(), } + if !profile.ExpiresAt.IsZero() { + // Copy expiration date when provided + c.ExpiresAt = &profile.ExpiresAt + } + + if c, err = svc.credentials.Create(c); err != nil { + return err + } + + log.Printf( + "creating new credential entry (%v, %v) for exisintg user (%v, %v)", + c.ID, + profile.Provider, + u.ID, + u.Email, + ) + // Owner loaded, carry on. return nil }) diff --git a/system/internal/service/auth_test.go b/system/internal/service/auth_test.go index 2ae1d5452..c9eab757e 100644 --- a/system/internal/service/auth_test.go +++ b/system/internal/service/auth_test.go @@ -56,6 +56,11 @@ func TestAuth_External_Existing(t *testing.T) { Times(1). Return(types.CredentialsSet{c}, nil) + crdRpoMock.EXPECT(). + Update(gomock.Any()). + Times(1). + Return(c, nil) + usrRpoMock := repomock.NewMockUserRepository(mockCtrl) usrRpoMock.EXPECT().FindByID(u.ID).Times(1).Return(u, nil)