3
0
corteza/server/pkg/wfexec/session.go
2024-11-05 15:23:54 +01:00

1010 lines
21 KiB
Go

package wfexec
import (
"context"
"fmt"
"runtime/debug"
"sync"
"time"
"github.com/cortezaproject/corteza/server/pkg/auth"
"github.com/cortezaproject/corteza/server/pkg/expr"
"github.com/cortezaproject/corteza/server/pkg/id"
"github.com/cortezaproject/corteza/server/pkg/logger"
"go.uber.org/zap"
)
type (
Session struct {
// Session identifier
id uint64
workflowID uint64
// steps graph
g *Graph
started time.Time
// state channel (ie work queue)
qState chan *State
// crash channel
qErr chan error
// locks concurrent executions
execLock chan struct{}
// delayed states (waiting for the right time)
delayed map[uint64]*delayed
// prompted
prompted map[uint64]*prompted
// how often we check for delayed states and how often idle stat is checked in Wait()
workerIntervalSuspended time.Duration
workerIntervalWaiter time.Duration
// only one worker routine per session
workerLock chan struct{}
// holds final result
result *expr.Vars
err error
mux sync.RWMutex
// debug logger
log *zap.Logger
dumpStacktraceOnPanic bool
eventHandler StateChangeHandler
// This keeps track of workflow calls
callStack []uint64
}
StateChangeHandler func(SessionStatus, *State, *Session) error
SessionOpt func(*Session)
Frame struct {
CreatedAt time.Time `json:"createdAt"`
SessionID uint64 `json:"sessionID"`
StateID uint64 `json:"stateID"`
Input *expr.Vars `json:"input"`
Scope *expr.Vars `json:"scope"`
Results *expr.Vars `json:"results"`
ParentID uint64 `json:"parentID"`
StepID uint64 `json:"stepID"`
NextSteps []uint64 `json:"nextSteps"`
// How much time from the 1st step to the start of this step in milliseconds
ElapsedTime uint `json:"elapsedTime"`
// How much time it took to execute this step in milliseconds
StepTime uint `json:"stepTime"`
Action string `json:"action,omitempty"`
Error string `json:"error,omitempty"`
}
// ExecRequest is passed to Exec() functions and contains all information
// for step execution
ExecRequest struct {
SessionID uint64
StateID uint64
// Current input received on session resume
Input *expr.Vars
// Current scope
Scope *expr.Vars
// Helps with gateway join/merge steps
// that needs info about the step it's currently merging
Parent Step
}
SessionStatus int
callStackCtxKey struct{}
)
const (
sessionStateChanBuf = 512
sessionConcurrentExec = 32
)
const (
SessionActive SessionStatus = iota
SessionPrompted
SessionDelayed
SessionFailed
SessionCompleted
SessionCanceled
)
var (
// wrapper around nextID that will aid service testing
nextID = func() uint64 {
return id.Next()
}
// wrapper around time.Now() that will aid service testing
now = func() *time.Time {
c := time.Now()
return &c
}
errCanceled = fmt.Errorf("canceled")
)
func (s SessionStatus) String() string {
switch s {
case SessionActive:
return "active"
case SessionPrompted:
return "prompted"
case SessionDelayed:
return "delayed"
case SessionFailed:
return "failed"
case SessionCompleted:
return "completed"
case SessionCanceled:
return "aborted"
}
return "UNKNOWN-SESSION-STATUS"
}
func NewSession(ctx context.Context, g *Graph, oo ...SessionOpt) *Session {
s := &Session{
g: g,
id: nextID(),
started: *now(),
qState: make(chan *State, sessionStateChanBuf),
qErr: make(chan error, 1),
execLock: make(chan struct{}, sessionConcurrentExec),
delayed: make(map[uint64]*delayed),
prompted: make(map[uint64]*prompted),
// Setting this one to something higher since it'll need external interaction
workerIntervalSuspended: time.Millisecond * 100,
// Setting this to a smaller number since the wait fnc. is a tight loop
workerIntervalWaiter: time.Millisecond,
workerLock: make(chan struct{}, 1),
log: zap.NewNop(),
eventHandler: func(SessionStatus, *State, *Session) error {
// noop
return nil
},
}
for _, o := range oo {
o(s)
}
s.log = s.log.
With(logger.Uint64("sessionID", s.id))
s.callStack = append(s.callStack, s.id)
go s.worker(ctx)
return s
}
func (s *Session) Status() SessionStatus {
s.mux.RLock()
defer s.mux.RUnlock()
switch {
case s.err != nil:
if s.err == errCanceled {
return SessionCanceled
}
return SessionFailed
case len(s.prompted) > 0:
return SessionPrompted
case len(s.delayed) > 0:
return SessionDelayed
case s.result == nil:
return SessionActive
default:
return SessionCompleted
}
}
func (s *Session) ID() uint64 {
return s.id
}
func (s *Session) Idle() bool {
return s.Status() != SessionActive
}
func (s *Session) Error() error {
s.mux.RLock()
defer s.mux.RUnlock()
return s.err
}
func (s *Session) Result() *expr.Vars {
s.mux.RLock()
defer s.mux.RUnlock()
return s.result
}
func (s *Session) Exec(ctx context.Context, step Step, scope *expr.Vars) error {
s.mux.RLock()
defer s.mux.RUnlock()
err := func() error {
if s.g.Len() == 0 {
return fmt.Errorf("refusing to execute without steps")
}
if len(s.g.Parents(step)) > 0 {
return fmt.Errorf("cannot execute step with parents")
}
return nil
}()
if err != nil {
// send nil to error queue to trigger worker shutdown
// session error must be set to update session status
s.qErr <- err
return err
}
if scope == nil {
scope, _ = expr.NewVars(nil)
}
return s.enqueue(ctx, NewState(s, auth.GetIdentityFromContext(ctx), nil, step, scope))
}
// UserPendingPrompts prompts fn returns all owner's pending prompts on this session
func (s *Session) UserPendingPrompts(ownerId uint64) (out []*PendingPrompt) {
if ownerId == 0 {
return
}
defer s.mux.RUnlock()
s.mux.RLock()
out = make([]*PendingPrompt, 0, len(s.prompted))
for _, p := range s.prompted {
if p.ownerId != ownerId {
continue
}
pending := p.toPending()
pending.SessionID = s.id
out = append(out, pending)
}
return
}
// AllPendingPrompts returns all pending prompts for all user
func (s *Session) AllPendingPrompts() (out []*PendingPrompt) {
defer s.mux.RUnlock()
s.mux.RLock()
return s.pendingPrompts(s.prompted)
}
// UnsentPendingPrompts returns unsent pending prompts for all user
func (s *Session) UnsentPendingPrompts() (out []*PendingPrompt) {
defer s.mux.RUnlock()
s.mux.RLock()
aux := s.pendingPrompts(s.prompted)
for _, p := range aux {
if p.Original.sent {
continue
}
out = append(out, p)
}
return
}
func (s *Session) pendingPrompts(prompted map[uint64]*prompted) (out []*PendingPrompt) {
out = make([]*PendingPrompt, 0, len(prompted))
for _, p := range prompted {
pending := p.toPending()
pending.SessionID = s.id
out = append(out, pending)
}
return
}
func (s *Session) Resume(ctx context.Context, stateId uint64, input *expr.Vars) (*ResumedPrompt, error) {
defer s.mux.Unlock()
s.mux.Lock()
var (
i = auth.GetIdentityFromContext(ctx)
p, has = s.prompted[stateId]
)
if !has {
return nil, fmt.Errorf("unexisting state")
}
if i == nil || p.ownerId != i.Identity() {
return nil, fmt.Errorf("state access denied")
}
delete(s.prompted, stateId)
// setting received input to state
p.state.input = input
if err := s.enqueue(ctx, p.state); err != nil {
return nil, err
}
return p.toResumed(), nil
}
func (s *Session) canEnqueue(st *State) error {
if st == nil {
return fmt.Errorf("state is nil")
}
// when the step is completed right away, it is considered as special
if st.step == nil && st.completed == nil {
return fmt.Errorf("state step is nil")
}
return nil
}
func (s *Session) enqueue(ctx context.Context, st *State) error {
if err := s.canEnqueue(st); err != nil {
return err
}
if st.stateId == 0 {
st.stateId = nextID()
}
select {
case <-ctx.Done():
return ctx.Err()
case s.qState <- st:
s.log.Debug("add step to queue")
return nil
}
}
// Wait does not wait for the whole wf to be complete but until:
// - context timeout
// - idle state
// - error in error queue
func (s *Session) Wait(ctx context.Context) error {
return s.WaitUntil(ctx, SessionFailed, SessionDelayed, SessionCompleted)
}
// WaitUntil blocks until workflow session gets into expected status
func (s *Session) WaitUntil(ctx context.Context, expected ...SessionStatus) error {
indexed := make(map[SessionStatus]bool)
for _, status := range expected {
indexed[status] = true
}
// already at the expected status
if indexed[s.Status()] {
return s.err
}
s.log.Debug(
"waiting for status change",
zap.Any("expecting", expected),
zap.Duration("interval", s.workerIntervalWaiter),
)
waitCheck := time.NewTicker(s.workerIntervalWaiter)
defer waitCheck.Stop()
for {
select {
case <-waitCheck.C:
status := s.Status()
if indexed[status] {
s.log.Debug("waiting complete", zap.Stringer("status", status))
// nothing in the pipeline
return s.err
}
case <-ctx.Done():
s.log.Debug("wait context canceled", zap.Error(ctx.Err()))
s.Cancel()
return s.err
}
}
}
func (s *Session) worker(ctx context.Context) {
defer s.Stop()
// making sure
defer close(s.workerLock)
s.workerLock <- struct{}{}
workerTicker := time.NewTicker(s.workerIntervalSuspended)
defer workerTicker.Stop()
for {
select {
case <-ctx.Done():
s.log.Debug("worker context done", zap.Error(ctx.Err()))
return
case <-workerTicker.C:
s.queueScheduledSuspended()
case st := <-s.qState:
s.log.Debug("pulled state from queue", logger.Uint64("stateID", st.stateId))
if st.step == nil {
// When there are any suspended steps we shouldn't kill the worker
// as those need to be processed.
if s.Suspended() {
break
}
s.log.Debug("done, setting results and stopping the worker")
// Make sure we're serving a non-nil value
s.mux.Lock()
if st.scope.IsEmpty() {
s.result = &expr.Vars{}
} else {
s.result = st.scope
}
s.mux.Unlock()
// Call event handler with completed status
err := s.eventHandler(SessionCompleted, st, s)
if err != nil {
err = fmt.Errorf(
"workflow %d state change handler failed: %w",
s.workflowID,
err,
)
s.log.Error(err.Error())
s.err = err
return
}
return
}
// add empty struct to chan to lock and to have control over number of concurrent go processes
// this will block if number of items in execLock chan reached value of sessionConcurrentExec
s.execLock <- struct{}{}
go func() {
defer func() {
// remove protection that prevents multiple
// steps executing at the same time
<-s.execLock
}()
var (
err error
log = s.log.With(logger.Uint64("stateID", st.stateId))
)
nxt, err := s.exec(ctx, log, st)
if err != nil && st.err == nil {
// If exec returns an error, use that one over the wf runtime error
st.err = err
}
st.completed = now()
status := s.Status()
if st.err != nil {
st.err = fmt.Errorf(
"workflow %d step %d execution failed: %w",
s.workflowID,
st.step.ID(),
st.err,
)
s.mux.Lock()
// when the err handler is defined, the error was handled and should not kill the workflow
if !st.errHandled {
// We need to force failed session status
// because it's not set early enough to pick it up with s.Status()
status = SessionFailed
// pushing step execution error into error queue
// to break worker loop
s.qErr <- st.err
}
s.mux.Unlock()
}
s.log.Debug(
"executed",
logger.Uint64("stateID", st.stateId),
zap.Stringer("status", status),
zap.Error(st.err),
)
err = s.eventHandler(status, st, s)
if err != nil {
err = fmt.Errorf(
"workflow %d state change handler failed: %w",
s.workflowID,
err,
)
s.log.Error(err.Error())
s.err = err
return
}
for _, n := range nxt {
if n.step != nil {
log.Debug("next step queued", logger.Uint64("nextStepId", n.step.ID()))
} else {
log.Debug("next step queued", logger.Uint64("nextStepId", 0))
}
if err = s.enqueue(ctx, n); err != nil {
log.Error("unable to enqueue", zap.Error(err))
return
}
}
}()
case err := <-s.qErr:
s.mux.Lock()
defer s.mux.Unlock()
if err == nil {
// stop worker
return
}
// set final error on session
s.err = err
return
}
}
}
func (s *Session) Cancel() {
s.log.Warn("canceling")
s.qErr <- errCanceled
s.eventHandler(SessionCanceled, nil, s)
}
func (s *Session) Stop() {
s.log.Debug("stopping worker")
s.qErr <- nil
}
func (s *Session) Delayed() bool {
defer s.mux.RUnlock()
s.mux.RLock()
return len(s.delayed) > 0
}
func (s *Session) Prompted() bool {
defer s.mux.RUnlock()
s.mux.RLock()
return len(s.prompted) > 0
}
// Suspended returns true if the workflow has either delayed or prompted steps
func (s *Session) Suspended() bool {
return s.Delayed() || s.Prompted()
}
func (s *Session) queueScheduledSuspended() {
defer s.mux.Unlock()
s.mux.Lock()
for id, sus := range s.delayed {
if !sus.resumeAt.IsZero() && sus.resumeAt.After(*now()) {
continue
}
delete(s.delayed, id)
// Set state input when step is resumed
sus.state.input = &expr.Vars{}
sus.state.input.Set("resumed", true)
sus.state.input.Set("resumeAt", sus.resumeAt)
s.qState <- sus.state
}
}
// executes single step, resolves response and schedule following steps for execution
func (s *Session) exec(ctx context.Context, log *zap.Logger, st *State) (nxt []*State, err error) {
st.created = *now()
defer func() {
reason := recover()
if reason == nil {
return
}
var perr error
// normalize error and set it to state
switch reason := reason.(type) {
case error:
perr = fmt.Errorf("step %d crashed: %w", st.step.ID(), reason)
default:
perr = fmt.Errorf("step %d crashed: %v", st.step.ID(), reason)
}
if s.dumpStacktraceOnPanic {
fmt.Printf("Error: %v\n", perr)
println(string(debug.Stack()))
}
s.qErr <- perr
}()
var (
result ExecResponse
scope = (&expr.Vars{}).MustMerge(st.scope)
currLoop = st.loopCurr()
)
if st.step != nil {
log = log.With(logger.Uint64("stepID", st.step.ID()))
}
{
if currLoop != nil && currLoop.Is(st.step) {
result = currLoop
} else {
// push logger to context but raise the stacktrace level to panic
// to prevent overly verbose traces
ctx = logger.ContextWithValue(ctx, log)
stepCtx := SetContextCallStack(ctx, s.callStack)
result, st.err = st.step.Exec(stepCtx, st.MakeRequest())
if iterator, isIterator := result.(Iterator); isIterator && st.err == nil {
// Exec fn returned an iterator, adding loop to stack
st.newLoop(iterator)
if err = iterator.Start(ctx, scope); err != nil {
return
}
}
}
if st.err != nil {
if st.errHandler == nil {
// no error handler set
return nil, st.err
}
// handling error with error handling
// step set in one of the previous steps
log.Warn("step execution error handled",
logger.Uint64("errorHandlerStepId", st.errHandler.ID()),
zap.Error(st.err),
)
err = setErrorHandlerResultsToScope(scope, st.results, st.err, st.step.ID())
if err != nil {
return nil, err
}
// copy error handler & disable it on state to prevent inf. loop
// in case of another error in the error-handling branch
eh := st.errHandler
st.errHandler = nil
st.errHandled = true
return []*State{st.Next(eh, scope)}, nil
}
switch l := result.(type) {
case Iterator:
st.action = "iterator initialized"
// add looper to state
var (
n Step
)
n, result, st.err = l.Next(ctx, scope)
if st.err != nil {
return nil, st.err
}
if n == nil {
st.next = st.loopEnd()
} else {
st.next = Steps{n}
}
}
log.Debug("step executed", zap.String("resultType", fmt.Sprintf("%T", result)))
switch result := result.(type) {
case *expr.Vars:
// most common (successful) result
// session will continue with configured child steps
st.results = result
scope = scope.MustMerge(st.results)
case *errHandler:
st.action = "error handler initialized"
// this step sets error handling step on current state
// and continues on the current path
st.errHandler = result.handler
st.results = st.results.MustMerge(result.results)
// find step that's not error handler and
// use it for the next step
for _, c := range s.g.Children(st.step) {
if c != st.errHandler {
st.next = Steps{c}
break
}
}
case *loopBreak:
st.action = "loop break"
if currLoop == nil {
return nil, fmt.Errorf("break step not inside a loop")
}
// jump out of the loop
st.next = st.loopEnd()
log.Debug("breaking from iterator")
case *loopContinue:
st.action = "loop continue"
if currLoop == nil {
return nil, fmt.Errorf("continue step not inside a loop")
}
// jump back to iterator
st.next = Steps{currLoop.Iterator()}
log.Debug("continuing with next iteration")
case *partial:
st.action = "partial"
// *partial is returned when step needs to be executed again
// it's used mainly for join gateway step that should be called multiple times (one for each parent path)
return
case *termination:
st.action = "termination"
// terminate all activities, all delayed tasks and exit right away
log.Debug("termination", zap.Int("delayed", len(s.delayed)))
s.mux.Lock()
s.delayed = nil
s.mux.Unlock()
return []*State{FinalState(s, scope)}, nil
case *delayed:
st.action = "delayed"
log.Debug("session delayed", zap.Time("at", result.resumeAt))
result.state = st
s.mux.Lock()
s.delayed[st.stateId] = result
s.mux.Unlock()
return
case *resumed:
st.action = "resumed"
log.Debug("session resumed")
case *prompted:
st.action = "prompted"
if result.ownerId == 0 {
return nil, fmt.Errorf("without an owner")
}
result.state = st
s.mux.Lock()
s.prompted[st.stateId] = result
s.mux.Unlock()
return
case Steps:
st.action = "next-steps"
// session continues with set of specified steps
// steps MUST be configured in a graph as step's children
st.next = result
case Step:
st.action = "next-step"
// session continues with a specified step
// step MUST be configured in a graph as step's child
st.next = Steps{result}
default:
return nil, fmt.Errorf("unknown exec response type %T", result)
}
}
if len(st.next) == 0 {
// step's exec did not return next steps (only gateway steps, iterators and loops controls usually do that)
//
// rely on graph and get next (children) steps from there
st.next = s.g.Children(st.step)
} else {
// children returned from step's exec
// do a quick sanity check
cc := s.g.Children(st.step)
if len(cc) > 0 && !cc.Contains(st.next...) {
return nil, fmt.Errorf("inconsistent relationship")
}
}
if currLoop != nil && len(st.next) == 0 {
// gracefully handling last step of iteration branch
// that does not point back to the iterator step
st.next = Steps{currLoop.Iterator()}
log.Debug("last step in iteration branch, going back", logger.Uint64("backStepId", st.next[0].ID()))
}
if len(st.next) == 0 {
log.Debug("zero paths, finalizing")
// using state to transport results and complete the worker loop
return []*State{FinalState(s, scope)}, nil
}
nxt = make([]*State, len(st.next))
for i, step := range st.next {
nn := st.Next(step, scope)
if err = s.canEnqueue(nn); err != nil {
log.Error("unable to queue", zap.Error(err))
return
}
nxt[i] = nn
}
return nxt, nil
}
func SetWorkerIntervalSuspended(i time.Duration) SessionOpt {
return func(s *Session) {
s.workerIntervalSuspended = i
}
}
func SetWorkerIntervalWaiter(i time.Duration) SessionOpt {
return func(s *Session) {
s.workerIntervalWaiter = i
}
}
func SetHandler(fn StateChangeHandler) SessionOpt {
return func(s *Session) {
s.eventHandler = fn
}
}
func SetWorkflowID(workflowID uint64) SessionOpt {
return func(s *Session) {
s.workflowID = workflowID
}
}
func SetLogger(log *zap.Logger) SessionOpt {
return func(s *Session) {
s.log = log
}
}
func SetDumpStacktraceOnPanic(dump bool) SessionOpt {
return func(s *Session) {
s.dumpStacktraceOnPanic = dump
}
}
func SetCallStack(id ...uint64) SessionOpt {
return func(s *Session) {
s.callStack = id
}
}
func (ss Steps) hash() map[Step]bool {
out := make(map[Step]bool)
for _, s := range ss {
out[s] = true
}
return out
}
func (ss Steps) Contains(steps ...Step) bool {
hash := ss.hash()
for _, s1 := range steps {
if !hash[s1] {
return false
}
}
return true
}
func (ss Steps) IDs() []uint64 {
if len(ss) == 0 {
return nil
}
var ids = make([]uint64, len(ss))
for i := range ss {
ids[i] = ss[i].ID()
}
return ids
}
func SetContextCallStack(ctx context.Context, ss []uint64) context.Context {
return context.WithValue(ctx, callStackCtxKey{}, ss)
}
func GetContextCallStack(ctx context.Context) []uint64 {
v := ctx.Value(callStackCtxKey{})
if v == nil {
return nil
}
return v.([]uint64)
}
func setErrorHandlerResultsToScope(scope *expr.Vars, result *expr.Vars, e error, stepID uint64) (err error) {
var (
ehr = struct {
Error string `json:"error"`
ErrorMessage string `json:"errorMessage"`
ErrorStepID string `json:"errorStepID"`
}{}
)
err = result.Decode(&ehr)
if err != nil {
return
}
if len(ehr.Error) > 0 {
_ = expr.Assign(scope, ehr.Error, expr.Must(expr.NewAny(e)))
}
if len(ehr.ErrorMessage) > 0 {
_ = expr.Assign(scope, ehr.ErrorMessage, expr.Must(expr.NewString(e.Error())))
}
if len(ehr.ErrorStepID) > 0 {
_ = expr.Assign(scope, ehr.ErrorStepID, expr.Must(expr.NewInteger(stepID)))
}
return
}