diff --git a/pkg/composableschemadsl/compiler/compiler.go b/pkg/composableschemadsl/compiler/compiler.go index e238dbe4a8..11a1cc5aca 100644 --- a/pkg/composableschemadsl/compiler/compiler.go +++ b/pkg/composableschemadsl/compiler/compiler.go @@ -86,13 +86,31 @@ type Option func(*config) type ObjectPrefixOption func(*config) +type compilationContext struct { + // The set of definition names that we've seen as we compile. + // If these collide we throw an error. + existingNames *mapz.Set[string] + // The global set of files we've visited in the import process. + // If these collide we short circuit, preventing duplicate imports. + globallyVisitedFiles *mapz.Set[string] + // The set of files that we've visited on a particular leg of the recursion. + // This allows for detection of circular imports. + // NOTE: This depends on an assumption that a depth-first search will always + // find a cycle, even if we're otherwise marking globally visited nodes. + locallyVisitedFiles *mapz.Set[string] +} + // Compile compilers the input schema into a set of namespace definition protos. func Compile(schema InputSchema, prefix ObjectPrefixOption, opts ...Option) (*CompiledSchema, error) { - names := mapz.NewSet[string]() - return compileImpl(schema, names, prefix, opts...) + cctx := compilationContext{ + existingNames: mapz.NewSet[string](), + globallyVisitedFiles: mapz.NewSet[string](), + locallyVisitedFiles: mapz.NewSet[string](), + } + return compileImpl(schema, cctx, prefix, opts...) } -func compileImpl(schema InputSchema, existingNames *mapz.Set[string], prefix ObjectPrefixOption, opts ...Option) (*CompiledSchema, error) { +func compileImpl(schema InputSchema, cctx compilationContext, prefix ObjectPrefixOption, opts ...Option) (*CompiledSchema, error) { cfg := &config{} prefix(cfg) // required option @@ -109,12 +127,14 @@ func compileImpl(schema InputSchema, existingNames *mapz.Set[string], prefix Obj } compiled, err := translate(translationContext{ - objectTypePrefix: cfg.objectTypePrefix, - mapper: mapper, - schemaString: schema.SchemaString, - skipValidate: cfg.skipValidation, - sourceFolder: cfg.sourceFolder, - existingNames: existingNames, + objectTypePrefix: cfg.objectTypePrefix, + mapper: mapper, + schemaString: schema.SchemaString, + skipValidate: cfg.skipValidation, + sourceFolder: cfg.sourceFolder, + existingNames: cctx.existingNames, + locallyVisitedFiles: cctx.locallyVisitedFiles, + globallyVisitedFiles: cctx.globallyVisitedFiles, }, root) if err != nil { var errorWithNode errorWithNode diff --git a/pkg/composableschemadsl/compiler/importer-test/circular-import/expected.zed b/pkg/composableschemadsl/compiler/importer-test/circular-import/expected.zed new file mode 100644 index 0000000000..ca32874a16 --- /dev/null +++ b/pkg/composableschemadsl/compiler/importer-test/circular-import/expected.zed @@ -0,0 +1,8 @@ +definition user {} + +definition persona {} + +definition resource { + relation viewer: user + permission view = viewer +} \ No newline at end of file diff --git a/pkg/composableschemadsl/compiler/importer-test/circular-import/root.zed b/pkg/composableschemadsl/compiler/importer-test/circular-import/root.zed new file mode 100644 index 0000000000..13c9b0fbaa --- /dev/null +++ b/pkg/composableschemadsl/compiler/importer-test/circular-import/root.zed @@ -0,0 +1,6 @@ +from .subjects import user + +definition resource { + relation viewer: user + permission view = viewer +} diff --git a/pkg/composableschemadsl/compiler/importer-test/circular-import/subjects.zed b/pkg/composableschemadsl/compiler/importer-test/circular-import/subjects.zed new file mode 100644 index 0000000000..2b6888b9e3 --- /dev/null +++ b/pkg/composableschemadsl/compiler/importer-test/circular-import/subjects.zed @@ -0,0 +1,3 @@ +from .user import user + +definition persona {} diff --git a/pkg/composableschemadsl/compiler/importer-test/circular-import/user.zed b/pkg/composableschemadsl/compiler/importer-test/circular-import/user.zed new file mode 100644 index 0000000000..0547015612 --- /dev/null +++ b/pkg/composableschemadsl/compiler/importer-test/circular-import/user.zed @@ -0,0 +1,3 @@ +from .subjects import persona + +definition user {} diff --git a/pkg/composableschemadsl/compiler/importer-test/diamond-shaped/expected.zed b/pkg/composableschemadsl/compiler/importer-test/diamond-shaped/expected.zed new file mode 100644 index 0000000000..ca32874a16 --- /dev/null +++ b/pkg/composableschemadsl/compiler/importer-test/diamond-shaped/expected.zed @@ -0,0 +1,8 @@ +definition user {} + +definition persona {} + +definition resource { + relation viewer: user + permission view = viewer +} \ No newline at end of file diff --git a/pkg/composableschemadsl/compiler/importer-test/diamond-shaped/left.zed b/pkg/composableschemadsl/compiler/importer-test/diamond-shaped/left.zed new file mode 100644 index 0000000000..3d52abc77f --- /dev/null +++ b/pkg/composableschemadsl/compiler/importer-test/diamond-shaped/left.zed @@ -0,0 +1 @@ +from .subjects import user diff --git a/pkg/composableschemadsl/compiler/importer-test/diamond-shaped/right.zed b/pkg/composableschemadsl/compiler/importer-test/diamond-shaped/right.zed new file mode 100644 index 0000000000..ee675dc7a9 --- /dev/null +++ b/pkg/composableschemadsl/compiler/importer-test/diamond-shaped/right.zed @@ -0,0 +1 @@ +from .subjects import persona diff --git a/pkg/composableschemadsl/compiler/importer-test/diamond-shaped/root.zed b/pkg/composableschemadsl/compiler/importer-test/diamond-shaped/root.zed new file mode 100644 index 0000000000..0ade575e6d --- /dev/null +++ b/pkg/composableschemadsl/compiler/importer-test/diamond-shaped/root.zed @@ -0,0 +1,7 @@ +from .left import user +from .right import persona + +definition resource { + relation viewer: user + permission view = viewer +} diff --git a/pkg/composableschemadsl/compiler/importer-test/diamond-shaped/subjects.zed b/pkg/composableschemadsl/compiler/importer-test/diamond-shaped/subjects.zed new file mode 100644 index 0000000000..5e3fdcaf3e --- /dev/null +++ b/pkg/composableschemadsl/compiler/importer-test/diamond-shaped/subjects.zed @@ -0,0 +1,2 @@ +definition user {} +definition persona {} diff --git a/pkg/composableschemadsl/compiler/importer.go b/pkg/composableschemadsl/compiler/importer.go index 4aad96ccc9..0cb6c6162c 100644 --- a/pkg/composableschemadsl/compiler/importer.go +++ b/pkg/composableschemadsl/compiler/importer.go @@ -13,38 +13,75 @@ import ( ) type importContext struct { - pathSegments []string - sourceFolder string - names *mapz.Set[string] + pathSegments []string + sourceFolder string + names *mapz.Set[string] + locallyVisitedFiles *mapz.Set[string] + globallyVisitedFiles *mapz.Set[string] } const SchemaFileSuffix = ".zed" +type ErrCircularImport struct { + error + filePath string +} + func importFile(importContext importContext) (*CompiledSchema, error) { relativeFilepath := constructFilePath(importContext.pathSegments) filePath := path.Join(importContext.sourceFolder, relativeFilepath) newSourceFolder := filepath.Dir(filePath) - var schemaBytes []byte + currentLocallyVisitedFiles := importContext.locallyVisitedFiles.Copy() + + if ok := currentLocallyVisitedFiles.Add(filePath); !ok { + // If we've already visited the file on this particular branch walk, it's + // a circular import issue. + return nil, &ErrCircularImport{ + error: fmt.Errorf("circular import detected: %s has been visited on this branch", filePath), + filePath: filePath, + } + } + + if ok := importContext.globallyVisitedFiles.Add(filePath); !ok { + // If the file has already been visited, we short-circuit the import process + // by not reading the schema file in and compiling a schema with an empty string. + // This prevents duplicate definitions from ending up in the output, as well + // as preventing circular imports. + log.Debug().Str("filepath", filePath).Msg("file %s has already been visited in another part of the walk") + return compileImpl(InputSchema{ + Source: input.Source(filePath), + SchemaString: "", + }, + compilationContext{ + existingNames: importContext.names, + locallyVisitedFiles: currentLocallyVisitedFiles, + globallyVisitedFiles: importContext.globallyVisitedFiles, + }, + AllowUnprefixedObjectType(), + SourceFolder(newSourceFolder), + ) + } + schemaBytes, err := os.ReadFile(filePath) if err != nil { return nil, fmt.Errorf("failed to read schema file: %w", err) } log.Trace().Str("schema", string(schemaBytes)).Str("file", filePath).Msg("read schema from file") - compiled, err := compileImpl(InputSchema{ + return compileImpl(InputSchema{ Source: input.Source(filePath), SchemaString: string(schemaBytes), }, - importContext.names, + compilationContext{ + existingNames: importContext.names, + locallyVisitedFiles: currentLocallyVisitedFiles, + globallyVisitedFiles: importContext.globallyVisitedFiles, + }, AllowUnprefixedObjectType(), SourceFolder(newSourceFolder), ) - if err != nil { - return nil, err - } - return compiled, nil } func constructFilePath(segments []string) string { diff --git a/pkg/composableschemadsl/compiler/importer_test.go b/pkg/composableschemadsl/compiler/importer_test.go index 03475359e4..8708f6e520 100644 --- a/pkg/composableschemadsl/compiler/importer_test.go +++ b/pkg/composableschemadsl/compiler/importer_test.go @@ -58,6 +58,7 @@ func TestImporter(t *testing.T) { {"nested local import", "nested-local"}, {"nested local import with transitive hop", "nested-local-with-hop"}, {"nested local two layers deep import", "nested-two-layer-local"}, + {"diamond-shaped imports are fine", "diamond-shaped"}, } for _, test := range importerTests { @@ -89,3 +90,23 @@ func TestImporter(t *testing.T) { }) } } + +func TestImportCycleCausesError(t *testing.T) { + t.Parallel() + + workingDir, err := os.Getwd() + require.NoError(t, err) + test := importerTest{"", "circular-import"} + + sourceFolder := path.Join(workingDir, test.relativePath()) + + inputSchema := test.input() + + _, err = compiler.Compile(compiler.InputSchema{ + Source: input.Source("schema"), + SchemaString: inputSchema, + }, compiler.AllowUnprefixedObjectType(), + compiler.SourceFolder(sourceFolder)) + + require.ErrorContains(t, err, "circular import") +} diff --git a/pkg/composableschemadsl/compiler/translator.go b/pkg/composableschemadsl/compiler/translator.go index 66fd63a810..3d4fce70d7 100644 --- a/pkg/composableschemadsl/compiler/translator.go +++ b/pkg/composableschemadsl/compiler/translator.go @@ -2,6 +2,7 @@ package compiler import ( "bufio" + "errors" "fmt" "strings" @@ -19,12 +20,14 @@ import ( ) type translationContext struct { - objectTypePrefix *string - mapper input.PositionMapper - schemaString string - skipValidate bool - existingNames *mapz.Set[string] - sourceFolder string + objectTypePrefix *string + mapper input.PositionMapper + schemaString string + skipValidate bool + existingNames *mapz.Set[string] + locallyVisitedFiles *mapz.Set[string] + globallyVisitedFiles *mapz.Set[string] + sourceFolder string } func (tctx translationContext) prefixedPath(definitionName string) (string, error) { @@ -696,7 +699,6 @@ func addWithCaveats(tctx translationContext, typeRefNode *dslNode, ref *core.All func translateImport(tctx translationContext, importNode *dslNode, names *mapz.Set[string]) (*CompiledSchema, error) { // NOTE: this function currently just grabs everything that's in the target file. // TODO: only grab the requested definitions - // TODO: import cycle tracking pathNodes := importNode.List(dslshape.NodeImportPredicatePathSegment) pathSegments := make([]string, 0, len(pathNodes)) @@ -709,9 +711,21 @@ func translateImport(tctx translationContext, importNode *dslNode, names *mapz.S pathSegments = append(pathSegments, segment) } - return importFile(importContext{ - names: names, - pathSegments: pathSegments, - sourceFolder: tctx.sourceFolder, + compiledSchema, err := importFile(importContext{ + names: names, + pathSegments: pathSegments, + sourceFolder: tctx.sourceFolder, + globallyVisitedFiles: tctx.globallyVisitedFiles, + locallyVisitedFiles: tctx.locallyVisitedFiles, }) + if err != nil { + var circularImportError *ErrCircularImport + if errors.As(err, &circularImportError) { + // NOTE: The "%s" is an empty format string to keep with the form of ErrorWithSourcef + return nil, importNode.ErrorWithSourcef(circularImportError.filePath, "%s", circularImportError.error.Error()) + } + return nil, err + } + + return compiledSchema, nil }