commit 5f5970d56ed2feefd81fef33a377f69a451624b3
parent 9cdfdb752abd91dd648f607587aa2bc3d9e8b691
Author: Sean Enck <sean@ttypty.com>
Date: Fri, 16 Aug 2024 20:50:21 -0400
use an iterable sequence for query callback results
Diffstat:
7 files changed, 97 insertions(+), 63 deletions(-)
diff --git a/go.mod b/go.mod
@@ -1,8 +1,6 @@
module github.com/seanenck/lockbox
-go 1.22.0
-
-toolchain go1.22.2
+go 1.23.0
require (
github.com/aymanbagabas/go-osc52 v1.2.2
diff --git a/internal/app/conv.go b/internal/app/conv.go
@@ -39,7 +39,10 @@ func serialize(w io.Writer, tx *backend.Transaction, isJSON bool, filter string)
}
hasFilter := len(filter) > 0
printed := false
- for _, item := range e {
+ for item := range e {
+ if item.Error != nil {
+ return item.Error
+ }
if hasFilter {
if !strings.Contains(item.Path, filter) {
continue
diff --git a/internal/app/list.go b/internal/app/list.go
@@ -20,7 +20,10 @@ func List(cmd CommandOptions) error {
return err
}
w := cmd.Writer()
- for _, f := range e {
+ for f := range e {
+ if f.Error != nil {
+ return f.Error
+ }
fmt.Fprintf(w, "%s\n", f.Path)
}
return nil
diff --git a/internal/app/totp.go b/internal/app/totp.go
@@ -238,7 +238,10 @@ func (args *TOTPArguments) Do(opts TOTPOptions) error {
return err
}
writer := opts.app.Writer()
- for _, entry := range e {
+ for entry := range e {
+ if entry.Error != nil {
+ return entry.Error
+ }
fmt.Fprintf(writer, "%s\n", entry.Directory())
}
return nil
diff --git a/internal/backend/core.go b/internal/backend/core.go
@@ -43,6 +43,10 @@ type (
Value string
backing gokeepasslib.Entry
}
+ QuerySeq struct {
+ QueryEntity
+ Error error
+ }
// TransactionEntity objects alter the database entities
TransactionEntity struct {
path string
diff --git a/internal/backend/query.go b/internal/backend/query.go
@@ -6,6 +6,7 @@ import (
"encoding/json"
"errors"
"fmt"
+ "iter"
"sort"
"strings"
@@ -70,7 +71,7 @@ func (t *Transaction) MatchPath(path string) ([]QueryEntity, error) {
if strings.HasSuffix(prefix, pathSep) {
return nil, errors.New("invalid match criteria, too many path separators")
}
- return t.QueryCallback(QueryOptions{Mode: PrefixMode, Criteria: prefix + pathSep, Values: BlankValue})
+ return t.queryCollect(QueryOptions{Mode: PrefixMode, Criteria: prefix + pathSep, Values: BlankValue})
}
// Get will request a singular entity
@@ -79,7 +80,7 @@ func (t *Transaction) Get(path string, mode ValueMode) (*QueryEntity, error) {
if err != nil {
return nil, err
}
- e, err := t.QueryCallback(QueryOptions{Mode: ExactMode, Criteria: path, Values: mode})
+ e, err := t.queryCollect(QueryOptions{Mode: ExactMode, Criteria: path, Values: mode})
if err != nil {
return nil, err
}
@@ -108,8 +109,23 @@ func forEach(offset string, groups []gokeepasslib.Group, entries []gokeepasslib.
}
}
+func (t *Transaction) queryCollect(args QueryOptions) ([]QueryEntity, error) {
+ e, err := t.QueryCallback(args)
+ if err != nil {
+ return nil, err
+ }
+ var entities []QueryEntity
+ for entity := range e {
+ if entity.Error != nil {
+ return nil, entity.Error
+ }
+ entities = append(entities, entity.QueryEntity)
+ }
+ return entities, nil
+}
+
// QueryCallback will retrieve a query based on setting
-func (t *Transaction) QueryCallback(args QueryOptions) ([]QueryEntity, error) {
+func (t *Transaction) QueryCallback(args QueryOptions) (iter.Seq[QuerySeq], error) {
if args.Mode == noneMode {
return nil, errors.New("no query mode specified")
}
@@ -174,42 +190,47 @@ func (t *Transaction) QueryCallback(args QueryOptions) ([]QueryEntity, error) {
return nil, err
}
}
- var results []QueryEntity
- for _, k := range keys {
- entity := QueryEntity{Path: k}
- if args.Values != BlankValue {
- e, ok := entities[k]
- if !ok {
- return nil, errors.New("failed to read entity back from map")
- }
- val := getValue(e.backing, notesKey)
- if strings.TrimSpace(val) == "" {
- val = e.backing.GetPassword()
- }
- switch args.Values {
- case JSONValue:
- data := ""
- switch jsonMode {
- case config.JSONOutputs.Raw:
- data = val
- case config.JSONOutputs.Hash:
- data = fmt.Sprintf("%x", sha512.Sum512([]byte(val)))
- if hashLength > 0 && len(data) > hashLength {
- data = data[0:hashLength]
+ return func(yield func(QuerySeq) bool) {
+ for _, k := range keys {
+ entity := QuerySeq{}
+ entity.Path = k
+ if args.Values != BlankValue {
+ e, ok := entities[k]
+ if ok {
+ val := getValue(e.backing, notesKey)
+ if strings.TrimSpace(val) == "" {
+ val = e.backing.GetPassword()
}
+ switch args.Values {
+ case JSONValue:
+ data := ""
+ switch jsonMode {
+ case config.JSONOutputs.Raw:
+ data = val
+ case config.JSONOutputs.Hash:
+ data = fmt.Sprintf("%x", sha512.Sum512([]byte(val)))
+ if hashLength > 0 && len(data) > hashLength {
+ data = data[0:hashLength]
+ }
+ }
+ t := getValue(e.backing, modTimeKey)
+ s := JSON{ModTime: t, Data: data}
+ m, err := json.Marshal(s)
+ if err == nil {
+ entity.Value = string(m)
+ } else {
+ entity.Error = err
+ }
+ case SecretValue:
+ entity.Value = val
+ }
+ } else {
+ entity.Error = errors.New("failed to read entity back from map")
}
- t := getValue(e.backing, modTimeKey)
- s := JSON{ModTime: t, Data: data}
- m, err := json.Marshal(s)
- if err != nil {
- return nil, err
- }
- entity.Value = string(m)
- case SecretValue:
- entity.Value = val
+ }
+ if !yield(entity) {
+ return
}
}
- results = append(results, entity)
- }
- return results, nil
+ }, nil
}
diff --git a/internal/backend/query_test.go b/internal/backend/query_test.go
@@ -2,6 +2,7 @@ package backend_test
import (
"encoding/json"
+ "iter"
"os"
"strings"
"testing"
@@ -156,58 +157,59 @@ func TestValueModes(t *testing.T) {
}
}
+func testCollect(t *testing.T, count int, seq iter.Seq[backend.QuerySeq]) []backend.QueryEntity {
+ var collected []backend.QueryEntity
+ for item := range seq {
+ if item.Error != nil {
+ t.Errorf("unexpected error: %v", item.Error)
+ }
+ collected = append(collected, item.QueryEntity)
+ }
+ return collected
+}
+
func TestQueryCallback(t *testing.T) {
setupInserts(t)
if _, err := fullSetup(t, true).QueryCallback(backend.QueryOptions{}); err.Error() != "no query mode specified" {
t.Errorf("wrong error: %v", err)
}
- res, err := fullSetup(t, true).QueryCallback(backend.QueryOptions{Mode: backend.ListMode})
+ seq, err := fullSetup(t, true).QueryCallback(backend.QueryOptions{Mode: backend.ListMode})
if err != nil {
t.Errorf("no error: %v", err)
}
- if len(res) != 4 {
- t.Error("invalid results: not enough")
- }
+ res := testCollect(t, 4, seq)
if res[0].Path != "test/test/ab11c" || res[1].Path != "test/test/abc" || res[2].Path != "test/test/abc1ak" || res[3].Path != "test/test/abcx" {
t.Errorf("invalid results: %v", res)
}
- res, err = fullSetup(t, true).QueryCallback(backend.QueryOptions{Mode: backend.FindMode, Criteria: "1"})
+ seq, err = fullSetup(t, true).QueryCallback(backend.QueryOptions{Mode: backend.FindMode, Criteria: "1"})
if err != nil {
t.Errorf("no error: %v", err)
}
- if len(res) != 2 {
- t.Error("invalid results: not enough")
- }
+ res = testCollect(t, 2, seq)
if res[0].Path != "test/test/ab11c" || res[1].Path != "test/test/abc1ak" {
t.Errorf("invalid results: %v", res)
}
- res, err = fullSetup(t, true).QueryCallback(backend.QueryOptions{Mode: backend.SuffixMode, Criteria: "c"})
+ seq, err = fullSetup(t, true).QueryCallback(backend.QueryOptions{Mode: backend.SuffixMode, Criteria: "c"})
if err != nil {
t.Errorf("no error: %v", err)
}
- if len(res) != 2 {
- t.Error("invalid results: not enough")
- }
+ res = testCollect(t, 2, seq)
if res[0].Path != "test/test/ab11c" || res[1].Path != "test/test/abc" {
t.Errorf("invalid results: %v", res)
}
- res, err = fullSetup(t, true).QueryCallback(backend.QueryOptions{Mode: backend.ExactMode, Criteria: "test/test/abc"})
+ seq, err = fullSetup(t, true).QueryCallback(backend.QueryOptions{Mode: backend.ExactMode, Criteria: "test/test/abc"})
if err != nil {
t.Errorf("no error: %v", err)
}
- if len(res) != 1 {
- t.Error("invalid results: not enough")
- }
+ res = testCollect(t, 1, seq)
if res[0].Path != "test/test/abc" {
t.Errorf("invalid results: %v", res)
}
- res, err = fullSetup(t, true).QueryCallback(backend.QueryOptions{Mode: backend.ExactMode, Criteria: "abczzz"})
+ seq, err = fullSetup(t, true).QueryCallback(backend.QueryOptions{Mode: backend.ExactMode, Criteria: "abczzz"})
if err != nil {
t.Errorf("no error: %v", err)
}
- if len(res) != 0 {
- t.Error("invalid results: should be empty")
- }
+ testCollect(t, 0, seq)
}
func TestSetModTime(t *testing.T) {