diff --git a/compose/internal/repository/record.go b/compose/internal/repository/record.go index b80315e58..8d88966eb 100644 --- a/compose/internal/repository/record.go +++ b/compose/internal/repository/record.go @@ -28,7 +28,7 @@ type ( Update(record *types.Record) (*types.Record, error) Delete(record *types.Record) error - LoadValues(IDs ...uint64) (rvs types.RecordValueSet, err error) + LoadValues(fieldNames []string, IDs []uint64) (rvs types.RecordValueSet, err error) DeleteValues(record *types.Record) error UpdateValues(recordID uint64, rvs types.RecordValueSet) (err error) } @@ -294,7 +294,7 @@ func (r record) UpdateValues(recordID uint64, rvs types.RecordValueSet) (err err } -func (r record) LoadValues(IDs ...uint64) (rvs types.RecordValueSet, err error) { +func (r record) LoadValues(fieldNames []string, IDs []uint64) (rvs types.RecordValueSet, err error) { if len(IDs) == 0 { return } @@ -302,10 +302,11 @@ func (r record) LoadValues(IDs ...uint64) (rvs types.RecordValueSet, err error) var sql = "SELECT record_id, name, value, ref, place, deleted_at " + " FROM compose_record_value " + " WHERE record_id IN (?) " + + " AND name IN (?) " + " AND deleted_at IS NULL " + " ORDER BY record_id, place" - if sql, args, err := sqlx.In(sql, IDs); err != nil { + if sql, args, err := sqlx.In(sql, IDs, fieldNames); err != nil { return nil, err } else { return rvs, r.db().Select(&rvs, sql, args...) diff --git a/compose/internal/service/record.go b/compose/internal/service/record.go index 8635c9e5c..8d926e12d 100644 --- a/compose/internal/service/record.go +++ b/compose/internal/service/record.go @@ -34,6 +34,7 @@ type ( CanReadRecord(context.Context, *types.Module) bool CanUpdateRecord(context.Context, *types.Module) bool CanDeleteRecord(context.Context, *types.Module) bool + CanReadRecordValue(context.Context, *types.ModuleField) bool } RecordService interface { @@ -97,7 +98,7 @@ func (svc record) FindByID(namespaceID, recordID uint64) (r *types.Record, err e return nil, ErrNoReadPermissions.withStack() } - if err = svc.preloadValues(r); err != nil { + if err = svc.preloadValues(m, r); err != nil { return } @@ -140,7 +141,7 @@ func (svc record) Find(filter types.RecordFilter) (set types.RecordSet, f types. return } - if err = svc.preloadValues(set...); err != nil { + if err = svc.preloadValues(m, set...); err != nil { return } @@ -178,7 +179,7 @@ func (svc record) Create(mod *types.Record) (r *types.Record, err error) { return } - if err = svc.preloadValues(r); err != nil { + if err = svc.preloadValues(m, r); err != nil { return } @@ -294,8 +295,8 @@ func (svc record) sanitizeValues(module *types.Module, values types.RecordValueS }) } -func (svc record) preloadValues(rr ...*types.Record) error { - if rvs, err := svc.recordRepo.LoadValues(types.RecordSet(rr).IDs()...); err != nil { +func (svc record) preloadValues(m *types.Module, rr ...*types.Record) error { + if rvs, err := svc.recordRepo.LoadValues(svc.readableFields(m), types.RecordSet(rr).IDs()); err != nil { return err } else { return types.RecordSet(rr).Walk(func(r *types.Record) error { @@ -304,3 +305,18 @@ func (svc record) preloadValues(rr ...*types.Record) error { }) } } + +// readableFields creates a slice of module fields that current user has permission to read +func (svc record) readableFields(m *types.Module) []string { + ff := make([]string, 0) + + _ = m.Fields.Walk(func(f *types.ModuleField) error { + if svc.ac.CanReadRecordValue(svc.ctx, f) { + ff = append(ff, f.Name) + } + + return nil + }) + + return ff +}