commit a599f69c8285cd3ac63d08c80f49054a88d99cdd
parent b32a8f8d34741275d3c319f197dd9100c69bb113
Author: Sean Enck <sean@ttypty.com>
Date: Sun, 28 Sep 2025 18:52:02 -0400
enable strict enable/disable feature (mostly for included files only for now)
Diffstat:
5 files changed, 188 insertions(+), 73 deletions(-)
diff --git a/cmd/lb/main.go b/cmd/lb/main.go
@@ -50,7 +50,7 @@ func handleEarly(command string, args []string) (bool, error) {
func run() error {
for _, p := range config.NewConfigFiles() {
if platform.PathExists(p) {
- if err := config.LoadConfigFile(p); err != nil {
+ if err := platform.LoadConfigFile(p); err != nil {
return err
}
break
diff --git a/internal/config/toml.go b/internal/config/toml.go
@@ -1,7 +1,6 @@
package config
import (
- "bytes"
"fmt"
"io"
"maps"
@@ -16,12 +15,14 @@ import (
)
const (
- isInclude = "include"
- maxDepth = 10
- tomlInt = "integer"
- tomlBool = "boolean"
- tomlString = "string"
- tomlArray = "[]string"
+ isStrict = "strict"
+ isInclude = "include"
+ maxDepth = 10
+ tomlInt = "integer"
+ tomlBool = "boolean"
+ tomlString = "string"
+ tomlArray = "[]string"
+ strictDefault = true
)
type (
@@ -77,7 +78,14 @@ func DefaultTOML() (string, error) {
#
# it is ONLY used during TOML configuration loading
%s = []
-`, maxDepth, isInclude), "\n"} {
+
+# strict, when enabled, requires the configuration entries
+# to adhere to all loading rules.
+#
+# it is currently only used to ignore included files that
+# are not found
+%s = %t
+`, maxDepth, isInclude, isStrict, strictDefault), "\n"} {
if _, err := builder.WriteString(header); err != nil {
return "", err
}
@@ -133,8 +141,8 @@ func generateDetailText(data printer) (string, error) {
return strings.Join(text, "\n"), nil
}
-// LoadConfig will read the input reader and use the loader to source configuration files
-func LoadConfig(r io.Reader, loader Loader) error {
+// Load will read the input reader and use the loader to source configuration files
+func Load(r io.Reader, loader Loader) error {
mapped, err := readConfigs(r, 1, loader)
if err != nil {
return err
@@ -160,23 +168,22 @@ func LoadConfig(r io.Reader, loader Loader) error {
case tomlInt:
i, ok := v.(int64)
if !ok {
- return fmt.Errorf("non-int64 found where expected: %v", v)
+ return newTypeError("int64", v)
}
if i < 0 {
return fmt.Errorf("%d is negative (not allowed here)", i)
}
store.SetInt64(export, i)
case tomlBool:
- switch t := v.(type) {
- case bool:
- store.SetBool(export, t)
- default:
- return fmt.Errorf("non-bool found where expected: %v", v)
+ b, err := parseBool(v)
+ if err != nil {
+ return err
}
+ store.SetBool(export, b)
case tomlString:
s, ok := v.(string)
if !ok {
- return fmt.Errorf("non-string found where expected: %v", v)
+ return newTypeError("string", v)
}
if md.canExpand {
s = os.Expand(s, os.Getenv)
@@ -190,6 +197,19 @@ func LoadConfig(r io.Reader, loader Loader) error {
return nil
}
+func newTypeError(t string, v any) error {
+ return fmt.Errorf("non-%s found where %s expected: %v", t, t, v)
+}
+
+func parseBool(v any) (bool, error) {
+ switch t := v.(type) {
+ case bool:
+ return t, nil
+ default:
+ return false, newTypeError("bool", v)
+ }
+}
+
func readConfigs(r io.Reader, depth int, loader Loader) ([]map[string]any, error) {
if depth > maxDepth {
return nil, fmt.Errorf("too many nested includes (%d > %d)", depth, maxDepth)
@@ -200,6 +220,15 @@ func readConfigs(r io.Reader, depth int, loader Loader) ([]map[string]any, error
return nil, err
}
maps := []map[string]any{m}
+ strict := strictDefault
+ if v, ok := m[isStrict]; ok {
+ delete(m, isStrict)
+ b, err := parseBool(v)
+ if err != nil {
+ return nil, err
+ }
+ strict = b
+ }
includes, ok := m[isInclude]
if ok {
delete(m, isInclude)
@@ -222,6 +251,12 @@ func readConfigs(r io.Reader, depth int, loader Loader) ([]map[string]any, error
if err != nil {
return nil, err
}
+ if reader == nil {
+ if strict {
+ return nil, fmt.Errorf("failed to load the included file: %s", file)
+ }
+ continue
+ }
results, err := readConfigs(reader, depth+1, loader)
if err != nil {
return nil, err
@@ -258,7 +293,6 @@ func parseStringArray(value any, expand bool) ([]string, error) {
func flatten(m map[string]any, prefix string) map[string]any {
flattened := make(map[string]any)
-
for k, v := range m {
key := k
if prefix != "" {
@@ -275,21 +309,3 @@ func flatten(m map[string]any, prefix string) map[string]any {
return flattened
}
-
-func configLoader(path string) (io.Reader, error) {
- data, err := os.ReadFile(path)
- if err != nil {
- return nil, err
- }
- return bytes.NewReader(data), nil
-}
-
-// LoadConfigFile will load a path as the configuration
-// it will also set the environment
-func LoadConfigFile(path string) error {
- reader, err := configLoader(path)
- if err != nil {
- return err
- }
- return LoadConfig(reader, configLoader)
-}
diff --git a/internal/config/toml_test.go b/internal/config/toml_test.go
@@ -5,7 +5,6 @@ import (
"fmt"
"io"
"os"
- "path/filepath"
"strings"
"testing"
@@ -23,7 +22,7 @@ func TestLoadIncludes(t *testing.T) {
t.Setenv("TEST", "xyz")
data := `include = ["$TEST/abc"]`
r := strings.NewReader(data)
- if err := config.LoadConfig(r, func(p string) (io.Reader, error) {
+ if err := config.Load(r, func(p string) (io.Reader, error) {
if p == "xyz/abc" {
return strings.NewReader("include = [\"$TEST/abc\"]"), nil
} else {
@@ -34,7 +33,7 @@ func TestLoadIncludes(t *testing.T) {
}
data = `include = ["abc"]`
r = strings.NewReader(data)
- if err := config.LoadConfig(r, func(p string) (io.Reader, error) {
+ if err := config.Load(r, func(p string) (io.Reader, error) {
if p == "xyz/abc" {
return strings.NewReader("include = [\"aaa\"]"), nil
} else {
@@ -45,7 +44,7 @@ func TestLoadIncludes(t *testing.T) {
}
data = `include = 1`
r = strings.NewReader(data)
- if err := config.LoadConfig(r, func(p string) (io.Reader, error) {
+ if err := config.Load(r, func(p string) (io.Reader, error) {
if p == "xyz/abc" {
return strings.NewReader("include = [\"aaa\"]"), nil
} else {
@@ -56,7 +55,7 @@ func TestLoadIncludes(t *testing.T) {
}
data = `include = [1]`
r = strings.NewReader(data)
- if err := config.LoadConfig(r, func(p string) (io.Reader, error) {
+ if err := config.Load(r, func(p string) (io.Reader, error) {
if p == "xyz/abc" {
return strings.NewReader("include = [\"aaa\"]"), nil
} else {
@@ -69,7 +68,7 @@ func TestLoadIncludes(t *testing.T) {
store="xyz"
`
r = strings.NewReader(data)
- if err := config.LoadConfig(r, func(p string) (io.Reader, error) {
+ if err := config.Load(r, func(p string) (io.Reader, error) {
if p == "xyz/abc" {
return strings.NewReader("store = 'abc'"), nil
} else {
@@ -96,7 +95,7 @@ func TestArrayLoad(t *testing.T) {
copy = ["'xyz/$TEST'", "s", 1]
`
r := strings.NewReader(data)
- if err := config.LoadConfig(r, emptyRead); err == nil || err.Error() != "value is not string in array: 1" {
+ if err := config.Load(r, emptyRead); err == nil || err.Error() != "value is not string in array: 1" {
t.Errorf("invalid error: %v", err)
}
data = `include = []
@@ -105,7 +104,7 @@ store="xyz"
copy = ["'xyz/$TEST'", "s"]
`
r = strings.NewReader(data)
- if err := config.LoadConfig(r, emptyRead); err != nil {
+ if err := config.Load(r, emptyRead); err != nil {
t.Errorf("invalid error: %v", err)
}
if len(store.List()) != 2 {
@@ -125,7 +124,7 @@ store="xyz"
copy = ["'xyz/$TEST'", "s"]
`
r = strings.NewReader(data)
- if err := config.LoadConfig(r, emptyRead); err != nil {
+ if err := config.Load(r, emptyRead); err != nil {
t.Errorf("invalid error: %v", err)
}
if len(store.List()) != 2 {
@@ -148,7 +147,7 @@ func TestReadInt(t *testing.T) {
hash_length = true
`
r := strings.NewReader(data)
- if err := config.LoadConfig(r, emptyRead); err == nil || err.Error() != "non-int64 found where expected: true" {
+ if err := config.Load(r, emptyRead); err == nil || err.Error() != "non-int64 found where int64 expected: true" {
t.Errorf("invalid error: %v", err)
}
data = `include = []
@@ -156,7 +155,7 @@ hash_length = true
hash_length = 1
`
r = strings.NewReader(data)
- if err := config.LoadConfig(r, emptyRead); err != nil {
+ if err := config.Load(r, emptyRead); err != nil {
t.Errorf("invalid error: %v", err)
}
if len(store.List()) != 1 {
@@ -171,7 +170,7 @@ hash_length = 1
hash_length = -1
`
r = strings.NewReader(data)
- if err := config.LoadConfig(r, emptyRead); err == nil || err.Error() != "-1 is negative (not allowed here)" {
+ if err := config.Load(r, emptyRead); err == nil || err.Error() != "-1 is negative (not allowed here)" {
t.Errorf("invalid error: %v", err)
}
}
@@ -183,7 +182,7 @@ func TestReadBool(t *testing.T) {
clip = 1
`
r := strings.NewReader(data)
- if err := config.LoadConfig(r, emptyRead); err == nil || err.Error() != "non-bool found where expected: 1" {
+ if err := config.Load(r, emptyRead); err == nil || err.Error() != "non-bool found where bool expected: 1" {
t.Errorf("invalid error: %v", err)
}
data = `include = []
@@ -191,7 +190,7 @@ clip = 1
clip = true
`
r = strings.NewReader(data)
- if err := config.LoadConfig(r, emptyRead); err != nil {
+ if err := config.Load(r, emptyRead); err != nil {
t.Errorf("invalid error: %v", err)
}
if len(store.List()) != 1 {
@@ -206,7 +205,7 @@ clip = true
clip = false
`
r = strings.NewReader(data)
- if err := config.LoadConfig(r, emptyRead); err != nil {
+ if err := config.Load(r, emptyRead); err != nil {
t.Errorf("invalid error: %v", err)
}
if len(store.List()) != 1 {
@@ -225,7 +224,7 @@ func TestBadValues(t *testing.T) {
enabled = "false"
`
r := strings.NewReader(data)
- if err := config.LoadConfig(r, emptyRead); err == nil || err.Error() != "unknown key: totsp_enabled (LOCKBOX_TOTSP_ENABLED)" {
+ if err := config.Load(r, emptyRead); err == nil || err.Error() != "unknown key: totsp_enabled (LOCKBOX_TOTSP_ENABLED)" {
t.Errorf("invalid error: %v", err)
}
data = `include = []
@@ -233,29 +232,11 @@ enabled = "false"
otp_format = -1
`
r = strings.NewReader(data)
- if err := config.LoadConfig(r, emptyRead); err == nil || err.Error() != "non-string found where expected: -1" {
+ if err := config.Load(r, emptyRead); err == nil || err.Error() != "non-string found where string expected: -1" {
t.Errorf("invalid error: %v", err)
}
}
-func TestDefaultTOMLToLoadFile(t *testing.T) {
- store.Clear()
- os.Mkdir("testdata", 0o755)
- defer os.RemoveAll("testdata")
- file := filepath.Join("testdata", "config.toml")
- loaded, err := config.DefaultTOML()
- if err != nil {
- t.Errorf("invalid error: %v", err)
- }
- os.WriteFile(file, []byte(loaded), 0o644)
- if err := config.LoadConfigFile(file); err != nil {
- t.Errorf("invalid error: %v", err)
- }
- if len(store.List()) != 16 {
- t.Errorf("invalid environment after load: %d", len(store.List()))
- }
-}
-
func TestExpands(t *testing.T) {
store.Clear()
t.Setenv("TEST", "1")
@@ -266,7 +247,7 @@ clip.copy = ["$TEST", "$TEST"]
otp_format = "$TEST"
`
r := strings.NewReader(data)
- if err := config.LoadConfig(r, emptyRead); err != nil {
+ if err := config.Load(r, emptyRead); err != nil {
t.Errorf("invalid error: %v", err)
}
if len(store.List()) != 3 {
@@ -285,3 +266,43 @@ otp_format = "$TEST"
t.Errorf("invalid object: %v", a)
}
}
+
+func TestLoadIncludesStrictControls(t *testing.T) {
+ store.Clear()
+ defer os.Clearenv()
+ t.Setenv("TEST", "xyz")
+ data := `include = ["$TEST/abc"]
+store="xyz"
+strict = true
+`
+ r := strings.NewReader(data)
+ if err := config.Load(r, func(p string) (io.Reader, error) {
+ if p == "xyz/abc" {
+ return strings.NewReader("include = ['123']\nstrict = 1\nstore = 'abc'"), nil
+ } else {
+ return nil, errors.New("invalid path")
+ }
+ }); err == nil || err.Error() != "non-bool found where bool expected: 1" {
+ t.Errorf("invalid error: %v", err)
+ }
+ data = `include = ["$TEST/abc"]
+store="xyz"
+strict = true
+`
+ r = strings.NewReader(data)
+ if err := config.Load(r, func(_ string) (io.Reader, error) {
+ return nil, nil
+ }); err == nil || err.Error() != "failed to load the included file: xyz/abc" {
+ t.Errorf("invalid error: %v", err)
+ }
+ data = `include = ["$TEST/abc"]
+store="xyz"
+strict = false
+`
+ r = strings.NewReader(data)
+ if err := config.Load(r, func(_ string) (io.Reader, error) {
+ return nil, nil
+ }); err != nil {
+ t.Errorf("invalid error: %v", err)
+ }
+}
diff --git a/internal/platform/os.go b/internal/platform/os.go
@@ -6,9 +6,12 @@ import (
"bytes"
"errors"
"fmt"
+ "io"
"os"
"strings"
"syscall"
+
+ "github.com/enckse/lockbox/internal/config"
)
func termEcho(on bool) {
@@ -133,3 +136,24 @@ func PathExists(file string) bool {
}
return true
}
+
+func configLoader(path string) (io.Reader, error) {
+ if !PathExists(path) {
+ return nil, nil
+ }
+ data, err := os.ReadFile(path)
+ if err != nil {
+ return nil, err
+ }
+ return bytes.NewReader(data), nil
+}
+
+// LoadConfigFile will load a path as the configuration
+// it will also set the environment
+func LoadConfigFile(path string) error {
+ reader, err := configLoader(path)
+ if err != nil {
+ return err
+ }
+ return config.Load(reader, configLoader)
+}
diff --git a/internal/platform/os_test.go b/internal/platform/os_test.go
@@ -3,8 +3,11 @@ package platform_test
import (
"os"
"path/filepath"
+ "strings"
"testing"
+ "github.com/enckse/lockbox/internal/config"
+ "github.com/enckse/lockbox/internal/config/store"
"github.com/enckse/lockbox/internal/platform"
)
@@ -19,3 +22,54 @@ func TestPathExist(t *testing.T) {
t.Error("test dir SHOULD exist")
}
}
+
+func TestLoadConfigFile(t *testing.T) {
+ store.Clear()
+ os.Mkdir("testdata", 0o755)
+ defer os.RemoveAll("testdata")
+ file := filepath.Join("testdata", "config.toml")
+ loaded, err := config.DefaultTOML()
+ if err != nil {
+ t.Errorf("invalid error: %v", err)
+ }
+ os.WriteFile(file, []byte(loaded), 0o644)
+ if err := platform.LoadConfigFile(file); err != nil {
+ t.Errorf("invalid error: %v", err)
+ }
+ if len(store.List()) != 16 {
+ t.Errorf("invalid environment after load: %d", len(store.List()))
+ }
+}
+
+func TestLoadConfigFileNoFileStrict(t *testing.T) {
+ store.Clear()
+ os.Mkdir("testdata", 0o755)
+ defer os.RemoveAll("testdata")
+ file := filepath.Join("testdata", "config.toml")
+ loaded, err := config.DefaultTOML()
+ if err != nil {
+ t.Errorf("invalid error: %v", err)
+ }
+ loaded = strings.Replace(loaded, "include = []", "include = ['invalid.toml']", 1)
+ os.WriteFile(file, []byte(loaded), 0o644)
+ if err := platform.LoadConfigFile(file); err == nil || err.Error() != "failed to load the included file: invalid.toml" {
+ t.Errorf("invalid error: %v", err)
+ }
+}
+
+func TestLoadConfigFileNoFileNoStrict(t *testing.T) {
+ store.Clear()
+ os.Mkdir("testdata", 0o755)
+ defer os.RemoveAll("testdata")
+ file := filepath.Join("testdata", "config.toml")
+ loaded, err := config.DefaultTOML()
+ if err != nil {
+ t.Errorf("invalid error: %v", err)
+ }
+ loaded = strings.Replace(loaded, "include = []", "include = ['invalid.toml']", 1)
+ loaded = strings.Replace(loaded, "strict = true", "strict = false", 1)
+ os.WriteFile(file, []byte(loaded), 0o644)
+ if err := platform.LoadConfigFile(file); err != nil {
+ t.Errorf("invalid error: %v", err)
+ }
+}