commit 00cc3b524e29bc27e84582f594166b9414f38628
parent 5f5970d56ed2feefd81fef33a377f69a451624b3
Author: Sean Enck <sean@ttypty.com>
Date: Fri, 16 Aug 2024 21:09:31 -0400
use Seq2 to get the error result instead of a wrapper struct
Diffstat:
7 files changed, 83 insertions(+), 39 deletions(-)
diff --git a/internal/app/conv.go b/internal/app/conv.go
@@ -39,9 +39,9 @@ func serialize(w io.Writer, tx *backend.Transaction, isJSON bool, filter string)
}
hasFilter := len(filter) > 0
printed := false
- for item := range e {
- if item.Error != nil {
- return item.Error
+ for item, err := range e {
+ if err != nil {
+ return err
}
if hasFilter {
if !strings.Contains(item.Path, filter) {
diff --git a/internal/app/list.go b/internal/app/list.go
@@ -20,9 +20,9 @@ func List(cmd CommandOptions) error {
return err
}
w := cmd.Writer()
- for f := range e {
- if f.Error != nil {
- return f.Error
+ for f, err := range e {
+ if err != nil {
+ return err
}
fmt.Fprintf(w, "%s\n", f.Path)
}
diff --git a/internal/app/totp.go b/internal/app/totp.go
@@ -238,9 +238,9 @@ func (args *TOTPArguments) Do(opts TOTPOptions) error {
return err
}
writer := opts.app.Writer()
- for entry := range e {
- if entry.Error != nil {
- return entry.Error
+ for entry, err := range e {
+ if err != nil {
+ return err
}
fmt.Fprintf(writer, "%s\n", entry.Directory())
}
diff --git a/internal/backend/core.go b/internal/backend/core.go
@@ -4,6 +4,7 @@ package backend
import (
"errors"
"fmt"
+ "iter"
"os"
"strings"
@@ -25,6 +26,8 @@ const (
)
type (
+ // QuerySeq2 wraps the iteration for query entities
+ QuerySeq2 iter.Seq2[QueryEntity, error]
// Transaction handles the overall operation of the transaction
Transaction struct {
valid bool
@@ -43,10 +46,6 @@ type (
Value string
backing gokeepasslib.Entry
}
- QuerySeq struct {
- QueryEntity
- Error error
- }
// TransactionEntity objects alter the database entities
TransactionEntity struct {
path string
@@ -218,3 +217,15 @@ func getValue(e gokeepasslib.Entry, key string) string {
func IsDirectory(path string) bool {
return strings.HasSuffix(path, pathSep)
}
+
+// Collect will create a slice from an iterable set of query sequence results
+func (s QuerySeq2) Collect() ([]QueryEntity, error) {
+ var entities []QueryEntity
+ for entity, err := range s {
+ if err != nil {
+ return nil, err
+ }
+ entities = append(entities, entity)
+ }
+ return entities, nil
+}
diff --git a/internal/backend/core_test.go b/internal/backend/core_test.go
@@ -1,6 +1,7 @@
package backend_test
import (
+ "errors"
"fmt"
"testing"
@@ -112,3 +113,44 @@ func TestNewSuffix(t *testing.T) {
t.Error("invalid suffix")
}
}
+
+func generateTestSeq(hasError, extra bool) backend.QuerySeq2 {
+ return func(yield func(backend.QueryEntity, error) bool) {
+ if !yield(backend.QueryEntity{}, nil) {
+ return
+ }
+ if !yield(backend.QueryEntity{}, nil) {
+ return
+ }
+ if hasError {
+ if !yield(backend.QueryEntity{}, errors.New("test collect error")) {
+ return
+ }
+ }
+ if !yield(backend.QueryEntity{}, nil) {
+ return
+ }
+ if extra {
+ if !yield(backend.QueryEntity{}, nil) {
+ return
+ }
+ }
+ }
+}
+
+func TestQuerySeq2Collect(t *testing.T) {
+ seq := generateTestSeq(true, true)
+ if _, err := seq.Collect(); err == nil || err.Error() != "test collect error" {
+ t.Errorf("invalid error: %v", err)
+ }
+ seq = generateTestSeq(false, false)
+ c, err := seq.Collect()
+ if err != nil || len(c) != 3 {
+ t.Errorf("invalid collect: %v %v %d", c, err, len(c))
+ }
+ seq = generateTestSeq(false, true)
+ c, err = seq.Collect()
+ if err != nil || len(c) != 4 {
+ t.Errorf("invalid collect: %v %v %d", c, err, len(c))
+ }
+}
diff --git a/internal/backend/query.go b/internal/backend/query.go
@@ -6,7 +6,6 @@ import (
"encoding/json"
"errors"
"fmt"
- "iter"
"sort"
"strings"
@@ -114,18 +113,11 @@ func (t *Transaction) queryCollect(args QueryOptions) ([]QueryEntity, error) {
if err != nil {
return nil, err
}
- var entities []QueryEntity
- for entity := range e {
- if entity.Error != nil {
- return nil, entity.Error
- }
- entities = append(entities, entity.QueryEntity)
- }
- return entities, nil
+ return e.Collect()
}
// QueryCallback will retrieve a query based on setting
-func (t *Transaction) QueryCallback(args QueryOptions) (iter.Seq[QuerySeq], error) {
+func (t *Transaction) QueryCallback(args QueryOptions) (QuerySeq2, error) {
if args.Mode == noneMode {
return nil, errors.New("no query mode specified")
}
@@ -190,10 +182,10 @@ func (t *Transaction) QueryCallback(args QueryOptions) (iter.Seq[QuerySeq], erro
return nil, err
}
}
- return func(yield func(QuerySeq) bool) {
+ return func(yield func(QueryEntity, error) bool) {
for _, k := range keys {
- entity := QuerySeq{}
- entity.Path = k
+ entity := QueryEntity{Path: k}
+ var err error
if args.Values != BlankValue {
e, ok := entities[k]
if ok {
@@ -215,20 +207,20 @@ func (t *Transaction) QueryCallback(args QueryOptions) (iter.Seq[QuerySeq], erro
}
t := getValue(e.backing, modTimeKey)
s := JSON{ModTime: t, Data: data}
- m, err := json.Marshal(s)
- if err == nil {
+ m, jErr := json.Marshal(s)
+ if jErr == nil {
entity.Value = string(m)
} else {
- entity.Error = err
+ err = jErr
}
case SecretValue:
entity.Value = val
}
} else {
- entity.Error = errors.New("failed to read entity back from map")
+ err = errors.New("failed to read entity back from map")
}
}
- if !yield(entity) {
+ if !yield(entity, err) {
return
}
}
diff --git a/internal/backend/query_test.go b/internal/backend/query_test.go
@@ -2,7 +2,6 @@ package backend_test
import (
"encoding/json"
- "iter"
"os"
"strings"
"testing"
@@ -157,13 +156,13 @@ func TestValueModes(t *testing.T) {
}
}
-func testCollect(t *testing.T, count int, seq iter.Seq[backend.QuerySeq]) []backend.QueryEntity {
- var collected []backend.QueryEntity
- for item := range seq {
- if item.Error != nil {
- t.Errorf("unexpected error: %v", item.Error)
- }
- collected = append(collected, item.QueryEntity)
+func testCollect(t *testing.T, count int, seq backend.QuerySeq2) []backend.QueryEntity {
+ collected, err := seq.Collect()
+ if err != nil {
+ t.Errorf("invalid collect error: %v", err)
+ }
+ if len(collected) != count {
+ t.Errorf("unexpected entity count: %d", count)
}
return collected
}