lockbox

password manager
Log | Files | Refs | README | LICENSE

commit 5f5970d56ed2feefd81fef33a377f69a451624b3
parent 9cdfdb752abd91dd648f607587aa2bc3d9e8b691
Author: Sean Enck <sean@ttypty.com>
Date:   Fri, 16 Aug 2024 20:50:21 -0400

use an iterable sequence for query callback results

Diffstat:
Mgo.mod | 4+---
Minternal/app/conv.go | 5++++-
Minternal/app/list.go | 5++++-
Minternal/app/totp.go | 5++++-
Minternal/backend/core.go | 4++++
Minternal/backend/query.go | 95++++++++++++++++++++++++++++++++++++++++++++++++-------------------------------
Minternal/backend/query_test.go | 42++++++++++++++++++++++--------------------
7 files changed, 97 insertions(+), 63 deletions(-)

diff --git a/go.mod b/go.mod @@ -1,8 +1,6 @@ module github.com/seanenck/lockbox -go 1.22.0 - -toolchain go1.22.2 +go 1.23.0 require ( github.com/aymanbagabas/go-osc52 v1.2.2 diff --git a/internal/app/conv.go b/internal/app/conv.go @@ -39,7 +39,10 @@ func serialize(w io.Writer, tx *backend.Transaction, isJSON bool, filter string) } hasFilter := len(filter) > 0 printed := false - for _, item := range e { + for item := range e { + if item.Error != nil { + return item.Error + } if hasFilter { if !strings.Contains(item.Path, filter) { continue diff --git a/internal/app/list.go b/internal/app/list.go @@ -20,7 +20,10 @@ func List(cmd CommandOptions) error { return err } w := cmd.Writer() - for _, f := range e { + for f := range e { + if f.Error != nil { + return f.Error + } fmt.Fprintf(w, "%s\n", f.Path) } return nil diff --git a/internal/app/totp.go b/internal/app/totp.go @@ -238,7 +238,10 @@ func (args *TOTPArguments) Do(opts TOTPOptions) error { return err } writer := opts.app.Writer() - for _, entry := range e { + for entry := range e { + if entry.Error != nil { + return entry.Error + } fmt.Fprintf(writer, "%s\n", entry.Directory()) } return nil diff --git a/internal/backend/core.go b/internal/backend/core.go @@ -43,6 +43,10 @@ type ( Value string backing gokeepasslib.Entry } + QuerySeq struct { + QueryEntity + Error error + } // TransactionEntity objects alter the database entities TransactionEntity struct { path string diff --git a/internal/backend/query.go b/internal/backend/query.go @@ -6,6 +6,7 @@ import ( "encoding/json" "errors" "fmt" + "iter" "sort" "strings" @@ -70,7 +71,7 @@ func (t *Transaction) MatchPath(path string) ([]QueryEntity, error) { if strings.HasSuffix(prefix, pathSep) { return nil, errors.New("invalid match criteria, too many path separators") } - return t.QueryCallback(QueryOptions{Mode: PrefixMode, Criteria: prefix + pathSep, Values: BlankValue}) + return t.queryCollect(QueryOptions{Mode: PrefixMode, Criteria: prefix + pathSep, Values: BlankValue}) } // Get will request a singular entity @@ -79,7 +80,7 @@ func (t *Transaction) Get(path string, mode ValueMode) (*QueryEntity, error) { if err != nil { return nil, err } - e, err := t.QueryCallback(QueryOptions{Mode: ExactMode, Criteria: path, Values: mode}) + e, err := t.queryCollect(QueryOptions{Mode: ExactMode, Criteria: path, Values: mode}) if err != nil { return nil, err } @@ -108,8 +109,23 @@ func forEach(offset string, groups []gokeepasslib.Group, entries []gokeepasslib. } } +func (t *Transaction) queryCollect(args QueryOptions) ([]QueryEntity, error) { + e, err := t.QueryCallback(args) + if err != nil { + return nil, err + } + var entities []QueryEntity + for entity := range e { + if entity.Error != nil { + return nil, entity.Error + } + entities = append(entities, entity.QueryEntity) + } + return entities, nil +} + // QueryCallback will retrieve a query based on setting -func (t *Transaction) QueryCallback(args QueryOptions) ([]QueryEntity, error) { +func (t *Transaction) QueryCallback(args QueryOptions) (iter.Seq[QuerySeq], error) { if args.Mode == noneMode { return nil, errors.New("no query mode specified") } @@ -174,42 +190,47 @@ func (t *Transaction) QueryCallback(args QueryOptions) ([]QueryEntity, error) { return nil, err } } - var results []QueryEntity - for _, k := range keys { - entity := QueryEntity{Path: k} - if args.Values != BlankValue { - e, ok := entities[k] - if !ok { - return nil, errors.New("failed to read entity back from map") - } - val := getValue(e.backing, notesKey) - if strings.TrimSpace(val) == "" { - val = e.backing.GetPassword() - } - switch args.Values { - case JSONValue: - data := "" - switch jsonMode { - case config.JSONOutputs.Raw: - data = val - case config.JSONOutputs.Hash: - data = fmt.Sprintf("%x", sha512.Sum512([]byte(val))) - if hashLength > 0 && len(data) > hashLength { - data = data[0:hashLength] + return func(yield func(QuerySeq) bool) { + for _, k := range keys { + entity := QuerySeq{} + entity.Path = k + if args.Values != BlankValue { + e, ok := entities[k] + if ok { + val := getValue(e.backing, notesKey) + if strings.TrimSpace(val) == "" { + val = e.backing.GetPassword() } + switch args.Values { + case JSONValue: + data := "" + switch jsonMode { + case config.JSONOutputs.Raw: + data = val + case config.JSONOutputs.Hash: + data = fmt.Sprintf("%x", sha512.Sum512([]byte(val))) + if hashLength > 0 && len(data) > hashLength { + data = data[0:hashLength] + } + } + t := getValue(e.backing, modTimeKey) + s := JSON{ModTime: t, Data: data} + m, err := json.Marshal(s) + if err == nil { + entity.Value = string(m) + } else { + entity.Error = err + } + case SecretValue: + entity.Value = val + } + } else { + entity.Error = errors.New("failed to read entity back from map") } - t := getValue(e.backing, modTimeKey) - s := JSON{ModTime: t, Data: data} - m, err := json.Marshal(s) - if err != nil { - return nil, err - } - entity.Value = string(m) - case SecretValue: - entity.Value = val + } + if !yield(entity) { + return } } - results = append(results, entity) - } - return results, nil + }, nil } diff --git a/internal/backend/query_test.go b/internal/backend/query_test.go @@ -2,6 +2,7 @@ package backend_test import ( "encoding/json" + "iter" "os" "strings" "testing" @@ -156,58 +157,59 @@ func TestValueModes(t *testing.T) { } } +func testCollect(t *testing.T, count int, seq iter.Seq[backend.QuerySeq]) []backend.QueryEntity { + var collected []backend.QueryEntity + for item := range seq { + if item.Error != nil { + t.Errorf("unexpected error: %v", item.Error) + } + collected = append(collected, item.QueryEntity) + } + return collected +} + func TestQueryCallback(t *testing.T) { setupInserts(t) if _, err := fullSetup(t, true).QueryCallback(backend.QueryOptions{}); err.Error() != "no query mode specified" { t.Errorf("wrong error: %v", err) } - res, err := fullSetup(t, true).QueryCallback(backend.QueryOptions{Mode: backend.ListMode}) + seq, err := fullSetup(t, true).QueryCallback(backend.QueryOptions{Mode: backend.ListMode}) if err != nil { t.Errorf("no error: %v", err) } - if len(res) != 4 { - t.Error("invalid results: not enough") - } + res := testCollect(t, 4, seq) if res[0].Path != "test/test/ab11c" || res[1].Path != "test/test/abc" || res[2].Path != "test/test/abc1ak" || res[3].Path != "test/test/abcx" { t.Errorf("invalid results: %v", res) } - res, err = fullSetup(t, true).QueryCallback(backend.QueryOptions{Mode: backend.FindMode, Criteria: "1"}) + seq, err = fullSetup(t, true).QueryCallback(backend.QueryOptions{Mode: backend.FindMode, Criteria: "1"}) if err != nil { t.Errorf("no error: %v", err) } - if len(res) != 2 { - t.Error("invalid results: not enough") - } + res = testCollect(t, 2, seq) if res[0].Path != "test/test/ab11c" || res[1].Path != "test/test/abc1ak" { t.Errorf("invalid results: %v", res) } - res, err = fullSetup(t, true).QueryCallback(backend.QueryOptions{Mode: backend.SuffixMode, Criteria: "c"}) + seq, err = fullSetup(t, true).QueryCallback(backend.QueryOptions{Mode: backend.SuffixMode, Criteria: "c"}) if err != nil { t.Errorf("no error: %v", err) } - if len(res) != 2 { - t.Error("invalid results: not enough") - } + res = testCollect(t, 2, seq) if res[0].Path != "test/test/ab11c" || res[1].Path != "test/test/abc" { t.Errorf("invalid results: %v", res) } - res, err = fullSetup(t, true).QueryCallback(backend.QueryOptions{Mode: backend.ExactMode, Criteria: "test/test/abc"}) + seq, err = fullSetup(t, true).QueryCallback(backend.QueryOptions{Mode: backend.ExactMode, Criteria: "test/test/abc"}) if err != nil { t.Errorf("no error: %v", err) } - if len(res) != 1 { - t.Error("invalid results: not enough") - } + res = testCollect(t, 1, seq) if res[0].Path != "test/test/abc" { t.Errorf("invalid results: %v", res) } - res, err = fullSetup(t, true).QueryCallback(backend.QueryOptions{Mode: backend.ExactMode, Criteria: "abczzz"}) + seq, err = fullSetup(t, true).QueryCallback(backend.QueryOptions{Mode: backend.ExactMode, Criteria: "abczzz"}) if err != nil { t.Errorf("no error: %v", err) } - if len(res) != 0 { - t.Error("invalid results: should be empty") - } + testCollect(t, 0, seq) } func TestSetModTime(t *testing.T) {