commit abd3bf6f657deac9baea40fb286751979fe6321e
parent d70fac41ea582c7e6769bfe6808b24efb1c6bddc
Author: Sean Enck <sean@ttypty.com>
Date: Fri, 6 Dec 2024 21:48:06 -0500
support 10 levels of nested includes
Diffstat:
2 files changed, 16 insertions(+), 14 deletions(-)
diff --git a/internal/config/toml.go b/internal/config/toml.go
@@ -3,7 +3,6 @@ package config
import (
"bytes"
_ "embed"
- "errors"
"fmt"
"io"
"os"
@@ -15,7 +14,10 @@ import (
"github.com/BurntSushi/toml"
)
-const isInclude = "include"
+const (
+ isInclude = "include"
+ maxDepth = 10
+)
type (
// Loader indicates how included files should be sourced
@@ -143,12 +145,14 @@ func DefaultTOML() (string, error) {
return "", err
}
for _, header := range []string{configEnv, "\n", fmt.Sprintf(`
-# include additional configs, can NOT nest, but does allow globs ('*')
+# include additional configs, allowing globs ('*'), nesting
+# depth allowed up to %d include levels
+#
# this field is not configurable via environment variables
# and it is not considered part of the environment either
# it is ONLY used during TOML configuration loading
%s = []
-`, isInclude), "\n"} {
+`, maxDepth, isInclude), "\n"} {
if _, err := builder.WriteString(header); err != nil {
return "", err
}
@@ -196,7 +200,7 @@ func generateDetailText(key string) (string, error) {
// LoadConfig will read the input reader and use the loader to source configuration files
func LoadConfig(r io.Reader, loader Loader) ([]ShellEnv, error) {
m := make(map[string]interface{})
- if err := overlayConfig(r, true, &m, loader); err != nil {
+ if err := overlayConfig(r, 1, &m, loader); err != nil {
return nil, err
}
m = flatten(m, "")
@@ -250,7 +254,10 @@ func LoadConfig(r io.Reader, loader Loader) ([]ShellEnv, error) {
return res, nil
}
-func overlayConfig(r io.Reader, canInclude bool, m *map[string]interface{}, loader Loader) error {
+func overlayConfig(r io.Reader, depth int, m *map[string]interface{}, loader Loader) error {
+ if depth > maxDepth {
+ return fmt.Errorf("too many nested includes (%d > %d)", depth, maxDepth)
+ }
d := toml.NewDecoder(r)
if _, err := d.Decode(m); err != nil {
return err
@@ -264,9 +271,6 @@ func overlayConfig(r io.Reader, canInclude bool, m *map[string]interface{}, load
return err
}
if len(including) > 0 {
- if !canInclude {
- return errors.New("nested includes not allowed")
- }
for _, s := range including {
use := os.Expand(s, os.Getenv)
files := []string{use}
@@ -282,7 +286,7 @@ func overlayConfig(r io.Reader, canInclude bool, m *map[string]interface{}, load
if err != nil {
return err
}
- if err := overlayConfig(reader, false, m, nil); err != nil {
+ if err := overlayConfig(reader, depth+1, m, loader); err != nil {
return err
}
}
diff --git a/internal/config/toml_test.go b/internal/config/toml_test.go
@@ -2,7 +2,6 @@ package config_test
import (
"errors"
- "fmt"
"io"
"os"
"path/filepath"
@@ -20,11 +19,11 @@ func TestLoadIncludes(t *testing.T) {
r := strings.NewReader(data)
if _, err := config.LoadConfig(r, func(p string) (io.Reader, error) {
if p == "xyz/abc" {
- return strings.NewReader("include = [\"aaa\"]"), nil
+ return strings.NewReader("include = [\"$TEST/abc\"]"), nil
} else {
return nil, errors.New("invalid path")
}
- }); err == nil || err.Error() != "nested includes not allowed" {
+ }); err == nil || err.Error() != "too many nested includes (11 > 10)" {
t.Errorf("invalid error: %v", err)
}
data = `include = ["abc"]`
@@ -268,7 +267,6 @@ func TestDefaultTOMLToLoadFile(t *testing.T) {
t.Errorf("invalid error: %v", err)
}
os.WriteFile(file, []byte(loaded), 0o644)
- fmt.Println(loaded)
if err := config.LoadConfigFile(file); err != nil {
t.Errorf("invalid error: %v", err)
}