diff --git a/command/command.go b/command/command.go index 0813072..3f75c40 100644 --- a/command/command.go +++ b/command/command.go @@ -60,35 +60,9 @@ type Command struct { } func (c *Command) Run() error { - var formatter yamlfmt.Formatter - if c.Config.FormatterConfig == nil { - factory, err := c.Registry.GetDefaultFactory() - if err != nil { - return err - } - formatter, err = factory.NewFormatter(nil) - if err != nil { - return err - } - } else { - var ( - factory yamlfmt.Factory - err error - ) - if c.Config.FormatterConfig.Type == "" { - factory, err = c.Registry.GetDefaultFactory() - } else { - factory, err = c.Registry.GetFactory(c.Config.FormatterConfig.Type) - } - if err != nil { - return err - } - - c.Config.FormatterConfig.FormatterSettings["line_ending"] = c.Config.LineEnding - formatter, err = factory.NewFormatter(c.Config.FormatterConfig.FormatterSettings) - if err != nil { - return err - } + formatter, err := c.getFormatter() + if err != nil { + return err } lineSepChar, err := c.Config.LineEnding.Separator() @@ -191,6 +165,30 @@ func (c *Command) Run() error { return nil } +func (c *Command) getFormatter() (yamlfmt.Formatter, error) { + var factoryType string + + // In the existing codepaths, this value is always set. But + // it's a habit of mine to check anything that can possibly be nil + // if I remember that to be the case. :) + if c.Config.FormatterConfig != nil { + factoryType = c.Config.FormatterConfig.Type + + // The line ending set within the formatter settings takes precedence over setting + // it from the top level config. If it's not set in formatter settings, then + // we use the value from the top level. + if _, ok := c.Config.FormatterConfig.FormatterSettings["line_ending"]; !ok { + c.Config.FormatterConfig.FormatterSettings["line_ending"] = c.Config.LineEnding + } + } + + factory, err := c.Registry.GetFactory(factoryType) + if err != nil { + return nil, err + } + return factory.NewFormatter(c.Config.FormatterConfig.FormatterSettings) +} + func (c *Command) collectPaths() ([]string, error) { collector := c.makePathCollector() return collector.CollectPaths() diff --git a/command/command_test.go b/command/command_test.go new file mode 100644 index 0000000..6faf6bb --- /dev/null +++ b/command/command_test.go @@ -0,0 +1,46 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package command + +import ( + "testing" + + "github.com/google/yamlfmt" + "github.com/google/yamlfmt/formatters/basic" + "github.com/google/yamlfmt/internal/assert" +) + +// This test asserts the proper behaviour for `line_ending` settings specified +// in formatter settings overriding the global configuration. +func TestLineEndingFormatterVsGloabl(t *testing.T) { + c := &Command{ + Config: &Config{ + LineEnding: "lf", + FormatterConfig: &FormatterConfig{ + FormatterSettings: map[string]any{ + "line_ending": yamlfmt.LineBreakStyleLF, + }, + }, + }, + Registry: yamlfmt.NewFormatterRegistry(&basic.BasicFormatterFactory{}), + } + + f, err := c.getFormatter() + assert.NilErr(t, err) + configMap, err := f.ConfigMap() + assert.NilErr(t, err) + formatterLineEnding := configMap["line_ending"].(yamlfmt.LineBreakStyle) + assert.Assert(t, formatterLineEnding == yamlfmt.LineBreakStyleLF, "expected formatter's line ending to be lf") +} diff --git a/formatter.go b/formatter.go index 2a17475..ce432ee 100644 --- a/formatter.go +++ b/formatter.go @@ -46,6 +46,9 @@ func (r *Registry) Add(f Factory) { } func (r *Registry) GetFactory(fType string) (Factory, error) { + if fType == "" { + return r.GetDefaultFactory() + } factory, ok := r.registry[fType] if !ok { return nil, fmt.Errorf("no formatter registered with type \"%s\"", fType)