diff --git a/engine.go b/engine.go index 3a20d5e..f1b89c9 100644 --- a/engine.go +++ b/engine.go @@ -33,8 +33,6 @@ var ( ErrInvalidArg = errors.New("invalid argument") // ErrTemplateCompilation is returned when a template fails to compile. ErrTemplateCompilation = errors.New("template compilation failed") - // ErrFunctionNotFound Function does not exist in script. - ErrFunctionNotFound = errors.New("failed to find function") ) // CallContext is the context that is passed to go functions when called from js. @@ -190,7 +188,7 @@ func (e *Engine) Init(ctx context.Context, data any) error { // RunScript runs the provided script file within the environment initialized by Init. // This is useful for setting up the environment with global variables and functions, // or running code that is not directly related to templating but might setup the environment for templating. -func (e *Engine) RunScript(scriptFile string) error { +func (e *Engine) RunScript(ctx context.Context, scriptFile string) error { if e.vm == nil { return ErrNotInitialized } @@ -200,7 +198,7 @@ func (e *Engine) RunScript(scriptFile string) error { return fmt.Errorf("failed to read script file: %w", err) } - if _, err := e.vm.Run(scriptFile, string(script)); err != nil { + if _, err := e.vm.Run(ctx, scriptFile, string(script)); err != nil { return err } @@ -209,53 +207,39 @@ func (e *Engine) RunScript(scriptFile string) error { // RunFunction will run the named function if it already exists within the environment, for example if it was defined in a script run by RunScript. // The provided args will be passed to the function, and the result will be returned. -func (e *Engine) RunFunction(fnName string, args ...any) (goja.Value, error) { +func (e *Engine) RunFunction(ctx context.Context, fnName string, args ...any) (goja.Value, error) { if e.vm == nil { return nil, ErrNotInitialized } - fn, ok := goja.AssertFunction(e.vm.Get(fnName)) - if !ok { - return nil, fmt.Errorf("%w: %s", ErrFunctionNotFound, fnName) - } - - gojaArgs := make([]goja.Value, len(args)) - for i, arg := range args { - gojaArgs[i] = e.vm.ToValue(arg) - } - val, err := fn(goja.Undefined(), gojaArgs...) - if err != nil { - return nil, err - } - - return val, nil + return e.vm.RunFunction(ctx, fnName, args...) } // TemplateFile runs the provided template file, with the provided data and writes the result to the provided outFile. -func (e *Engine) TemplateFile(templateFile string, outFile string, data any) error { +func (e *Engine) TemplateFile(ctx context.Context, templateFile string, outFile string, data any) error { if e.vm == nil { return ErrNotInitialized } - return e.templator.TemplateFile(e.vm, templateFile, outFile, data) + return e.templator.TemplateFile(ctx, e.vm, templateFile, outFile, data) } // TemplateString runs the provided template file, with the provided data and returns the rendered result. -func (e *Engine) TemplateString(templateFilePath string, data any) (string, error) { +func (e *Engine) TemplateString(ctx context.Context, templateFilePath string, data any) (string, error) { if e.vm == nil { return "", ErrNotInitialized } - return e.templator.TemplateString(e.vm, templateFilePath, data) + return e.templator.TemplateString(ctx, e.vm, templateFilePath, data) } // TemplateStringInput runs the provided template string, with the provided data and returns the rendered result. -func (e *Engine) TemplateStringInput(name, template string, data any) (string, error) { +func (e *Engine) TemplateStringInput(ctx context.Context, name, template string, data any) (string, error) { if e.vm == nil { return "", ErrNotInitialized } - return e.templator.TemplateStringInput(e.vm, name, template, data) + return e.templator.TemplateStringInput(ctx, e.vm, name, template, data) } //nolint:funlen @@ -308,7 +292,7 @@ func (e *Engine) init(ctx context.Context, data any) (*vm.VM, error) { span.End() }() - err = e.templator.TemplateFile(v, templateFile, outFile, data) + err = e.templator.TemplateFile(ctx, v, templateFile, outFile, data) if err != nil { return "", err } @@ -318,7 +302,7 @@ func (e *Engine) init(ctx context.Context, data any) (*vm.VM, error) { }(v) e.templator.TmplFuncs["templateString"] = func(v *vm.VM) func(string, any) (string, error) { return func(templateFile string, data any) (string, error) { - templated, err := e.templator.TemplateString(v, templateFile, data) + templated, err := e.templator.TemplateString(ctx, v, templateFile, data) if err != nil { return "", err } @@ -328,7 +312,7 @@ func (e *Engine) init(ctx context.Context, data any) (*vm.VM, error) { }(v) e.templator.TmplFuncs["templateStringInput"] = func(v *vm.VM) func(string, string, any) (string, error) { return func(name, template string, data any) (string, error) { - templated, err := e.templator.TemplateStringInput(v, name, template, data) + templated, err := e.templator.TemplateStringInput(ctx, v, name, template, data) if err != nil { return "", err } @@ -347,11 +331,11 @@ func (e *Engine) init(ctx context.Context, data any) (*vm.VM, error) { } }(v) - if _, err := v.Run("initCreateComputedContextObject", `function createComputedContextObject() { return {}; }`); err != nil { + if _, err := v.Run(ctx, "initCreateComputedContextObject", `function createComputedContextObject() { return {}; }`); err != nil { return nil, utils.HandleJSError("failed to init createComputedContextObject", err) } - globalComputed, err := v.Run("globalCreateComputedContextObject", `createComputedContextObject();`) + globalComputed, err := v.Run(ctx, "globalCreateComputedContextObject", `createComputedContextObject();`) if err != nil { return nil, utils.HandleJSError("failed to init globalComputed", err) } @@ -396,7 +380,7 @@ func (e *Engine) require(call CallContext) goja.Value { panic(vm.NewGoError(err)) } - if _, err := vm.Run(scriptPath, string(script)); err != nil { + if _, err := vm.Run(call.Ctx, scriptPath, string(script)); err != nil { panic(vm.NewGoError(err)) } diff --git a/engine_integration_test.go b/engine_integration_test.go index a779f21..1796792 100644 --- a/engine_integration_test.go +++ b/engine_integration_test.go @@ -58,7 +58,7 @@ func TestEngine_RunScript_Success(t *testing.T) { }) require.NoError(t, err) - err = e.RunScript("scripts/test.js") + err = e.RunScript(context.Background(), "scripts/test.js") require.NoError(t, err) assert.Empty(t, expectedFiles, "not all expected files were written") diff --git a/internal/template/mocks/template_mock.go b/internal/template/mocks/template_mock.go index 87ecdf5..e63dcef 100644 --- a/internal/template/mocks/template_mock.go +++ b/internal/template/mocks/template_mock.go @@ -5,6 +5,7 @@ package mocks import ( + context "context" reflect "reflect" goja "github.com/dop251/goja" @@ -50,10 +51,10 @@ func (mr *MockVMMockRecorder) Get(arg0 interface{}) *gomock.Call { } // Run mocks base method. -func (m *MockVM) Run(arg0, arg1 string, arg2 ...vm.Option) (goja.Value, error) { +func (m *MockVM) Run(arg0 context.Context, arg1, arg2 string, arg3 ...vm.Option) (goja.Value, error) { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1} - for _, a := range arg2 { + varargs := []interface{}{arg0, arg1, arg2} + for _, a := range arg3 { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "Run", varargs...) @@ -63,9 +64,9 @@ func (m *MockVM) Run(arg0, arg1 string, arg2 ...vm.Option) (goja.Value, error) { } // Run indicates an expected call of Run. -func (mr *MockVMMockRecorder) Run(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { +func (mr *MockVMMockRecorder) Run(arg0, arg1, arg2 interface{}, arg3 ...interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1}, arg2...) + varargs := append([]interface{}{arg0, arg1, arg2}, arg3...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Run", reflect.TypeOf((*MockVM)(nil).Run), varargs...) } diff --git a/internal/template/template.go b/internal/template/template.go index d24876f..c248ae0 100644 --- a/internal/template/template.go +++ b/internal/template/template.go @@ -5,6 +5,7 @@ package template import ( "bytes" + "context" "fmt" "regexp" "strconv" @@ -46,7 +47,7 @@ type tmplContext struct { type VM interface { Get(name string) goja.Value Set(name string, value any) error - Run(name string, src string, opts ...vm.Option) (goja.Value, error) + Run(ctx context.Context, name string, src string, opts ...vm.Option) (goja.Value, error) ToObject(val goja.Value) *goja.Object } @@ -67,8 +68,8 @@ func (t *Templator) SetContextData(contextData any, globalComputed goja.Value) { } // TemplateFile will template a file and write the output to outFile. -func (t *Templator) TemplateFile(vm VM, templateFile, outFile string, inputData any) error { - output, err := t.TemplateString(vm, templateFile, inputData) +func (t *Templator) TemplateFile(ctx context.Context, vm VM, templateFile, outFile string, inputData any) error { + output, err := t.TemplateString(ctx, vm, templateFile, inputData) if err != nil { return err } @@ -97,26 +98,26 @@ func (c *inlineScriptContext) render(call goja.FunctionCall) goja.Value { } // TemplateString will template the provided file and return the output as a string. -func (t *Templator) TemplateString(vm VM, templatePath string, inputData any) (out string, err error) { +func (t *Templator) TemplateString(ctx context.Context, vm VM, templatePath string, inputData any) (out string, err error) { data, err := t.ReadFunc(templatePath) if err != nil { return "", fmt.Errorf("failed to read template file: %w", err) } - return t.TemplateStringInput(vm, templatePath, string(data), inputData) + return t.TemplateStringInput(ctx, vm, templatePath, string(data), inputData) } // TemplateStringInput will template the provided input string and return the output as a string. // //nolint:funlen -func (t *Templator) TemplateStringInput(vm VM, name string, input string, inputData any) (out string, err error) { +func (t *Templator) TemplateStringInput(ctx context.Context, vm VM, name string, input string, inputData any) (out string, err error) { defer func() { if e := recover(); e != nil { err = fmt.Errorf("failed to render template: %s", e) } }() - localComputed, err := vm.Run("localCreateComputedContextObject", `createComputedContextObject();`) + localComputed, err := vm.Run(ctx, "localCreateComputedContextObject", `createComputedContextObject();`) if err != nil { return "", utils.HandleJSError("failed to create local computed context", err) } @@ -133,7 +134,7 @@ func (t *Templator) TemplateStringInput(vm VM, name string, input string, inputD } if numRecursions > 0 { numIterations = numRecursions + 1 - localRecursiveComputed, err = vm.Run("recursiveCreateComputedContextObject", `createComputedContextObject();`) + localRecursiveComputed, err = vm.Run(ctx, "recursiveCreateComputedContextObject", `createComputedContextObject();`) if err != nil { return "", utils.HandleJSError("failed to create recursive computed context", err) } @@ -152,7 +153,7 @@ func (t *Templator) TemplateStringInput(vm VM, name string, input string, inputD return "", fmt.Errorf("failed to set context: %w", err) } - evaluated, replacedLines, err := t.evaluateInlineScripts(vm, name, input) + evaluated, replacedLines, err := t.evaluateInlineScripts(ctx, vm, name, input) if err != nil { return "", err } @@ -194,7 +195,7 @@ func (t *Templator) TemplateStringInput(vm VM, name string, input string, inputD return out, nil } -func (t *Templator) evaluateInlineScripts(vm VM, templatePath, content string) (string, int, error) { +func (t *Templator) evaluateInlineScripts(ctx context.Context, vm VM, templatePath, content string) (string, int, error) { replacedLines := 0 evaluated, err := utils.ReplaceAllStringSubmatchFunc(sjsRegex, content, func(match []string) (string, error) { @@ -203,7 +204,7 @@ func (t *Templator) evaluateInlineScripts(vm VM, templatePath, content string) ( return match[0], nil } - output, err := t.execSJSBlock(vm, match[2], templatePath, findJSBlockLineNumber(content, match[2])) + output, err := t.execSJSBlock(ctx, vm, match[2], templatePath, findJSBlockLineNumber(content, match[2])) if err != nil { return "", err } @@ -219,7 +220,7 @@ func (t *Templator) evaluateInlineScripts(vm VM, templatePath, content string) ( return evaluated, replacedLines, nil } -func (t *Templator) execSJSBlock(v VM, js, templatePath string, jsBlockLineNumber int) (string, error) { +func (t *Templator) execSJSBlock(ctx context.Context, v VM, js, templatePath string, jsBlockLineNumber int) (string, error) { currentRender := v.Get("render") c := newInlineScriptContext() @@ -227,7 +228,7 @@ func (t *Templator) execSJSBlock(v VM, js, templatePath string, jsBlockLineNumbe return "", fmt.Errorf("failed to set render function: %w", err) } - if _, err := v.Run(templatePath, js, vm.WithStartingLineNumber(jsBlockLineNumber)); err != nil { + if _, err := v.Run(ctx, templatePath, js, vm.WithStartingLineNumber(jsBlockLineNumber)); err != nil { return "", fmt.Errorf("failed to run inline script in %s:\n```sjs\n%ssjs```\n%w", templatePath, js, err) } diff --git a/internal/template/template_test.go b/internal/template/template_test.go index ed211e2..028c863 100644 --- a/internal/template/template_test.go +++ b/internal/template/template_test.go @@ -1,6 +1,7 @@ package template_test import ( + "context" "testing" "github.com/dop251/goja" @@ -47,7 +48,7 @@ func TestTemplator_TemplateFile_Success(t *testing.T) { vm := mocks.NewMockVM(ctrl) - context := &template.Context{ + ctx := &template.Context{ Global: tt.fields.contextData, GlobalComputed: goja.Undefined(), Local: tt.args.inputData, @@ -55,11 +56,11 @@ func TestTemplator_TemplateFile_Success(t *testing.T) { RecursiveComputed: goja.Undefined(), } o := goja.New() - contextVal := o.ToValue(context) + contextVal := o.ToValue(ctx) - vm.EXPECT().Run("localCreateComputedContextObject", `createComputedContextObject();`).Return(goja.Undefined(), nil).Times(1) + vm.EXPECT().Run(context.Background(), "localCreateComputedContextObject", `createComputedContextObject();`).Return(goja.Undefined(), nil).Times(1) vm.EXPECT().Get("context").Return(goja.Undefined()).Times(2) - vm.EXPECT().Set("context", context).Return(nil).Times(1) + vm.EXPECT().Set("context", ctx).Return(nil).Times(1) vm.EXPECT().Get("context").Return(contextVal).Times(1) vm.EXPECT().ToObject(contextVal).Return(contextVal.ToObject(o)).Times(1) vm.EXPECT().Set("context", goja.Undefined()).Return(nil).Times(1) @@ -76,7 +77,7 @@ func TestTemplator_TemplateFile_Success(t *testing.T) { }, } tr.SetContextData(tt.fields.contextData, goja.Undefined()) - err := tr.TemplateFile(vm, tt.args.templatePath, tt.args.outFile, tt.args.inputData) + err := tr.TemplateFile(context.Background(), vm, tt.args.templatePath, tt.args.outFile, tt.args.inputData) require.NoError(t, err) }) } @@ -147,7 +148,7 @@ func TestTemplator_TemplateString_Success(t *testing.T) { vm := mocks.NewMockVM(ctrl) - context := &template.Context{ + ctx := &template.Context{ Global: tt.fields.contextData, GlobalComputed: goja.Undefined(), Local: tt.args.inputData, @@ -155,16 +156,16 @@ func TestTemplator_TemplateString_Success(t *testing.T) { RecursiveComputed: goja.Undefined(), } o := goja.New() - contextVal := o.ToValue(context) + contextVal := o.ToValue(ctx) - vm.EXPECT().Run("localCreateComputedContextObject", `createComputedContextObject();`).Return(goja.Undefined(), nil).Times(1) + vm.EXPECT().Run(context.Background(), "localCreateComputedContextObject", `createComputedContextObject();`).Return(goja.Undefined(), nil).Times(1) vm.EXPECT().Get("context").Return(goja.Undefined()).Times(2) - vm.EXPECT().Set("context", context).Return(nil).Times(1) + vm.EXPECT().Set("context", ctx).Return(nil).Times(1) if tt.fields.includedJS != "" { vm.EXPECT().Get("render").Return(goja.Undefined()).Times(1) vm.EXPECT().Set("render", gomock.Any()).Return(nil).Times(1) - vm.EXPECT().Run("test", tt.fields.includedJS, gomock.Any()).Return(goja.Undefined(), nil).Times(1) + vm.EXPECT().Run(context.Background(), "test", tt.fields.includedJS, gomock.Any()).Return(goja.Undefined(), nil).Times(1) vm.EXPECT().Set("render", goja.Undefined()).Return(nil).Times(1) } @@ -180,7 +181,7 @@ func TestTemplator_TemplateString_Success(t *testing.T) { TmplFuncs: tt.fields.tmplFuncs, } tr.SetContextData(tt.fields.contextData, goja.Undefined()) - out, err := tr.TemplateString(vm, tt.args.templatePath, tt.args.inputData) + out, err := tr.TemplateString(context.Background(), vm, tt.args.templatePath, tt.args.inputData) require.NoError(t, err) assert.Equal(t, tt.wantOut, out) }) diff --git a/internal/vm/vm.go b/internal/vm/vm.go index d8f8a5b..136ecac 100644 --- a/internal/vm/vm.go +++ b/internal/vm/vm.go @@ -2,11 +2,13 @@ package vm import ( + "context" "errors" "fmt" "regexp" "strconv" "strings" + "time" "github.com/dop251/goja" "github.com/dop251/goja_nodejs/console" @@ -24,6 +26,12 @@ var ( ErrCompilation = errors.New("script compilation failed") // ErrRuntime is returned when a script fails to run. ErrRuntime = errors.New("script runtime failure") + // ErrFunctionNotFound Function does not exist in script. + ErrFunctionNotFound = errors.New("failed to find function") +) + +const ( + sleepThreshold = 50 * time.Millisecond ) var lineNumberRegex = regexp.MustCompile(` \(*([^ ]+):([0-9]+):([0-9]+)\([0-9]+\)`) @@ -69,7 +77,7 @@ func New() (*VM, error) { } // Run runs a script in the VM. -func (v *VM) Run(name string, src string, opts ...Option) (goja.Value, error) { +func (v *VM) Run(ctx context.Context, name string, src string, opts ...Option) (goja.Value, error) { options := &Options{} for _, opt := range opts { opt(options) @@ -91,7 +99,24 @@ func (v *VM) Run(name string, src string, opts ...Option) (goja.Value, error) { } } + done := make(chan bool) + + go func(done chan bool) { + running := true + for running { + select { + case <-ctx.Done(): + v.Runtime.Interrupt("halt") + case <-done: + running = false + default: + time.Sleep(sleepThreshold) + } + } + }(done) + res, err := v.Runtime.RunProgram(p.prog) + done <- true if err == nil { return res, nil } @@ -107,6 +132,44 @@ func (v *VM) Run(name string, src string, opts ...Option) (goja.Value, error) { return nil, fmt.Errorf("failed to run script %s: %w", fixedStackTrace, ErrRuntime) } +// RunFunction will run the named function if it already exists within the environment, for example if it was defined in a script run by RunScript. +// The provided args will be passed to the function, and the result will be returned. +func (v *VM) RunFunction(ctx context.Context, fnName string, args ...any) (goja.Value, error) { + fn, ok := goja.AssertFunction(v.Get(fnName)) + if !ok { + return nil, fmt.Errorf("%w: %s", ErrFunctionNotFound, fnName) + } + + gojaArgs := make([]goja.Value, len(args)) + for i, arg := range args { + gojaArgs[i] = v.ToValue(arg) + } + + done := make(chan bool) + + go func(done chan bool) { + running := true + for running { + select { + case <-ctx.Done(): + v.Runtime.Interrupt("halt") + case <-done: + running = false + default: + time.Sleep(sleepThreshold) + } + } + }(done) + + val, err := fn(goja.Undefined(), gojaArgs...) + done <- true + if err != nil { + return nil, err + } + + return val, nil +} + // ToObject converts a value to an object. func (v *VM) ToObject(val goja.Value) *goja.Object { return val.ToObject(v.Runtime) diff --git a/internal/vm/vm_test.go b/internal/vm/vm_test.go index b6eaae3..d3e41ab 100644 --- a/internal/vm/vm_test.go +++ b/internal/vm/vm_test.go @@ -1,6 +1,7 @@ package vm_test import ( + "context" "testing" "github.com/speakeasy-api/easytemplate/internal/vm" @@ -21,6 +22,6 @@ function test(input: Test): string { test({ Name: "test" });` - _, err = v.Run("test", typeScript) + _, err = v.Run(context.Background(), "test", typeScript) assert.Equal(t, "failed to run script Error: test error\n\tat test (test:5:7:*(3))\n\tat test:8:5:*(6)\n: script runtime failure", err.Error()) } diff --git a/templating.go b/templating.go index 676faed..a181aad 100644 --- a/templating.go +++ b/templating.go @@ -19,7 +19,7 @@ func (e *Engine) templateFileJS(call CallContext) goja.Value { )) defer span.End() - if err := e.templator.TemplateFile(call.VM, templateFile, outFile, inputData); err != nil { + if err := e.templator.TemplateFile(call.Ctx, call.VM, templateFile, outFile, inputData); err != nil { span.RecordError(err) span.SetStatus(codes.Error, err.Error()) span.End() @@ -40,7 +40,7 @@ func (e *Engine) templateStringJS(call CallContext) goja.Value { )) defer span.End() - output, err := e.templator.TemplateString(call.VM, templateFile, inputData) + output, err := e.templator.TemplateString(call.Ctx, call.VM, templateFile, inputData) if err != nil { span.RecordError(err) span.SetStatus(codes.Error, err.Error()) @@ -63,7 +63,7 @@ func (e *Engine) templateStringInputJS(call CallContext) goja.Value { )) defer span.End() - output, err := e.templator.TemplateStringInput(call.VM, name, input, inputData) + output, err := e.templator.TemplateStringInput(call.Ctx, call.VM, name, input, inputData) if err != nil { span.RecordError(err) span.SetStatus(codes.Error, err.Error())