Skip to content

Commit

Permalink
Merge pull request #18 from speakeasy-api/add-interrupts
Browse files Browse the repository at this point in the history
  • Loading branch information
TristanSpeakEasy committed Mar 28, 2024
2 parents 4ec8f0c + ea09f86 commit 33610e2
Show file tree
Hide file tree
Showing 8 changed files with 118 additions and 67 deletions.
48 changes: 16 additions & 32 deletions engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@ var (
ErrInvalidArg = errors.New("invalid argument")
// ErrTemplateCompilation is returned when a template fails to compile.
ErrTemplateCompilation = errors.New("template compilation failed")
// ErrFunctionNotFound Function does not exist in script.
ErrFunctionNotFound = errors.New("failed to find function")
)

// CallContext is the context that is passed to go functions when called from js.
Expand Down Expand Up @@ -190,7 +188,7 @@ func (e *Engine) Init(ctx context.Context, data any) error {
// RunScript runs the provided script file within the environment initialized by Init.
// This is useful for setting up the environment with global variables and functions,
// or running code that is not directly related to templating but might setup the environment for templating.
func (e *Engine) RunScript(scriptFile string) error {
func (e *Engine) RunScript(ctx context.Context, scriptFile string) error {
if e.vm == nil {
return ErrNotInitialized
}
Expand All @@ -200,7 +198,7 @@ func (e *Engine) RunScript(scriptFile string) error {
return fmt.Errorf("failed to read script file: %w", err)
}

if _, err := e.vm.Run(scriptFile, string(script)); err != nil {
if _, err := e.vm.Run(ctx, scriptFile, string(script)); err != nil {
return err
}

Expand All @@ -209,53 +207,39 @@ func (e *Engine) RunScript(scriptFile string) error {

// RunFunction will run the named function if it already exists within the environment, for example if it was defined in a script run by RunScript.
// The provided args will be passed to the function, and the result will be returned.
func (e *Engine) RunFunction(fnName string, args ...any) (goja.Value, error) {
func (e *Engine) RunFunction(ctx context.Context, fnName string, args ...any) (goja.Value, error) {
if e.vm == nil {
return nil, ErrNotInitialized
}

fn, ok := goja.AssertFunction(e.vm.Get(fnName))
if !ok {
return nil, fmt.Errorf("%w: %s", ErrFunctionNotFound, fnName)
}

gojaArgs := make([]goja.Value, len(args))
for i, arg := range args {
gojaArgs[i] = e.vm.ToValue(arg)
}
val, err := fn(goja.Undefined(), gojaArgs...)
if err != nil {
return nil, err
}

return val, nil
return e.vm.RunFunction(ctx, fnName, args...)
}

// TemplateFile runs the provided template file, with the provided data and writes the result to the provided outFile.
func (e *Engine) TemplateFile(templateFile string, outFile string, data any) error {
func (e *Engine) TemplateFile(ctx context.Context, templateFile string, outFile string, data any) error {
if e.vm == nil {
return ErrNotInitialized
}

return e.templator.TemplateFile(e.vm, templateFile, outFile, data)
return e.templator.TemplateFile(ctx, e.vm, templateFile, outFile, data)
}

// TemplateString runs the provided template file, with the provided data and returns the rendered result.
func (e *Engine) TemplateString(templateFilePath string, data any) (string, error) {
func (e *Engine) TemplateString(ctx context.Context, templateFilePath string, data any) (string, error) {
if e.vm == nil {
return "", ErrNotInitialized
}

return e.templator.TemplateString(e.vm, templateFilePath, data)
return e.templator.TemplateString(ctx, e.vm, templateFilePath, data)
}

// TemplateStringInput runs the provided template string, with the provided data and returns the rendered result.
func (e *Engine) TemplateStringInput(name, template string, data any) (string, error) {
func (e *Engine) TemplateStringInput(ctx context.Context, name, template string, data any) (string, error) {
if e.vm == nil {
return "", ErrNotInitialized
}

return e.templator.TemplateStringInput(e.vm, name, template, data)
return e.templator.TemplateStringInput(ctx, e.vm, name, template, data)
}

//nolint:funlen
Expand Down Expand Up @@ -308,7 +292,7 @@ func (e *Engine) init(ctx context.Context, data any) (*vm.VM, error) {
span.End()
}()

err = e.templator.TemplateFile(v, templateFile, outFile, data)
err = e.templator.TemplateFile(ctx, v, templateFile, outFile, data)
if err != nil {
return "", err
}
Expand All @@ -318,7 +302,7 @@ func (e *Engine) init(ctx context.Context, data any) (*vm.VM, error) {
}(v)
e.templator.TmplFuncs["templateString"] = func(v *vm.VM) func(string, any) (string, error) {
return func(templateFile string, data any) (string, error) {
templated, err := e.templator.TemplateString(v, templateFile, data)
templated, err := e.templator.TemplateString(ctx, v, templateFile, data)
if err != nil {
return "", err
}
Expand All @@ -328,7 +312,7 @@ func (e *Engine) init(ctx context.Context, data any) (*vm.VM, error) {
}(v)
e.templator.TmplFuncs["templateStringInput"] = func(v *vm.VM) func(string, string, any) (string, error) {
return func(name, template string, data any) (string, error) {
templated, err := e.templator.TemplateStringInput(v, name, template, data)
templated, err := e.templator.TemplateStringInput(ctx, v, name, template, data)
if err != nil {
return "", err
}
Expand All @@ -347,11 +331,11 @@ func (e *Engine) init(ctx context.Context, data any) (*vm.VM, error) {
}
}(v)

if _, err := v.Run("initCreateComputedContextObject", `function createComputedContextObject() { return {}; }`); err != nil {
if _, err := v.Run(ctx, "initCreateComputedContextObject", `function createComputedContextObject() { return {}; }`); err != nil {
return nil, utils.HandleJSError("failed to init createComputedContextObject", err)
}

globalComputed, err := v.Run("globalCreateComputedContextObject", `createComputedContextObject();`)
globalComputed, err := v.Run(ctx, "globalCreateComputedContextObject", `createComputedContextObject();`)
if err != nil {
return nil, utils.HandleJSError("failed to init globalComputed", err)
}
Expand Down Expand Up @@ -396,7 +380,7 @@ func (e *Engine) require(call CallContext) goja.Value {
panic(vm.NewGoError(err))
}

if _, err := vm.Run(scriptPath, string(script)); err != nil {
if _, err := vm.Run(call.Ctx, scriptPath, string(script)); err != nil {
panic(vm.NewGoError(err))
}

Expand Down
2 changes: 1 addition & 1 deletion engine_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func TestEngine_RunScript_Success(t *testing.T) {
})
require.NoError(t, err)

err = e.RunScript("scripts/test.js")
err = e.RunScript(context.Background(), "scripts/test.js")
require.NoError(t, err)

assert.Empty(t, expectedFiles, "not all expected files were written")
Expand Down
11 changes: 6 additions & 5 deletions internal/template/mocks/template_mock.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

27 changes: 14 additions & 13 deletions internal/template/template.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package template

import (
"bytes"
"context"
"fmt"
"regexp"
"strconv"
Expand Down Expand Up @@ -46,7 +47,7 @@ type tmplContext struct {
type VM interface {
Get(name string) goja.Value
Set(name string, value any) error
Run(name string, src string, opts ...vm.Option) (goja.Value, error)
Run(ctx context.Context, name string, src string, opts ...vm.Option) (goja.Value, error)
ToObject(val goja.Value) *goja.Object
}

Expand All @@ -67,8 +68,8 @@ func (t *Templator) SetContextData(contextData any, globalComputed goja.Value) {
}

// TemplateFile will template a file and write the output to outFile.
func (t *Templator) TemplateFile(vm VM, templateFile, outFile string, inputData any) error {
output, err := t.TemplateString(vm, templateFile, inputData)
func (t *Templator) TemplateFile(ctx context.Context, vm VM, templateFile, outFile string, inputData any) error {
output, err := t.TemplateString(ctx, vm, templateFile, inputData)
if err != nil {
return err
}
Expand Down Expand Up @@ -97,26 +98,26 @@ func (c *inlineScriptContext) render(call goja.FunctionCall) goja.Value {
}

// TemplateString will template the provided file and return the output as a string.
func (t *Templator) TemplateString(vm VM, templatePath string, inputData any) (out string, err error) {
func (t *Templator) TemplateString(ctx context.Context, vm VM, templatePath string, inputData any) (out string, err error) {
data, err := t.ReadFunc(templatePath)
if err != nil {
return "", fmt.Errorf("failed to read template file: %w", err)
}

return t.TemplateStringInput(vm, templatePath, string(data), inputData)
return t.TemplateStringInput(ctx, vm, templatePath, string(data), inputData)
}

// TemplateStringInput will template the provided input string and return the output as a string.
//
//nolint:funlen
func (t *Templator) TemplateStringInput(vm VM, name string, input string, inputData any) (out string, err error) {
func (t *Templator) TemplateStringInput(ctx context.Context, vm VM, name string, input string, inputData any) (out string, err error) {
defer func() {
if e := recover(); e != nil {
err = fmt.Errorf("failed to render template: %s", e)
}
}()

localComputed, err := vm.Run("localCreateComputedContextObject", `createComputedContextObject();`)
localComputed, err := vm.Run(ctx, "localCreateComputedContextObject", `createComputedContextObject();`)
if err != nil {
return "", utils.HandleJSError("failed to create local computed context", err)
}
Expand All @@ -133,7 +134,7 @@ func (t *Templator) TemplateStringInput(vm VM, name string, input string, inputD
}
if numRecursions > 0 {
numIterations = numRecursions + 1
localRecursiveComputed, err = vm.Run("recursiveCreateComputedContextObject", `createComputedContextObject();`)
localRecursiveComputed, err = vm.Run(ctx, "recursiveCreateComputedContextObject", `createComputedContextObject();`)
if err != nil {
return "", utils.HandleJSError("failed to create recursive computed context", err)
}
Expand All @@ -152,7 +153,7 @@ func (t *Templator) TemplateStringInput(vm VM, name string, input string, inputD
return "", fmt.Errorf("failed to set context: %w", err)
}

evaluated, replacedLines, err := t.evaluateInlineScripts(vm, name, input)
evaluated, replacedLines, err := t.evaluateInlineScripts(ctx, vm, name, input)
if err != nil {
return "", err
}
Expand Down Expand Up @@ -194,7 +195,7 @@ func (t *Templator) TemplateStringInput(vm VM, name string, input string, inputD
return out, nil
}

func (t *Templator) evaluateInlineScripts(vm VM, templatePath, content string) (string, int, error) {
func (t *Templator) evaluateInlineScripts(ctx context.Context, vm VM, templatePath, content string) (string, int, error) {
replacedLines := 0

evaluated, err := utils.ReplaceAllStringSubmatchFunc(sjsRegex, content, func(match []string) (string, error) {
Expand All @@ -203,7 +204,7 @@ func (t *Templator) evaluateInlineScripts(vm VM, templatePath, content string) (
return match[0], nil
}

output, err := t.execSJSBlock(vm, match[2], templatePath, findJSBlockLineNumber(content, match[2]))
output, err := t.execSJSBlock(ctx, vm, match[2], templatePath, findJSBlockLineNumber(content, match[2]))
if err != nil {
return "", err
}
Expand All @@ -219,15 +220,15 @@ func (t *Templator) evaluateInlineScripts(vm VM, templatePath, content string) (
return evaluated, replacedLines, nil
}

func (t *Templator) execSJSBlock(v VM, js, templatePath string, jsBlockLineNumber int) (string, error) {
func (t *Templator) execSJSBlock(ctx context.Context, v VM, js, templatePath string, jsBlockLineNumber int) (string, error) {
currentRender := v.Get("render")

c := newInlineScriptContext()
if err := v.Set("render", c.render); err != nil {
return "", fmt.Errorf("failed to set render function: %w", err)
}

if _, err := v.Run(templatePath, js, vm.WithStartingLineNumber(jsBlockLineNumber)); err != nil {
if _, err := v.Run(ctx, templatePath, js, vm.WithStartingLineNumber(jsBlockLineNumber)); err != nil {
return "", fmt.Errorf("failed to run inline script in %s:\n```sjs\n%ssjs```\n%w", templatePath, js, err)
}

Expand Down
23 changes: 12 additions & 11 deletions internal/template/template_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package template_test

import (
"context"
"testing"

"github.com/dop251/goja"
Expand Down Expand Up @@ -47,19 +48,19 @@ func TestTemplator_TemplateFile_Success(t *testing.T) {

vm := mocks.NewMockVM(ctrl)

context := &template.Context{
ctx := &template.Context{
Global: tt.fields.contextData,
GlobalComputed: goja.Undefined(),
Local: tt.args.inputData,
LocalComputed: goja.Undefined(),
RecursiveComputed: goja.Undefined(),
}
o := goja.New()
contextVal := o.ToValue(context)
contextVal := o.ToValue(ctx)

vm.EXPECT().Run("localCreateComputedContextObject", `createComputedContextObject();`).Return(goja.Undefined(), nil).Times(1)
vm.EXPECT().Run(context.Background(), "localCreateComputedContextObject", `createComputedContextObject();`).Return(goja.Undefined(), nil).Times(1)
vm.EXPECT().Get("context").Return(goja.Undefined()).Times(2)
vm.EXPECT().Set("context", context).Return(nil).Times(1)
vm.EXPECT().Set("context", ctx).Return(nil).Times(1)
vm.EXPECT().Get("context").Return(contextVal).Times(1)
vm.EXPECT().ToObject(contextVal).Return(contextVal.ToObject(o)).Times(1)
vm.EXPECT().Set("context", goja.Undefined()).Return(nil).Times(1)
Expand All @@ -76,7 +77,7 @@ func TestTemplator_TemplateFile_Success(t *testing.T) {
},
}
tr.SetContextData(tt.fields.contextData, goja.Undefined())
err := tr.TemplateFile(vm, tt.args.templatePath, tt.args.outFile, tt.args.inputData)
err := tr.TemplateFile(context.Background(), vm, tt.args.templatePath, tt.args.outFile, tt.args.inputData)
require.NoError(t, err)
})
}
Expand Down Expand Up @@ -147,24 +148,24 @@ func TestTemplator_TemplateString_Success(t *testing.T) {

vm := mocks.NewMockVM(ctrl)

context := &template.Context{
ctx := &template.Context{
Global: tt.fields.contextData,
GlobalComputed: goja.Undefined(),
Local: tt.args.inputData,
LocalComputed: goja.Undefined(),
RecursiveComputed: goja.Undefined(),
}
o := goja.New()
contextVal := o.ToValue(context)
contextVal := o.ToValue(ctx)

vm.EXPECT().Run("localCreateComputedContextObject", `createComputedContextObject();`).Return(goja.Undefined(), nil).Times(1)
vm.EXPECT().Run(context.Background(), "localCreateComputedContextObject", `createComputedContextObject();`).Return(goja.Undefined(), nil).Times(1)
vm.EXPECT().Get("context").Return(goja.Undefined()).Times(2)
vm.EXPECT().Set("context", context).Return(nil).Times(1)
vm.EXPECT().Set("context", ctx).Return(nil).Times(1)

if tt.fields.includedJS != "" {
vm.EXPECT().Get("render").Return(goja.Undefined()).Times(1)
vm.EXPECT().Set("render", gomock.Any()).Return(nil).Times(1)
vm.EXPECT().Run("test", tt.fields.includedJS, gomock.Any()).Return(goja.Undefined(), nil).Times(1)
vm.EXPECT().Run(context.Background(), "test", tt.fields.includedJS, gomock.Any()).Return(goja.Undefined(), nil).Times(1)
vm.EXPECT().Set("render", goja.Undefined()).Return(nil).Times(1)
}

Expand All @@ -180,7 +181,7 @@ func TestTemplator_TemplateString_Success(t *testing.T) {
TmplFuncs: tt.fields.tmplFuncs,
}
tr.SetContextData(tt.fields.contextData, goja.Undefined())
out, err := tr.TemplateString(vm, tt.args.templatePath, tt.args.inputData)
out, err := tr.TemplateString(context.Background(), vm, tt.args.templatePath, tt.args.inputData)
require.NoError(t, err)
assert.Equal(t, tt.wantOut, out)
})
Expand Down
Loading

0 comments on commit 33610e2

Please sign in to comment.