lockbox

password manager
Log | Files | Refs | README | LICENSE

commit 00cc3b524e29bc27e84582f594166b9414f38628
parent 5f5970d56ed2feefd81fef33a377f69a451624b3
Author: Sean Enck <sean@ttypty.com>
Date:   Fri, 16 Aug 2024 21:09:31 -0400

use Seq2 to get the error result instead of a wrapper struct

Diffstat:
Minternal/app/conv.go | 6+++---
Minternal/app/list.go | 6+++---
Minternal/app/totp.go | 6+++---
Minternal/backend/core.go | 19+++++++++++++++----
Minternal/backend/core_test.go | 42++++++++++++++++++++++++++++++++++++++++++
Minternal/backend/query.go | 28++++++++++------------------
Minternal/backend/query_test.go | 15+++++++--------
7 files changed, 83 insertions(+), 39 deletions(-)

diff --git a/internal/app/conv.go b/internal/app/conv.go @@ -39,9 +39,9 @@ func serialize(w io.Writer, tx *backend.Transaction, isJSON bool, filter string) } hasFilter := len(filter) > 0 printed := false - for item := range e { - if item.Error != nil { - return item.Error + for item, err := range e { + if err != nil { + return err } if hasFilter { if !strings.Contains(item.Path, filter) { diff --git a/internal/app/list.go b/internal/app/list.go @@ -20,9 +20,9 @@ func List(cmd CommandOptions) error { return err } w := cmd.Writer() - for f := range e { - if f.Error != nil { - return f.Error + for f, err := range e { + if err != nil { + return err } fmt.Fprintf(w, "%s\n", f.Path) } diff --git a/internal/app/totp.go b/internal/app/totp.go @@ -238,9 +238,9 @@ func (args *TOTPArguments) Do(opts TOTPOptions) error { return err } writer := opts.app.Writer() - for entry := range e { - if entry.Error != nil { - return entry.Error + for entry, err := range e { + if err != nil { + return err } fmt.Fprintf(writer, "%s\n", entry.Directory()) } diff --git a/internal/backend/core.go b/internal/backend/core.go @@ -4,6 +4,7 @@ package backend import ( "errors" "fmt" + "iter" "os" "strings" @@ -25,6 +26,8 @@ const ( ) type ( + // QuerySeq2 wraps the iteration for query entities + QuerySeq2 iter.Seq2[QueryEntity, error] // Transaction handles the overall operation of the transaction Transaction struct { valid bool @@ -43,10 +46,6 @@ type ( Value string backing gokeepasslib.Entry } - QuerySeq struct { - QueryEntity - Error error - } // TransactionEntity objects alter the database entities TransactionEntity struct { path string @@ -218,3 +217,15 @@ func getValue(e gokeepasslib.Entry, key string) string { func IsDirectory(path string) bool { return strings.HasSuffix(path, pathSep) } + +// Collect will create a slice from an iterable set of query sequence results +func (s QuerySeq2) Collect() ([]QueryEntity, error) { + var entities []QueryEntity + for entity, err := range s { + if err != nil { + return nil, err + } + entities = append(entities, entity) + } + return entities, nil +} diff --git a/internal/backend/core_test.go b/internal/backend/core_test.go @@ -1,6 +1,7 @@ package backend_test import ( + "errors" "fmt" "testing" @@ -112,3 +113,44 @@ func TestNewSuffix(t *testing.T) { t.Error("invalid suffix") } } + +func generateTestSeq(hasError, extra bool) backend.QuerySeq2 { + return func(yield func(backend.QueryEntity, error) bool) { + if !yield(backend.QueryEntity{}, nil) { + return + } + if !yield(backend.QueryEntity{}, nil) { + return + } + if hasError { + if !yield(backend.QueryEntity{}, errors.New("test collect error")) { + return + } + } + if !yield(backend.QueryEntity{}, nil) { + return + } + if extra { + if !yield(backend.QueryEntity{}, nil) { + return + } + } + } +} + +func TestQuerySeq2Collect(t *testing.T) { + seq := generateTestSeq(true, true) + if _, err := seq.Collect(); err == nil || err.Error() != "test collect error" { + t.Errorf("invalid error: %v", err) + } + seq = generateTestSeq(false, false) + c, err := seq.Collect() + if err != nil || len(c) != 3 { + t.Errorf("invalid collect: %v %v %d", c, err, len(c)) + } + seq = generateTestSeq(false, true) + c, err = seq.Collect() + if err != nil || len(c) != 4 { + t.Errorf("invalid collect: %v %v %d", c, err, len(c)) + } +} diff --git a/internal/backend/query.go b/internal/backend/query.go @@ -6,7 +6,6 @@ import ( "encoding/json" "errors" "fmt" - "iter" "sort" "strings" @@ -114,18 +113,11 @@ func (t *Transaction) queryCollect(args QueryOptions) ([]QueryEntity, error) { if err != nil { return nil, err } - var entities []QueryEntity - for entity := range e { - if entity.Error != nil { - return nil, entity.Error - } - entities = append(entities, entity.QueryEntity) - } - return entities, nil + return e.Collect() } // QueryCallback will retrieve a query based on setting -func (t *Transaction) QueryCallback(args QueryOptions) (iter.Seq[QuerySeq], error) { +func (t *Transaction) QueryCallback(args QueryOptions) (QuerySeq2, error) { if args.Mode == noneMode { return nil, errors.New("no query mode specified") } @@ -190,10 +182,10 @@ func (t *Transaction) QueryCallback(args QueryOptions) (iter.Seq[QuerySeq], erro return nil, err } } - return func(yield func(QuerySeq) bool) { + return func(yield func(QueryEntity, error) bool) { for _, k := range keys { - entity := QuerySeq{} - entity.Path = k + entity := QueryEntity{Path: k} + var err error if args.Values != BlankValue { e, ok := entities[k] if ok { @@ -215,20 +207,20 @@ func (t *Transaction) QueryCallback(args QueryOptions) (iter.Seq[QuerySeq], erro } t := getValue(e.backing, modTimeKey) s := JSON{ModTime: t, Data: data} - m, err := json.Marshal(s) - if err == nil { + m, jErr := json.Marshal(s) + if jErr == nil { entity.Value = string(m) } else { - entity.Error = err + err = jErr } case SecretValue: entity.Value = val } } else { - entity.Error = errors.New("failed to read entity back from map") + err = errors.New("failed to read entity back from map") } } - if !yield(entity) { + if !yield(entity, err) { return } } diff --git a/internal/backend/query_test.go b/internal/backend/query_test.go @@ -2,7 +2,6 @@ package backend_test import ( "encoding/json" - "iter" "os" "strings" "testing" @@ -157,13 +156,13 @@ func TestValueModes(t *testing.T) { } } -func testCollect(t *testing.T, count int, seq iter.Seq[backend.QuerySeq]) []backend.QueryEntity { - var collected []backend.QueryEntity - for item := range seq { - if item.Error != nil { - t.Errorf("unexpected error: %v", item.Error) - } - collected = append(collected, item.QueryEntity) +func testCollect(t *testing.T, count int, seq backend.QuerySeq2) []backend.QueryEntity { + collected, err := seq.Collect() + if err != nil { + t.Errorf("invalid collect error: %v", err) + } + if len(collected) != count { + t.Errorf("unexpected entity count: %d", count) } return collected }