diff --git a/README.md b/README.md index d799e82..c6302cf 100755 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ Zoom ==== -[![Version](https://img.shields.io/badge/version-0.16.0-5272B4.svg)](https://github.com/albrow/zoom/releases) +[![Version](https://img.shields.io/badge/version-0.17.0-5272B4.svg)](https://github.com/albrow/zoom/releases) [![Circle CI](https://img.shields.io/circleci/project/albrow/zoom/master.svg)](https://circleci.com/gh/albrow/zoom/tree/master) [![GoDoc](https://godoc.org/github.com/albrow/zoom?status.svg)](https://godoc.org/github.com/albrow/zoom) @@ -122,14 +122,6 @@ To install Zoom itself, run `go get -u github.com/albrow/zoom` to pull down the current master branch, or install with the dependency manager of your choice to lock in a specific version. -Zoom supports the -[Go 1.5 vendor experiment](https://docs.google.com/document/d/1Bz5-UB7g2uPBdOx-rw5t9MxJwkfpx90cqG9AFL0JAYo/edit) -and all dependencies are installed into the vendor folder, which is checked into -version control. To use Zoom, you must use Go version >= 1.5 and set -`GO15VENDOREXPERIMENT=1`. (Internally, Zoom uses -[Glide](https://github.com/Masterminds/glide) to manage dependencies, -but you do not need to install Glide to use Zoom). - Initialization -------------- @@ -143,8 +135,9 @@ import ( ) ``` -Then, you must create a new pool with `zoom.NewPool`. A pool represents a pool -of connections to the database. Since you may need access to the pool in +Then, you must create a new pool with +[`NewPool`](http://godoc.org/github.com/albrow/zoom/#NewPool). A pool represents +a pool of connections to the database. Since you may need access to the pool in different parts of your application, it is sometimes a good idea to declare a top-level variable and then initialize it in the `main` or `init` function. You must also call `pool.Close` when your application exits, so it's a good idea to @@ -154,7 +147,7 @@ use defer. var pool *zoom.Pool func main() { - pool = zoom.NewPool(nil) + pool = zoom.NewPool("localhost:6379") defer func() { if err := pool.Close(); err != nil { // handle error @@ -164,28 +157,27 @@ func main() { } ``` -The `NewPool` function takes a `zoom.PoolOptions` as an argument. Here's a list of options and their -defaults: +The `NewPool` function accepts an address which will be used to connect to +Redis, and it will use all the +[default values](http://godoc.org/github.com/albrow/zoom/#DefaultPoolOptions) +for the other options. If you need to specify different options, you can use the +[`NewPoolWithOptions`](http://godoc.org/github.com/albrow/zoom/#NewPoolWithOptions) +function. + +For convenience, the +[`PoolOptions`](http://godoc.org/github.com/albrow/zoom/#PoolOptions) type has +chainable methods for changing each option. Typically you would start with +[`DefaultOptions`](http://godoc.org/github.com/albrow/zoom/#DefaultOptions) and +call `WithX` to change value for option `X`. + +For example, here's how you could initialize a Pool that connects to Redis using +a unix socket connection on `/tmp/unix.sock`: ``` go -type PoolOptions struct { - // Address to connect to. Default: "localhost:6379" - Address string - // Network to use. Default: "tcp" - Network string - // Database id to use (using SELECT). Default: 0 - Database int - // Password for a password-protected redis database. If not empty, - // every connection will use the AUTH command during initialization - // to authenticate with the database. Default: "" - Password string -} +options := zoom.DefaultPoolOptions.WithNetwork("unix").WithAddress("/tmp/unix.sock") +pool = zoom.NewPoolWithOptions(options) ``` -If you pass in `nil` to `NewPool`, Zoom will use all the default values. Any fields in the `PoolOptions` -struct that are empty (e.g., an empty string or 0) will fall back to their default values, so you only need -to provide a `PoolOptions` struct with the fields you want to change. - Models ------ @@ -262,56 +254,47 @@ pools, you will need to create a collection for each pool. ``` go // Create a new collection for the Person type. -People, err := pool.NewCollection(&Person{}, nil) +People, err := pool.NewCollection(&Person{}) if err != nil { // handle error } ``` -The second argument to `NewCollection` is a + +The convention is to name the `Collection` the plural of the corresponding +model type (e.g. "People"), but it's just a variable so you can name it +whatever you want. + +`NewCollection` will use all the +[default options](http://godoc.org/github.com/albrow/zoom/#DefaultCollectionOptions) +for the collection. + +If you need to specify other options, use the +[`NewCollectionWithOptions`](http://godoc.org/github.com/albrow/zoom/#NewCollectionWithOptions) +function. The second argument to `NewCollectionWithOptions` is a [`CollectionOptions`](http://godoc.org/github.com/albrow/zoom#CollectionOptions). -It works similarly to `PoolOptions`. You can just pass nil to use all the -default options. Additionally, any zero-valued fields in the struct indicate -that the default value should be used for that field. +It works similarly to `PoolOptions`, so you can start with +[`DefaultCollectionOptions`](http://godoc.org/github.com/albrow/zoom/#DefaultCollectionOptions) +and use the chainable `WithX` methods to specify a new value for option `X`. +Here's an example of how to create a new `Collection` which is indexed, allowing +you to use Queries and methods like `FindAll` which rely on collection indexing: ``` go -type CollectionOptions struct { - // FallbackMarshalerUnmarshaler is used to marshal/unmarshal any type - // into a slice of bytes which is suitable for storing in the database. If - // Zoom does not know how to directly encode a certain type into bytes, it - // will use the FallbackMarshalerUnmarshaler. By default, the value is - // GobMarshalerUnmarshaler which uses the builtin gob package. Zoom also - // provides JSONMarshalerUnmarshaler to support json encoding out of the box. - // Default: GobMarshalerUnmarshaler. - FallbackMarshalerUnmarshaler MarshalerUnmarshaler - // If Index is true, any model in the collection that is saved will be added - // to a set in Redis which acts as an index. The default value is false. The - // key for the set is exposed via the IndexKey method. Queries and the - // FindAll, Count, and DeleteAll methods will not work for unindexed - // collections. This may change in future versions. Default: false. - Index bool - // Name is a unique string identifier to use for the collection in Redis. All - // models in this collection that are saved in the database will use the - // collection name as a prefix. If not provided, the default name will be the - // name of the model type without the package prefix or pointer declarations. - // So for example, the default name corresponding to *models.User would be - // "User". If a custom name is provided, it cannot contain a colon. - // Default: The name of the model type, excluding package prefix and pointer - // declarations. - Name string +options := zoom.DefaultCollectionOptions.WithIndex(true) +People, err = pool.NewCollection(&Person{}, options) +if err != nil { + // handle error } ``` There are a few important points to emphasize concerning collections: 1. The collection name cannot contain a colon. -2. Queries, as well as the FindAll, DeleteAll, and Count methods will not work - if Index is false. This may change in future versions. +2. Queries, as well as the `FindAll`, `DeleteAll`, and `Count` methods will not + work if `Index` is `false`. This may change in future versions. -Convention is to name the `Collection` the plural of the corresponding -model type (e.g. "People"), but it's just a variable so you can name it -whatever you want. If you need to access a `Collection` in different parts of +If you need to access a `Collection` in different parts of your application, it is sometimes a good idea to declare a top-level variable and then initialize it in the `init` function: @@ -323,13 +306,14 @@ var ( func init() { var err error // Assuming pool and Person are already defined. - People, err = pool.NewCollection(&Person{}, nil) + People, err = pool.NewCollection(&Person{}) if err != nil { // handle error } } ``` + ### Saving Models Continuing from the previous example, to persistently save a `Person` model to diff --git a/collection.go b/collection.go index 0cc8bf3..f84cf79 100644 --- a/collection.go +++ b/collection.go @@ -9,6 +9,7 @@ package zoom import ( + "container/list" "fmt" "reflect" "strings" @@ -16,6 +17,8 @@ import ( "github.com/garyburd/redigo/redis" ) +var collections = list.New() + // Collection represents a specific registered type of model. It has methods // for saving, finding, and deleting models of a specific type. Use the // NewCollection method to create a new collection. @@ -27,49 +30,86 @@ type Collection struct { // CollectionOptions contains various options for a pool. type CollectionOptions struct { - // FallbackMarshalerUnmarshaler is used to marshal/unmarshal any type - // into a slice of bytes which is suitable for storing in the database. If - // Zoom does not know how to directly encode a certain type into bytes, it - // will use the FallbackMarshalerUnmarshaler. By default, the value is - // GobMarshalerUnmarshaler which uses the builtin gob package. Zoom also - // provides JSONMarshalerUnmarshaler to support json encoding out of the box. - // Default: GobMarshalerUnmarshaler. + // FallbackMarshalerUnmarshaler is used to marshal/unmarshal any type into a + // slice of bytes which is suitable for storing in the database. If Zoom does + // not know how to directly encode a certain type into bytes, it will use the + // FallbackMarshalerUnmarshaler. Zoom provides GobMarshalerUnmarshaler and + // JSONMarshalerUnmarshaler out of the box. You are also free to write your + // own implementation. FallbackMarshalerUnmarshaler MarshalerUnmarshaler - // Iff Index is true, any model in the collection that is saved will be added - // to a set in redis which acts as an index. The default value is false. The - // key for the set is exposed via the IndexKey method. Queries and the + // If Index is true, any model in the collection that is saved will be added + // to a set in Redis which acts as an index on all models in the collection. + // The key for the set is exposed via the IndexKey method. Queries and the // FindAll, Count, and DeleteAll methods will not work for unindexed - // collections. This may change in future versions. Default: false. + // collections. This may change in future versions. Index bool - // Name is a unique string identifier to use for the collection in redis. All + // Name is a unique string identifier to use for the collection in Redis. All // models in this collection that are saved in the database will use the - // collection name as a prefix. If not provided, the default name will be the - // name of the model type without the package prefix or pointer declarations. - // So for example, the default name corresponding to *models.User would be - // "User". If a custom name is provided, it cannot contain a colon. - // Default: The name of the model type, excluding package prefix and pointer - // declarations. + // collection name as a prefix. If Name is an empty string, Zoom will use the + // name of the concrete model type, excluding package prefix and pointer + // declarations, as the name for the collection. So for example, the default + // name corresponding to *models.User would be "User". If a custom name is + // provided, it cannot contain a colon. Name string } +// DefaultCollectionOptions is the default set of options for a collection. +var DefaultCollectionOptions = CollectionOptions{ + FallbackMarshalerUnmarshaler: GobMarshalerUnmarshaler, + Index: false, + Name: "", +} + +// WithFallbackMarshalerUnmarshaler returns a new copy of the options with the +// FallbackMarshalerUnmarshaler property set to the given value. It does not +// mutate the original options. +func (options CollectionOptions) WithFallbackMarshalerUnmarshaler(fallback MarshalerUnmarshaler) CollectionOptions { + options.FallbackMarshalerUnmarshaler = fallback + return options +} + +// WithIndex returns a new copy of the options with the Index property set to +// the given value. It does not mutate the original options. +func (options CollectionOptions) WithIndex(index bool) CollectionOptions { + options.Index = index + return options +} + +// WithName returns a new copy of the options with the Name property set to the +// given value. It does not mutate the original options. +func (options CollectionOptions) WithName(name string) CollectionOptions { + options.Name = name + return options +} + // NewCollection registers and returns a new collection of the given model type. // You must create a collection for each model type you want to save. The type // of model must be unique, i.e., not already registered, and must be a pointer -// to a struct. To use the default options, pass in nil as the options argument. -func (p *Pool) NewCollection(model Model, options *CollectionOptions) (*Collection, error) { - // Parse the options - fullOptions, err := parseCollectionOptions(model, options) - if err != nil { - return nil, err +// to a struct. NewCollection will use all the default options for the +// collection, which are specified in DefaultCollectionOptions. If you want to +// specify different options, use the NewCollectionWithOptions method. +func (p *Pool) NewCollection(model Model) (*Collection, error) { + return p.NewCollectionWithOptions(model, DefaultCollectionOptions) +} + +// NewCollection registers and returns a new collection of the given model type +// and with the provided options. +func (p *Pool) NewCollectionWithOptions(model Model, options CollectionOptions) (*Collection, error) { + typ := reflect.TypeOf(model) + // If options.Name is empty use the name of the concrete model type (without + // the package prefix). + if options.Name == "" { + options.Name = getDefaultModelSpecName(typ) + } else if strings.Contains(options.Name, ":") { + return nil, fmt.Errorf("zoom: CollectionOptions.Name cannot contain a colon. Got: %s", options.Name) } // Make sure the name and type have not been previously registered - typ := reflect.TypeOf(model) switch { case p.typeIsRegistered(typ): return nil, fmt.Errorf("zoom: Error in NewCollection: The type %T has already been registered", model) - case p.nameIsRegistered(fullOptions.Name): - return nil, fmt.Errorf("zoom: Error in NewCollection: The name %s has already been registered", fullOptions.Name) + case p.nameIsRegistered(options.Name): + return nil, fmt.Errorf("zoom: Error in NewCollection: The name %s has already been registered", options.Name) case !typeIsPointerToStruct(typ): return nil, fmt.Errorf("zoom: NewCollection requires a pointer to a struct as an argument. Got type %T", model) } @@ -79,17 +119,18 @@ func (p *Pool) NewCollection(model Model, options *CollectionOptions) (*Collecti if err != nil { return nil, err } - spec.name = fullOptions.Name - spec.fallback = fullOptions.FallbackMarshalerUnmarshaler + spec.name = options.Name + spec.fallback = options.FallbackMarshalerUnmarshaler p.modelTypeToSpec[typ] = spec - p.modelNameToSpec[fullOptions.Name] = spec + p.modelNameToSpec[options.Name] = spec - // Return the Collection - return &Collection{ + collection := &Collection{ spec: spec, pool: p, - index: fullOptions.Index, - }, nil + index: options.Index, + } + addCollection(collection) + return collection, nil } // Name returns the name for the given collection. The name is a unique string @@ -99,30 +140,31 @@ func (c *Collection) Name() string { return c.spec.name } -// parseCollectionOptions returns a well-formed CollectionOptions struct. If -// passedOptions is nil, it uses all the default options. Else, for each zero -// value field in passedOptions, it uses the default value for that field. -func parseCollectionOptions(model Model, passedOptions *CollectionOptions) (*CollectionOptions, error) { - // If passedOptions is nil, use all the default values - if passedOptions == nil { - return &CollectionOptions{ - FallbackMarshalerUnmarshaler: GobMarshalerUnmarshaler, - Name: getDefaultModelSpecName(reflect.TypeOf(model)), - }, nil - } - // Copy and validate the passedOptions - newOptions := *passedOptions - if newOptions.Name == "" { - newOptions.Name = getDefaultModelSpecName(reflect.TypeOf(model)) - } else if strings.Contains(newOptions.Name, ":") { - return nil, fmt.Errorf("zoom: CollectionOptions.Name cannot contain a colon. Got: %s", newOptions.Name) +// addCollection adds the given spec to the list of collections iff it has not +// already been added. +func addCollection(collection *Collection) { + for e := collections.Front(); e != nil; e = e.Next() { + otherCollection := e.Value.(*Collection) + if collection.spec.typ == otherCollection.spec.typ { + // The Collection was already added to the list. No need to do + // anything. + return + } } - if newOptions.FallbackMarshalerUnmarshaler == nil { - newOptions.FallbackMarshalerUnmarshaler = GobMarshalerUnmarshaler + collections.PushFront(collection) +} + +// getCollectionForModel returns the Collection corresponding to the type of +// model. +func getCollectionForModel(model Model) (*Collection, error) { + typ := reflect.TypeOf(model) + for e := collections.Front(); e != nil; e = e.Next() { + col := e.Value.(*Collection) + if col.spec.typ == typ { + return col, nil + } } - // NOTE: we don't need to modify the Index field because the default value, - // false, is also the zero value. - return &newOptions, nil + return nil, fmt.Errorf("Could not find Collection for type %T", model) } func (p *Pool) typeIsRegistered(typ reflect.Type) bool { @@ -223,8 +265,9 @@ func (t *Transaction) Save(c *Collection, model Model) { } // Create a modelRef and start a transaction mr := &modelRef{ - spec: c.spec, - model: model, + collection: c, + model: model, + spec: c.spec, } // Save indexes // This must happen first, because it relies on reading the old field values @@ -325,49 +368,50 @@ func (t *Transaction) saveStringIndex(mr *modelRef, fs *fieldSpec) { t.Command("ZADD", redis.Args{indexKey, 0, member}, nil) } -// UpdateFields updates only the given fields of the model. UpdateFields uses +// SaveFields saves only the given fields of the model. SaveFields uses // "last write wins" semantics. If another caller updates the the same fields // concurrently, your updates may be overwritten. It will return an error if // the type of model does not match the registered Collection, or if any of // the given fieldNames are not found in the registered Collection. If -// UpdateFields is called on a model that has not yet been saved, it will not +// SaveFields is called on a model that has not yet been saved, it will not // return an error. Instead, only the given fields will be saved in the // database. -func (c *Collection) UpdateFields(fieldNames []string, model Model) error { +func (c *Collection) SaveFields(fieldNames []string, model Model) error { t := c.pool.NewTransaction() - t.UpdateFields(c, fieldNames, model) + t.SaveFields(c, fieldNames, model) if err := t.Exec(); err != nil { return err } return nil } -// UpdateFields updates only the given fields of the model inside an existing -// transaction. UpdateFields will set the err property of the transaction if the +// SaveFields saves only the given fields of the model inside an existing +// transaction. SaveFields will set the err property of the transaction if the // type of model does not match the registered Collection, or if any of the // given fieldNames are not found in the model type. In either case, the -// transaction will return the error when you call Exec. UpdateFields uses "last +// transaction will return the error when you call Exec. SaveFields uses "last // write wins" semantics. If another caller updates the the same fields -// concurrently, your updates may be overwritten. If UpdateFields is called on a +// concurrently, your updates may be overwritten. If SaveFields is called on a // model that has not yet been saved, it will not return an error. Instead, only // the given fields will be saved in the database. -func (t *Transaction) UpdateFields(c *Collection, fieldNames []string, model Model) { +func (t *Transaction) SaveFields(c *Collection, fieldNames []string, model Model) { // Check the model type if err := c.checkModelType(model); err != nil { - t.setError(fmt.Errorf("zoom: Error in UpdateFields or Transaction.UpdateFields: %s", err.Error())) + t.setError(fmt.Errorf("zoom: Error in SaveFields or Transaction.SaveFields: %s", err.Error())) return } // Check the given field names for _, fieldName := range fieldNames { if !stringSliceContains(c.spec.fieldNames(), fieldName) { - t.setError(fmt.Errorf("zoom: Error in UpdateFields or Transaction.UpdateFields: Collection %s does not have field named %s", c.Name(), fieldName)) + t.setError(fmt.Errorf("zoom: Error in SaveFields or Transaction.SaveFields: Collection %s does not have field named %s", c.Name(), fieldName)) return } } // Create a modelRef and start a transaction mr := &modelRef{ - spec: c.spec, - model: model, + collection: c, + model: model, + spec: c.spec, } // Update indexes // This must happen first, because it relies on reading the old field values @@ -386,6 +430,10 @@ func (t *Transaction) UpdateFields(c *Collection, fieldNames []string, model Mod // 1. t.Command("HMSET", hashArgs, nil) } + // Add the model id to the set of all models for this collection + if c.index { + t.Command("SADD", redis.Args{c.IndexKey(), model.ModelId()}, nil) + } } // Find retrieves a model with the given id from redis and scans its values @@ -420,9 +468,12 @@ func (t *Transaction) Find(c *Collection, id string, model Model) { } model.SetModelId(id) mr := &modelRef{ - spec: c.spec, - model: model, + collection: c, + model: model, + spec: c.spec, } + // Check if the model actually exists + t.Command("EXISTS", redis.Args{mr.key()}, newModelExistsHandler(c, id)) // Get the fields from the main hash for this model args := redis.Args{mr.key()} for _, fieldName := range mr.spec.fieldRedisNames() { @@ -456,8 +507,9 @@ func (t *Transaction) FindFields(c *Collection, id string, fieldNames []string, // Set the model id and create a modelRef model.SetModelId(id) mr := &modelRef{ - spec: c.spec, - model: model, + collection: c, + spec: c.spec, + model: model, } // Check the given field names and append the corresponding redis field names // to args. @@ -472,6 +524,8 @@ func (t *Transaction) FindFields(c *Collection, id string, fieldNames []string, // may be customized via struct tags. args = append(args, c.spec.fieldsByName[fieldName].redisName) } + // Check if the model actually exists. + t.Command("EXISTS", redis.Args{mr.key()}, newModelExistsHandler(c, id)) // Get the fields from the main hash for this model t.Command("HMGET", args, newScanModelRefHandler(fieldNames, mr)) } @@ -516,7 +570,7 @@ func (t *Transaction) FindAll(c *Collection, models interface{}) { t.setError(fmt.Errorf("zoom: Error in FindAll or Transaction.FindAll: %s", err.Error())) return } - sortArgs := c.spec.sortArgs(c.spec.indexKey(), c.spec.fieldRedisNames(), 0, 0, ascendingOrder) + sortArgs := c.spec.sortArgs(c.spec.indexKey(), c.spec.fieldRedisNames(), 0, 0, false) fieldNames := append(c.spec.fieldNames(), "-") t.Command("SORT", sortArgs, newScanModelsHandler(c.spec, fieldNames, models)) } diff --git a/collection_test.go b/collection_test.go index 0392c70..3334a8b 100644 --- a/collection_test.go +++ b/collection_test.go @@ -13,7 +13,7 @@ import ( ) // collectionTestModel is a model type that is only used for testing -// the Register and RegisterName functions +// the NewCollection and NewCollectionWithOptions functions type collectionTestModel struct { Int int Bool bool @@ -25,9 +25,9 @@ func TestNewCollection(t *testing.T) { testingSetUp() defer testingTearDown() - col, err := testPool.NewCollection(&collectionTestModel{}, nil) + col, err := testPool.NewCollection(&collectionTestModel{}) if err != nil { - t.Fatalf("Unexpected error in Register: %s", err.Error()) + t.Fatalf("Unexpected error in NewCollection: %s", err.Error()) } expectedName := "collectionTestModel" expectedType := reflect.TypeOf(&collectionTestModel{}) @@ -43,12 +43,10 @@ func TestNewCollectionWithName(t *testing.T) { defer testingTearDown() expectedName := "customName" - col, err := testPool.NewCollection(&collectionTestModel{}, - &CollectionOptions{ - Name: expectedName, - }) + options := DefaultCollectionOptions.WithName(expectedName) + col, err := testPool.NewCollectionWithOptions(&collectionTestModel{}, options) if err != nil { - t.Fatalf("Unexpected error in Register: %s", err.Error()) + t.Fatalf("Unexpected error in NewCollectionWithOptions: %s", err.Error()) } expectedType := reflect.TypeOf(&collectionTestModel{}) testRegisteredCollectionType(t, col, expectedName, expectedType) @@ -136,7 +134,38 @@ func TestSave(t *testing.T) { expectFieldEquals(t, key, "Bool", mu, model.Bool) } -func TestUpdateFields(t *testing.T) { +func TestSaveFields(t *testing.T) { + testingSetUp() + defer testingTearDown() + + // Save the Int and Bool fields, leaving the String field empty. + model := &testModel{ + Int: 43, + Bool: true, + } + if err := testModels.SaveFields([]string{"Int", "Bool"}, model); err != nil { + t.Errorf("Unexpected error in testModels.SaveFields: %s", err.Error()) + } + + // Make sure the model was saved correctly + expectModelExists(t, testModels, model) + key := testModels.ModelKey(model.ModelId()) + mu := testModels.spec.fallback + expectFieldEquals(t, key, "Int", mu, model.Int) + expectFieldEquals(t, key, "String", mu, nil) + expectFieldEquals(t, key, "Bool", mu, model.Bool) + + // Make sure the model can be found. + gotModel := &testModel{} + if err := testModels.Find(model.Id, gotModel); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(model, gotModel) { + t.Errorf("Expected: %+v\nBut got: %+v", model, gotModel) + } +} + +func TestSaveFieldsOverwrite(t *testing.T) { testingSetUp() defer testingTearDown() @@ -152,8 +181,8 @@ func TestUpdateFields(t *testing.T) { originalString := model.String model.String = "new" + model.String model.Bool = !model.Bool - if err := testModels.UpdateFields([]string{"Int", "Bool"}, model); err != nil { - t.Errorf("Unexpected error in testModels.UpdateFields: %s", err.Error()) + if err := testModels.SaveFields([]string{"Int", "Bool"}, model); err != nil { + t.Errorf("Unexpected error in testModels.SaveFields: %s", err.Error()) } // Make sure the model was saved correctly @@ -186,6 +215,26 @@ func TestFind(t *testing.T) { } } +func TestFindEmpty(t *testing.T) { + testingSetUp() + defer testingTearDown() + + // Create model which is empty (no fields with values) + model := &testModel{} + if err := testModels.Save(model); err != nil { + t.Fatal(err) + } + + // Find the model in the database and store it in modelCopy + modelCopy := &testModel{} + if err := testModels.Find(model.ModelId(), modelCopy); err != nil { + t.Errorf("Unexpected error in testModels.Find: %s", err.Error()) + } + if !reflect.DeepEqual(model, modelCopy) { + t.Errorf("Found model was incorrect.\n\tExpected: %+v\n\tBut got: %+v", model, modelCopy) + } +} + func TestFindFields(t *testing.T) { testingSetUp() defer testingTearDown() diff --git a/convert.go b/convert.go index 3de0984..e9ac162 100755 --- a/convert.go +++ b/convert.go @@ -24,13 +24,16 @@ import ( // not the redis names which may be custom. func scanModel(fieldNames []string, fieldValues []interface{}, mr *modelRef) error { ms := mr.spec + if fieldValues == nil || len(fieldValues) == 0 { + return newModelNotFoundError(mr) + } for i, reply := range fieldValues { + if reply == nil { + continue + } fieldName := fieldNames[i] replyBytes, err := redis.Bytes(reply, nil) if err != nil { - if err == redis.ErrNil { - return newModelNotFoundError(mr) - } return err } if fieldName == "-" { diff --git a/convert_test.go b/convert_test.go index 0e3e472..5973fe2 100644 --- a/convert_test.go +++ b/convert_test.go @@ -10,6 +10,7 @@ package zoom import ( "reflect" "testing" + "time" ) func TestConvertPrimatives(t *testing.T) { @@ -24,6 +25,24 @@ func TestConvertPointers(t *testing.T) { testConvertType(t, indexedPointersModels, createIndexedPointersModel()) } +func TestTimeDuration(t *testing.T) { + testingSetUp() + defer testingTearDown() + + type durationModel struct { + Duration time.Duration + RandomId + } + durationModels, err := testPool.NewCollection(&durationModel{}) + if err != nil { + t.Errorf("Unexpected error in testPool.NewCollection: %s", err.Error()) + } + model := &durationModel{ + Duration: 43 * time.Second, + } + testConvertType(t, durationModels, model) +} + func TestGobFallback(t *testing.T) { testingSetUp() defer testingTearDown() @@ -38,9 +57,8 @@ func TestGobFallback(t *testing.T) { IntMap map[int]int RandomId } - gobModels, err := testPool.NewCollection(&gobModel{}, &CollectionOptions{ - FallbackMarshalerUnmarshaler: GobMarshalerUnmarshaler, - }) + options := DefaultCollectionOptions.WithFallbackMarshalerUnmarshaler(GobMarshalerUnmarshaler) + gobModels, err := testPool.NewCollectionWithOptions(&gobModel{}, options) if err != nil { t.Errorf("Unexpected error in testPool.NewCollection: %s", err.Error()) } @@ -69,9 +87,8 @@ func TestJSONFallback(t *testing.T) { EmptyInterface interface{} RandomId } - jsonModels, err := testPool.NewCollection(&jsonModel{}, &CollectionOptions{ - FallbackMarshalerUnmarshaler: JSONMarshalerUnmarshaler, - }) + options := DefaultCollectionOptions.WithFallbackMarshalerUnmarshaler(JSONMarshalerUnmarshaler) + jsonModels, err := testPool.NewCollectionWithOptions(&jsonModel{}, options) if err != nil { t.Errorf("Unexpected error in testPool.NewCollection: %s", err.Error()) } @@ -100,7 +117,7 @@ func TestConvertEmbeddedStruct(t *testing.T) { Embeddable RandomId } - embededStructModels, err := testPool.NewCollection(&embeddedStructModel{}, nil) + embededStructModels, err := testPool.NewCollection(&embeddedStructModel{}) if err != nil { t.Errorf("Unexpected error in testPool.NewCollection: %s", err.Error()) } @@ -122,7 +139,7 @@ func TestEmbeddedPointerToStruct(t *testing.T) { *Embeddable RandomId } - embededPointerToStructModels, err := testPool.NewCollection(&embeddedPointerToStructModel{}, nil) + embededPointerToStructModels, err := testPool.NewCollection(&embeddedPointerToStructModel{}) if err != nil { t.Errorf("Unexpected error in testPool.NewCollection: %s", err.Error()) } @@ -164,4 +181,12 @@ func testConvertType(t *testing.T, collection *Collection, model Model) { if err := collection.Save(emptyModel); err != nil { t.Errorf("Unexpected error saving an empty model: %s", err.Error()) } + emptyModelCopy, ok := reflect.New(collection.spec.typ.Elem()).Interface().(Model) + if err := collection.Find(emptyModel.ModelId(), emptyModelCopy); err != nil { + t.Errorf("Unexpected error in Find: %s", err.Error()) + } + // Make sure the copy equals the original + if !reflect.DeepEqual(emptyModel, emptyModelCopy) { + t.Errorf("Model of type %T was not saved/retrieved correctly.\nExpected: %+v\nGot: %+v", emptyModel, emptyModel, emptyModelCopy) + } } diff --git a/doc.go b/doc.go index 1166fc8..c577f5d 100755 --- a/doc.go +++ b/doc.go @@ -8,7 +8,7 @@ // atomic transactions, lua scripts, and running Redis commands // directly if needed. // -// Version 0.16.0 +// Version 0.17.0 // // For installation instructions, examples, and more information // visit https://github.com/albrow/zoom. diff --git a/errors.go b/errors.go index e8ffc78..b939d54 100755 --- a/errors.go +++ b/errors.go @@ -12,7 +12,8 @@ import "fmt" // ModelNotFoundError is returned from Find and Query methods if a model // that fits the given criteria is not found. type ModelNotFoundError struct { - Msg string + Collection *Collection + Msg string } func (e ModelNotFoundError) Error() string { @@ -26,5 +27,8 @@ func newModelNotFoundError(mr *modelRef) error { } else { msg = fmt.Sprintf("Could not find %s with the given criteria", mr.spec.name) } - return ModelNotFoundError{Msg: msg} + return ModelNotFoundError{ + Collection: mr.collection, + Msg: msg, + } } diff --git a/handlers.go b/handlers.go index 1db00e1..a0fda1d 100644 --- a/handlers.go +++ b/handlers.go @@ -20,6 +20,25 @@ func newAlwaysErrorHandler(err error) ReplyHandler { } } +// newModelExistsHandler returns a reply handler which will return a +// ModelNotFound error if the value of reply is false. It is expected to be +// used as the reply handler for an EXISTS command. +func newModelExistsHandler(collection *Collection, modelId string) ReplyHandler { + return func(reply interface{}) error { + exists, err := redis.Bool(reply, nil) + if err != nil { + return err + } + if !exists { + return ModelNotFoundError{ + Collection: collection, + Msg: fmt.Sprintf("Could not find %s with id = %s", collection.spec.name, modelId), + } + } + return nil + } +} + // NewScanIntHandler returns a ReplyHandler which will convert the reply to an // integer and set the value of i to the converted integer. The ReplyHandler // will return an error if there was a problem converting the reply. @@ -132,13 +151,14 @@ func newScanModelRefHandler(fieldNames []string, mr *modelRef) ReplyHandler { // "b1C7B0yETtXFYuKinndqoa" using the model's SetModelId method. func NewScanModelHandler(fieldNames []string, model Model) ReplyHandler { // Create a modelRef that wraps the given model. - spec, err := getModelSpecForModel(model) + collection, err := getCollectionForModel(model) if err != nil { return newAlwaysErrorHandler(err) } mr := &modelRef{ - spec: spec, - model: model, + collection: collection, + model: model, + spec: collection.spec, } // Create and return a reply handler using newScanModelRefHandler return newScanModelRefHandler(fieldNames, mr) @@ -150,17 +170,18 @@ func NewScanModelHandler(fieldNames []string, model Model) ReplyHandler { func newScanModelsHandler(spec *modelSpec, fieldNames []string, models interface{}) ReplyHandler { return func(reply interface{}) error { allFields, err := redis.Values(reply, nil) + modelsVal := reflect.ValueOf(models).Elem() if err != nil { if err == redis.ErrNil { - return ModelNotFoundError{ - Msg: fmt.Sprintf("Could not find %s with the given criteria", spec.name), - } + // This means no models matched the criteria. Set the length of + // models to 0 to indicate this and then return. + modelsVal.SetLen(0) + return nil } return err } numFields := len(fieldNames) numModels := len(allFields) / numFields - modelsVal := reflect.ValueOf(models).Elem() for i := 0; i < numModels; i++ { start := i * numFields stop := i*numFields + numFields @@ -233,3 +254,39 @@ func newScanModelsHandler(spec *modelSpec, fieldNames []string, models interface func NewScanModelsHandler(collection *Collection, fieldNames []string, models interface{}) ReplyHandler { return newScanModelsHandler(collection.spec, fieldNames, models) } + +// newScanOneModelHandler returns a ReplyHandler which will scan reply into the +// given model. It differs from NewScanModelHandler in that it expects reply to +// have an underlying type of [][]byte{}. Specifically, if fieldNames is +// ["Age", "Name", "-"], reply should look like: +// +// 1) "25" +// 2) "Bob" +// 3) "b1C7B0yETtXFYuKinndqoa" +// +// Note that this is similar to the kind of reply expected by +// NewScanModelHandler except that there should only ever be len(fieldNames) +// fields in the reply (i.e. enough fields for exactly one model). If the reply +// is nil or an empty array, the ReplyHandler will return an error. This makes +// newScanOneModelHandler useful in contexts where you expect exactly one model +// to match certain query criteria (e.g. for Query.RunOne). +func newScanOneModelHandler(q *query, spec *modelSpec, fieldNames []string, model Model) ReplyHandler { + return func(reply interface{}) error { + // Use reflection to create a slice which contains only one element, the + // given model. We'll then pass this in to newScanModelsHandler to set the + // value of model. + modelsVal := reflect.New(reflect.SliceOf(reflect.TypeOf(model))) + modelsVal.Elem().Set(reflect.Append(modelsVal.Elem(), reflect.ValueOf(model))) + if err := newScanModelsHandler(spec, fieldNames, modelsVal.Interface())(reply); err != nil { + return err + } + // Return an error if we didn't find any models matching the criteria. + // When you use newScanOneModelHandler, you are explicitly saying that you + // expect exactly one model. + if modelsVal.Elem().Len() == 0 { + msg := fmt.Sprintf("Could not find a model with the given query criteria: %s", q) + return ModelNotFoundError{Msg: msg} + } + return nil + } +} diff --git a/handlers_test.go b/handlers_test.go index 8568274..e150235 100644 --- a/handlers_test.go +++ b/handlers_test.go @@ -152,3 +152,38 @@ func TestScanModelsHandler(t *testing.T) { ) } } + +func TestScanOneModelHandler(t *testing.T) { + testingSetUp() + defer testingTearDown() + expected := &testModel{ + Int: 38, + String: "bar", + Bool: false, + RandomId: RandomId{ + Id: "thisIsAnId", + }, + } + fieldNames := []string{"String", "-", "Int"} + got := &testModel{} + handler := newScanOneModelHandler(testModels.NewQuery().query, testModels.spec, fieldNames, got) + if err := handler([]interface{}{ + []byte("bar"), + []byte("thisIsAnId"), + []byte("38"), + }); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(expected, got) { + t.Errorf("\nExpected: %s\nBut got: %s\n", + spew.Sprint(expected), + spew.Sprint(got), + ) + } + // If reply is nil, the ReplyHandler should return a ModelNotFoundError. + if err := handler(nil); err == nil { + t.Error("Expected error but got none") + } else if _, ok := err.(ModelNotFoundError); !ok { + t.Errorf("Expected ModelNotFoundError but got: %T: %v", err, err) + } +} diff --git a/internal_query.go b/internal_query.go new file mode 100644 index 0000000..8c0c90a --- /dev/null +++ b/internal_query.go @@ -0,0 +1,650 @@ +// Copyright 2015 Alex Browne. All rights reserved. +// Use of this source code is governed by the MIT +// license, which can be found in the LICENSE file. + +// File query.go contains code related to the query abstraction. + +package zoom + +import ( + "errors" + "fmt" + "reflect" + "strings" + + "github.com/garyburd/redigo/redis" +) + +// query represents a query which will retrieve some models from +// the database. A Query may consist of one or more query modifiers +// (e.g. Filter or Order) and may be executed with a query finisher +// (e.g. Run or Ids). +type query struct { + collection *Collection + pool *Pool + includes []string + excludes []string + order order + limit uint + offset uint + filters []filter + err error +} + +// newQuery creates and returns a new query with the given collection. It will +// add an error to the query if the collection is not indexed. +func newQuery(collection *Collection) *query { + q := &query{ + collection: collection, + pool: collection.pool, + } + // For now, only indexed collections are queryable. This might change in + // future versions. + if !collection.index { + q.setError(fmt.Errorf("zoom: error in NewQuery: Only indexed collections are queryable.")) + return q + } + return q +} + +// String satisfies fmt.Stringer and prints out the query in a format that +// matches the go code used to declare it. +func (q *query) String() string { + result := fmt.Sprintf("%s.NewQuery()", q.collection.Name()) + for _, filter := range q.filters { + result += fmt.Sprintf(".%s", filter) + } + if q.hasOrder() { + result += fmt.Sprintf(".%s", q.order) + } + if q.hasOffset() { + result += fmt.Sprintf(".Offset(%d)", q.offset) + } + if q.hasLimit() { + result += fmt.Sprintf(".Limit(%d)", q.limit) + } + if q.hasIncludes() { + result += fmt.Sprintf(`.Include("%s")`, strings.Join(q.includes, `", "`)) + } else if q.hasExcludes() { + result += fmt.Sprintf(`.Exclude("%s")`, strings.Join(q.excludes, `", "`)) + } + return result +} + +type order struct { + fieldName string + redisName string + kind orderKind +} + +func (o order) String() string { + if o.kind == ascendingOrder { + return fmt.Sprintf(`Order("%s")`, o.fieldName) + } else { + return fmt.Sprintf(`Order("-%s")`, o.fieldName) + } +} + +type orderKind int + +const ( + ascendingOrder orderKind = iota + descendingOrder +) + +func (ok orderKind) String() string { + switch ok { + case ascendingOrder: + return "ascending" + case descendingOrder: + return "descending" + } + return "" +} + +type filter struct { + fieldSpec *fieldSpec + op filterOp + value reflect.Value +} + +func (f filter) String() string { + if f.value.Kind() == reflect.String { + return fmt.Sprintf(`Filter("%s %s", "%s")`, f.fieldSpec.name, f.op, f.value.String()) + } else { + return fmt.Sprintf(`Filter("%s %s", %v)`, f.fieldSpec.name, f.op, f.value.Interface()) + } +} + +type filterOp int + +const ( + equalOp filterOp = iota + notEqualOp + greaterOp + lessOp + greaterOrEqualOp + lessOrEqualOp +) + +func (fk filterOp) String() string { + switch fk { + case equalOp: + return "=" + case notEqualOp: + return "!=" + case greaterOp: + return ">" + case lessOp: + return "<" + case greaterOrEqualOp: + return ">=" + case lessOrEqualOp: + return "<=" + } + return "" +} + +var filterOps = map[string]filterOp{ + "=": equalOp, + "!=": notEqualOp, + ">": greaterOp, + "<": lessOp, + ">=": greaterOrEqualOp, + "<=": lessOrEqualOp, +} + +// setError sets the err property of q only if it has not already been set +func (q *query) setError(e error) { + if !q.hasError() { + q.err = e + } +} + +// Order specifies a field by which to sort the models. fieldName should be +// a field in the struct type corresponding to the Collection used in the query +// constructor. By default, the records are sorted by ascending order by the given +// field. To sort by descending order, put a negative sign before the field name. +// Zoom can only sort by fields which have been indexed, i.e. those which have the +// `zoom:"index"` struct tag. However, in the future this may change. Only one +// order may be specified per query. However in the future, secondary orders may be +// allowed, and will take effect when two or more models have the same value for the +// primary order field. Order will set an error on the query if the fieldName is invalid, +// if another order has already been applied to the query, or if the fieldName specified +// does not correspond to an indexed field. The error, same as any other error +// that occurs during the lifetime of the query, is not returned until the query +// is executed. When the query is executed the first error that occurred during +// the lifetime of the query object (if any) will be returned. +func (q *query) Order(fieldName string) { + if q.hasOrder() { + // TODO: allow secondary sort orders? + q.setError(errors.New("zoom: error in Query.Order: previous order already specified. Only one order per query is allowed.")) + return + } + // Check for the presence of the "-" prefix + var orderKind orderKind + if strings.HasPrefix(fieldName, "-") { + orderKind = descendingOrder + // remove the "-" prefix + fieldName = fieldName[1:] + } else { + orderKind = ascendingOrder + } + // Get the redisName for the given fieldName + fs, found := q.collection.spec.fieldsByName[fieldName] + if !found { + err := fmt.Errorf("zoom: error in Query.Order: could not find field %s in type %s", fieldName, q.collection.spec.typ.String()) + q.setError(err) + return + } + q.order = order{ + fieldName: fs.name, + redisName: fs.redisName, + kind: orderKind, + } +} + +// Limit specifies an upper limit on the number of records to return. If amount +// is 0, no limit will be applied. The default value is 0. +func (q *query) Limit(amount uint) { + q.limit = amount +} + +// Offset specifies a starting index (inclusive) from which to start counting +// records that will be returned. The default value is 0. +func (q *query) Offset(amount uint) { + q.offset = amount +} + +// Include specifies one or more field names which will be read from the +// database and scanned into the resulting models when the query is run. Field +// names which are not specified in Include will not be read or scanned. You can +// only use one of Include or Exclude, not both on the same query. Include will +// set an error if you try to use it with Exclude on the same query. The error, +// same as any other error that occurs during the lifetime of the query, is not +// returned until the query is executed. When the query is executed the first +// error that occurred during the lifetime of the query object (if any) will be +// returned. +func (q *query) Include(fields ...string) { + if q.hasExcludes() { + q.setError(errors.New("zoom: cannot use both Include and Exclude modifiers on a query")) + return + } + q.includes = append(q.includes, fields...) +} + +// Exclude specifies one or more field names which will *not* be read from the +// database and scanned. Any other fields *will* be read and scanned into the +// resulting models when the query is run. You can only use one of Include or +// Exclude, not both on the same query. Exclude will set an error if you try to +// use it with Include on the same query. The error, same as any other error +// that occurs during the lifetime of the query, is not returned until the query +// is executed. When the query is executed the first error that occurred during +// the lifetime of the query object (if any) will be returned. +func (q *query) Exclude(fields ...string) { + if q.hasIncludes() { + q.setError(errors.New("zoom: cannot use both Include and Exclude modifiers on a query")) + return + } + q.excludes = append(q.excludes, fields...) +} + +// Filter applies a filter to the query, which will cause the query to only +// return models with attributes matching the expression. filterString should be +// an expression which includes a fieldName, a space, and an operator in that +// order. Operators must be one of "=", "!=", ">", "<", ">=", or "<=". You can +// only use Filter on fields which are indexed, i.e. those which have the +// `zoom:"index"` struct tag. If multiple filters are applied to the same query, +// the query will only return models which have matches for ALL of the filters. +// I.e. applying multiple filters is logically equivalent to combining them with +// a AND or INTERSECT operator. Filter will set an error on the query if the +// arguments are improperly formated, if the field you are attempting to filter +// is not indexed, or if the type of value does not match the type of the field. +// The error, same as any other error that occurs during the lifetime of the +// query, is not returned until the query is executed. When the query is +// executed the first error that occurred during the lifetime of the query +// object (if any) will be returned. +func (q *query) Filter(filterString string, value interface{}) { + fieldName, operator, err := splitFilterString(filterString) + if err != nil { + q.setError(err) + return + } + // Parse the filter operator + filterOp, found := filterOps[operator] + if !found { + q.setError(errors.New("zoom: invalid Filter operator in fieldStr. should be one of =, !=, >, <, >=, or <=.")) + return + } + // Get the fieldSpec for the given fieldName + fieldSpec, found := q.collection.spec.fieldsByName[fieldName] + if !found { + err := fmt.Errorf("zoom: error in Query.Order: could not find field %s in type %s", fieldName, q.collection.spec.typ.String()) + q.setError(err) + return + } + // Make sure the field is an indexed field + if fieldSpec.indexKind == noIndex { + err := fmt.Errorf("zoom: filters are only allowed on indexed fields. %s.%s is not indexed. You can index it by adding the `zoom:\"index\"` struct tag.", q.collection.spec.typ.String(), fieldName) + q.setError(err) + return + } + filter := filter{ + fieldSpec: fieldSpec, + op: filterOp, + } + // Make sure the given value is the correct type + if err := filter.checkValType(value); err != nil { + q.setError(err) + return + } + filter.value = reflect.ValueOf(value) + q.filters = append(q.filters, filter) + return +} + +func splitFilterString(filterString string) (fieldName string, operator string, err error) { + tokens := strings.Split(filterString, " ") + if len(tokens) != 2 { + return "", "", errors.New("zoom: too many spaces in fieldStr argument. should be a field name, a space, and an operator.") + } + return tokens[0], tokens[1], nil +} + +// checkValType returns an error if the type of value does not correspond to +// filter.fieldSpec. +func (filter filter) checkValType(value interface{}) error { + // Here we iterate through pointer indirections. This is so you can + // just pass in a primitive instead of a pointer to a primitive for + // filtering on fields which have pointer values. + valueType := reflect.TypeOf(value) + valueVal := reflect.ValueOf(value) + for valueType.Kind() == reflect.Ptr { + valueType = valueType.Elem() + valueVal = valueVal.Elem() + if !valueVal.IsValid() { + return errors.New("zoom: invalid value for Filter. Is it a nil pointer?") + } + } + // Also dereference the field type to reach the underlying type. + fieldType := filter.fieldSpec.typ + for fieldType.Kind() == reflect.Ptr { + fieldType = fieldType.Elem() + } + if valueType != fieldType { + return fmt.Errorf("zoom: invalid value for Filter on %s. Type of value (%T) does not match type of field (%s).", filter.fieldSpec.name, value, fieldType.String()) + } + return nil +} + +// generateIdsSet will return the key of a set or sorted set that contains all the ids +// which match the query criteria. It may also return some temporary keys which were created +// during the process of creating the set of ids. Note that tmpKeys may contain idsKey itself, +// so the temporary keys should not be deleted until after the ids have been read from idsKey. +func generateIdsSet(q *query, tx *Transaction) (idsKey string, tmpKeys []interface{}, err error) { + idsKey = q.collection.spec.indexKey() + tmpKeys = []interface{}{} + if q.hasOrder() { + fieldIndexKey, err := q.collection.spec.fieldIndexKey(q.order.fieldName) + if err != nil { + return "", nil, err + } + fieldSpec := q.collection.spec.fieldsByName[q.order.fieldName] + if fieldSpec.indexKind == stringIndex { + // If the order is a string field, we need to extract the ids before + // we use ZRANGE. Create a temporary set to store the ordered ids + orderedIdsKey := generateRandomKey("tmp:order:" + q.order.fieldName) + tmpKeys = append(tmpKeys, orderedIdsKey) + idsKey = orderedIdsKey + // TODO: as an optimization, if there is a filter on the same field, + // pass the start and stop parameters to the script. + tx.ExtractIdsFromStringIndex(fieldIndexKey, orderedIdsKey, "-", "+") + } else { + idsKey = fieldIndexKey + } + } + if q.hasFilters() { + filteredIdsKey := generateRandomKey("tmp:filter:all") + tmpKeys = append(tmpKeys, filteredIdsKey) + for i, filter := range q.filters { + if i == 0 { + // The first time, we should intersect with the ids key from above + if err := intersectFilter(q, tx, filter, idsKey, filteredIdsKey); err != nil { + return "", tmpKeys, err + } + } else { + // All other times, we should intersect with the filteredIdsKey itself + if err := intersectFilter(q, tx, filter, filteredIdsKey, filteredIdsKey); err != nil { + return "", tmpKeys, err + } + } + } + idsKey = filteredIdsKey + } + return idsKey, tmpKeys, nil +} + +// intersectFilter adds commands to the query transaction which, when run, will create a +// temporary set which contains all the ids that fit the given filter criteria. Then it will +// intersect them with origKey and stores the result in destKey. The function will automatically +// delete any temporary sets created since, in this case, they are guaranteed to not be needed +// by any other transaction commands. +func intersectFilter(q *query, tx *Transaction, filter filter, origKey string, destKey string) error { + switch filter.fieldSpec.indexKind { + case numericIndex: + return intersectNumericFilter(q, tx, filter, origKey, destKey) + case booleanIndex: + return intersectBoolFilter(q, tx, filter, origKey, destKey) + case stringIndex: + return intersectStringFilter(q, tx, filter, origKey, destKey) + } + return nil +} + +// intersectNumericFilter adds commands to the query transaction which, when run, will +// create a temporary set which contains all the ids of models which match the given +// numeric filter criteria, then intersect those ids with origKey and store the result +// in destKey. +func intersectNumericFilter(q *query, tx *Transaction, filter filter, origKey string, destKey string) error { + fieldIndexKey, err := q.collection.spec.fieldIndexKey(filter.fieldSpec.name) + if err != nil { + return err + } + if filter.op == notEqualOp { + // Special case for not equal. We need to use two separate commands + valueExclusive := fmt.Sprintf("(%v", filter.value.Interface()) + filterKey := generateRandomKey("tmp:filter:" + fieldIndexKey) + // ZADD all ids greater than filter.value + tx.ExtractIdsFromFieldIndex(fieldIndexKey, filterKey, valueExclusive, "+inf") + // ZADD all ids less than filter.value + tx.ExtractIdsFromFieldIndex(fieldIndexKey, filterKey, "-inf", valueExclusive) + // Intersect filterKey with origKey and store result in destKey + tx.Command("ZINTERSTORE", redis.Args{destKey, 2, origKey, filterKey, "WEIGHTS", 1, 0}, nil) + // Delete the temporary key + tx.Command("DEL", redis.Args{filterKey}, nil) + } else { + var min, max interface{} + switch filter.op { + case equalOp: + min, max = filter.value.Interface(), filter.value.Interface() + case lessOp: + min = "-inf" + // use "(" for exclusive + max = fmt.Sprintf("(%v", filter.value.Interface()) + case greaterOp: + min = fmt.Sprintf("(%v", filter.value.Interface()) + max = "+inf" + case lessOrEqualOp: + min = "-inf" + max = filter.value.Interface() + case greaterOrEqualOp: + min = filter.value.Interface() + max = "+inf" + } + // Get all the ids that fit the filter criteria and store them in a temporary key caled filterKey + filterKey := generateRandomKey("tmp:filter:" + fieldIndexKey) + tx.ExtractIdsFromFieldIndex(fieldIndexKey, filterKey, min, max) + // Intersect filterKey with origKey and store result in destKey + tx.Command("ZINTERSTORE", redis.Args{destKey, 2, origKey, filterKey, "WEIGHTS", 1, 0}, nil) + // Delete the temporary key + tx.Command("DEL", redis.Args{filterKey}, nil) + } + return nil +} + +// intersectBoolFilter adds commands to the query transaction which, when run, will +// create a temporary set which contains all the ids of models which match the given +// bool filter criteria, then intersect those ids with origKey and store the result +// in destKey. +func intersectBoolFilter(q *query, tx *Transaction, filter filter, origKey string, destKey string) error { + fieldIndexKey, err := q.collection.spec.fieldIndexKey(filter.fieldSpec.name) + if err != nil { + return err + } + var min, max interface{} + switch filter.op { + case equalOp: + if filter.value.Bool() { + min, max = 1, 1 + } else { + min, max = 0, 0 + } + case lessOp: + if filter.value.Bool() { + // Only false is less than true + min, max = 0, 0 + } else { + // No models are less than false, + // so we should eliminate all models + min, max = -1, -1 + } + case greaterOp: + if filter.value.Bool() { + // No models are greater than true, + // so we should eliminate all models + min, max = -1, -1 + } else { + // Only true is greater than false + min, max = 1, 1 + } + case lessOrEqualOp: + if filter.value.Bool() { + // All models are <= true + min, max = 0, 1 + } else { + // Only false is <= false + min, max = 0, 0 + } + case greaterOrEqualOp: + if filter.value.Bool() { + // Only true is >= true + min, max = 1, 1 + } else { + // All models are >= false + min, max = 0, 1 + } + case notEqualOp: + if filter.value.Bool() { + min, max = 0, 0 + } else { + min, max = 1, 1 + } + } + // Get all the ids that fit the filter criteria and store them in a temporary key caled filterKey + filterKey := generateRandomKey("tmp:filter:" + fieldIndexKey) + tx.ExtractIdsFromFieldIndex(fieldIndexKey, filterKey, min, max) + // Intersect filterKey with origKey and store result in destKey + tx.Command("ZINTERSTORE", redis.Args{destKey, 2, origKey, filterKey, "WEIGHTS", 1, 0}, nil) + // Delete the temporary key + tx.Command("DEL", redis.Args{filterKey}, nil) + return nil +} + +// intersectStringFilter adds commands to the query transaction which, when run, will +// create a temporary set which contains all the ids of models which match the given +// string filter criteria, then intersect those ids with origKey and store the result +// in destKey. +func intersectStringFilter(q *query, tx *Transaction, filter filter, origKey string, destKey string) error { + fieldIndexKey, err := q.collection.spec.fieldIndexKey(filter.fieldSpec.name) + if err != nil { + return err + } + valString := filter.value.String() + if filter.op == notEqualOp { + // Special case for not equal. We need to use two separate commands + filterKey := generateRandomKey("tmp:filter:" + fieldIndexKey) + // ZADD all ids greater than filter.value + min := "(" + valString + nullString + delString + tx.ExtractIdsFromStringIndex(fieldIndexKey, filterKey, min, "+") + // ZADD all ids less than filter.value + max := "(" + valString + tx.ExtractIdsFromStringIndex(fieldIndexKey, filterKey, "-", max) + // Intersect filterKey with origKey and store result in destKey + tx.Command("ZINTERSTORE", redis.Args{destKey, 2, origKey, filterKey, "WEIGHTS", 1, 0}, nil) + // Delete the temporary key + tx.Command("DEL", redis.Args{filterKey}, nil) + } else { + var min, max string + switch filter.op { + case equalOp: + min = "[" + valString + max = "(" + valString + nullString + delString + case lessOp: + min = "-" + max = "(" + valString + case greaterOp: + min = "(" + valString + nullString + delString + max = "+" + case lessOrEqualOp: + min = "-" + max = "(" + valString + nullString + delString + case greaterOrEqualOp: + min = "[" + valString + max = "+" + } + // Get all the ids that fit the filter criteria and store them in a temporary key caled filterKey + filterKey := generateRandomKey("tmp:filter:" + fieldIndexKey) + tx.ExtractIdsFromStringIndex(fieldIndexKey, filterKey, min, max) + // Intersect filterKey with origKey and store result in destKey + tx.Command("ZINTERSTORE", redis.Args{destKey, 2, origKey, filterKey, "WEIGHTS", 1, 0}, nil) + // Delete the temporary key + tx.Command("DEL", redis.Args{filterKey}, nil) + } + return nil +} + +// fieldNames parses the includes and excludes properties to return a list of +// field names which should be included in all find operations. If there are no +// includes or excludes, it returns all the field names. +func (q *query) fieldNames() []string { + switch { + case q.hasIncludes(): + return q.includes + case q.hasExcludes(): + results := q.collection.spec.fieldNames() + for _, name := range q.excludes { + results = removeElementFromStringSlice(results, name) + } + return results + default: + return q.collection.spec.fieldNames() + } +} + +// redisFieldNames parses the includes and excludes properties to return a list of +// redis names for each field which should be included in all find operations. If +// there are no includes or excludes, it returns the redis names for all fields. +func (q *query) redisFieldNames() []string { + fieldNames := q.fieldNames() + redisNames := []string{} + for _, fieldName := range fieldNames { + redisNames = append(redisNames, q.collection.spec.fieldsByName[fieldName].redisName) + } + return redisNames +} + +// converts limit and offset to start and stop values for cases where redis +// requires them. NOTE start cannot be negative, but stop can be +func (q *query) getStartStop() (start int, stop int) { + start = int(q.offset) + stop = -1 + if q.hasLimit() { + stop = int(start) + int(q.limit) - 1 + } + return start, stop +} + +func (q *query) hasFilters() bool { + return len(q.filters) > 0 +} + +func (q *query) hasOrder() bool { + return q.order.fieldName != "" +} + +func (q *query) hasLimit() bool { + return q.limit != 0 +} + +func (q *query) hasOffset() bool { + return q.offset != 0 +} + +func (q *query) hasIncludes() bool { + return len(q.includes) > 0 +} + +func (q *query) hasExcludes() bool { + return len(q.excludes) > 0 +} + +func (q *query) hasError() bool { + return q.err != nil +} + +// generateRandomKey generates a random string that is more or less +// guaranteed to be unique and then prepends the given prefix. It is +// used to generate keys for temporary sorted sets in queries. +func generateRandomKey(prefix string) string { + return prefix + ":" + generateRandomId() +} diff --git a/model.go b/model.go index c2a8c33..1f05d5f 100755 --- a/model.go +++ b/model.go @@ -8,10 +8,10 @@ package zoom import ( - "container/list" "fmt" "reflect" "strings" + "time" "github.com/garyburd/redigo/redis" ) @@ -86,38 +86,6 @@ const ( booleanIndex ) -var modelSpecs = list.New() - -// addModelSpec adds the given spec to the list of modelSpecs iff it has not -// already been added. -func addModelSpec(spec *modelSpec) { - for e := modelSpecs.Front(); e != nil; e = e.Next() { - otherSpec := e.Value.(*modelSpec) - if spec.typ == otherSpec.typ { - // The spec was already added to the list. No need to do anything. - return - } - } - modelSpecs.PushFront(spec) -} - -// getModelSpecForModel returns the modelSpec corresponding to the type of -// model. -func getModelSpecForModel(model Model) (*modelSpec, error) { - typ := reflect.TypeOf(model) - for e := modelSpecs.Front(); e != nil; e = e.Next() { - spec := e.Value.(*modelSpec) - if spec.typ == typ { - return spec, nil - } - } - return nil, fmt.Errorf( - "Could not find modelSpec for type %T. Is there a corresponding"+ - "Collection for this type?", - model, - ) -} - // compilesModelSpec examines typ using reflection, parses its fields, // and returns a modelSpec. func compileModelSpec(typ reflect.Type) (*modelSpec, error) { @@ -198,7 +166,6 @@ func compileModelSpec(typ reflect.Type) (*modelSpec, error) { fs.kind = inconvertibleField } } - addModelSpec(ms) return ms, nil } @@ -269,6 +236,18 @@ func (ms modelSpec) fieldRedisNames() []string { return names } +func (ms modelSpec) redisNamesForFieldNames(fieldNames []string) ([]string, error) { + redisNames := []string{} + for _, fieldName := range fieldNames { + fs, found := ms.fieldsByName[fieldName] + if !found { + return nil, fmt.Errorf("Type %s has no field named %s", ms.typ.Name(), fieldName) + } + redisNames = append(redisNames, fs.redisName) + } + return redisNames, nil +} + // fieldIndexKey returns the key for the sorted set used to index the field identified // by fieldName. It returns an error if fieldName does not identify a field in the spec // or if the field it identifies is not an indexed field. @@ -290,9 +269,9 @@ func (ms *modelSpec) fieldIndexKey(fieldName string) (string, error) { // be the key of a set or a sorted set which consists of model ids. The arguments // use they "BY nosort" option, so if a specific order is required, the setKey should be // a sorted set. -func (ms *modelSpec) sortArgs(setKey string, includeFields []string, limit int, offset uint, orderKind orderKind) redis.Args { - args := redis.Args{setKey, "BY", "nosort"} - for _, fieldName := range includeFields { +func (ms *modelSpec) sortArgs(idsKey string, redisFieldNames []string, limit int, offset uint, reverse bool) redis.Args { + args := redis.Args{idsKey, "BY", "nosort"} + for _, fieldName := range redisFieldNames { args = append(args, "GET", ms.name+":*->"+fieldName) } // We always want to get the id @@ -300,11 +279,10 @@ func (ms *modelSpec) sortArgs(setKey string, includeFields []string, limit int, if !(limit == 0 && offset == 0) { args = append(args, "LIMIT", offset, limit) } - switch orderKind { - case ascendingOrder: - args = append(args, "ASC") - case descendingOrder: + if reverse { args = append(args, "DESC") + } else { + args = append(args, "ASC") } return args } @@ -341,8 +319,9 @@ func (spec *modelSpec) checkModelsType(models interface{}) error { // itself and a pointer to the corresponding spec. This allows us to avoid constant lookups // in the modelTypeToSpec map. type modelRef struct { - model Model - spec *modelSpec + collection *Collection + model Model + spec *modelSpec } // value is an alias for reflect.ValueOf(mr.model) @@ -392,7 +371,14 @@ func (mr *modelRef) mainHashArgsForFields(fieldNames []string) (redis.Args, erro fieldVal := mr.fieldValue(fs.name) switch fs.kind { case primativeField: - args = args.Add(fs.redisName, fieldVal.Interface()) + // Add a special case for time.Duration. By default, the redigo driver + // will fall back to fmt.Sprintf, but we want to save it as an int64 in + // this case. + if fs.typ == reflect.TypeOf(time.Duration(0)) { + args = args.Add(fs.redisName, int64(fieldVal.Interface().(time.Duration))) + } else { + args = args.Add(fs.redisName, fieldVal.Interface()) + } case pointerField: if !fieldVal.IsNil() { args = args.Add(fs.redisName, fieldVal.Elem().Interface()) diff --git a/pool.go b/pool.go index 6fed812..a0b059a 100644 --- a/pool.go +++ b/pool.go @@ -20,7 +20,7 @@ import ( type Pool struct { // options is the fully parsed conifg, with defaults filling in any // blanks from the poolConfig passed into NewPool. - options *PoolOptions + options PoolOptions // redisPool is a redis.Pool redisPool *redis.Pool // modelTypeToSpec maps a registered model type to a modelSpec @@ -29,47 +29,137 @@ type Pool struct { modelNameToSpec map[string]*modelSpec } +// DefaultPoolOptions is the default set of options for a Pool. +var DefaultPoolOptions = PoolOptions{ + Address: "localhost:6379", + Database: 0, + IdleTimeout: 240 * time.Second, + MaxActive: 1000, + MaxIdle: 1000, + Network: "tcp", + Password: "", + Wait: true, +} + // PoolOptions contains various options for a pool. type PoolOptions struct { - // Address to connect to. Default: "localhost:6379" + // Address to use when connecting to Redis. Address string - // Network to use. Default: "tcp" - Network string - // Database id to use (using SELECT). Default: 0 + // Database id to use (using SELECT). Database int + // IdleTimeout is the amount of time to wait before timing out (closing) idle + // connections. + IdleTimeout time.Duration + // MaxActive is the maximum number of active connections the pool will keep. + // A value of 0 means unlimited. + MaxActive int + // MaxIdle is the maximum number of idle connections the pool will keep. A + // value of 0 means unlimited. + MaxIdle int + // Network to use. + Network string // Password for a password-protected redis database. If not empty, // every connection will use the AUTH command during initialization - // to authenticate with the database. Default: "" + // to authenticate with the database. Password string + // Wait indicates whether or not the pool should wait for a free connection + // if the MaxActive limit has been reached. If Wait is false and the + // MaxActive limit is reached, Zoom will return an error indicating that the + // pool is exhausted. + Wait bool +} + +// WithAddress returns a new copy of the options with the Address property set +// to the given value. It does not mutate the original options. +func (options PoolOptions) WithAddress(address string) PoolOptions { + options.Address = address + return options +} + +// WithDatabase returns a new copy of the options with the Database property set +// to the given value. It does not mutate the original options. +func (options PoolOptions) WithDatabase(database int) PoolOptions { + options.Database = database + return options } -// NewPool initializes and returns a pool with the given options. To use all -// the default options, you can pass in nil. -func NewPool(options *PoolOptions) *Pool { - fullOptions := parsePoolOptions(options) +// WithIdleTimeout returns a new copy of the options with the IdleTimeout +// property set to the given value. It does not mutate the original options. +func (options PoolOptions) WithIdleTimeout(timeout time.Duration) PoolOptions { + options.IdleTimeout = timeout + return options +} + +// WithMaxActive returns a new copy of the options with the MaxActive property +// set to the given value. It does not mutate the original options. +func (options PoolOptions) WithMaxActive(maxActive int) PoolOptions { + options.MaxActive = maxActive + return options +} + +// WithMaxIdle returns a new copy of the options with the MaxIdle property set +// to the given value. It does not mutate the original options. +func (options PoolOptions) WithMaxIdle(maxIdle int) PoolOptions { + options.MaxIdle = maxIdle + return options +} + +// WithNetwork returns a new copy of the options with the Network property set +// to the given value. It does not mutate the original options. +func (options PoolOptions) WithNetwork(network string) PoolOptions { + options.Network = network + return options +} + +// WithPassword returns a new copy of the options with the Password property set +// to the given value. It does not mutate the original options. +func (options PoolOptions) WithPassword(password string) PoolOptions { + options.Password = password + return options +} + +// WithWait returns a new copy of the options with the Wait property set to the +// given value. It does not mutate the original options. +func (options PoolOptions) WithWait(wait bool) PoolOptions { + options.Wait = wait + return options +} + +// NewPool creates and returns a new pool using the given address to connect to +// Redis. All the other options will be set to their default values, which can +// be found in DefaultPoolOptions. +func NewPool(address string) *Pool { + return NewPoolWithOptions(DefaultPoolOptions.WithAddress(address)) +} + +// NewPoolWithOptions initializes and returns a pool with the given options. You +// can pass in DefaultOptions to use all the default options. Or cal the WithX +// methods of DefaultOptions to change the options you want to change. +func NewPoolWithOptions(options PoolOptions) *Pool { pool := &Pool{ - options: fullOptions, + options: options, modelTypeToSpec: map[reflect.Type]*modelSpec{}, modelNameToSpec: map[string]*modelSpec{}, } pool.redisPool = &redis.Pool{ - MaxIdle: 10, - MaxActive: 0, - IdleTimeout: 240 * time.Second, + MaxIdle: options.MaxIdle, + MaxActive: options.MaxActive, + IdleTimeout: options.IdleTimeout, + Wait: options.Wait, Dial: func() (redis.Conn, error) { - c, err := redis.Dial(fullOptions.Network, fullOptions.Address) + c, err := redis.Dial(options.Network, options.Address) if err != nil { return nil, err } // If a options.Password was provided, use the AUTH command to authenticate - if fullOptions.Password != "" { - _, err = c.Do("AUTH", fullOptions.Password) + if options.Password != "" { + _, err = c.Do("AUTH", options.Password) if err != nil { return nil, err } } - // Select the database number provided by fullOptions.Database - if _, err := c.Do("Select", fullOptions.Database); err != nil { + // Select the database number provided by options.Database + if _, err := c.Do("Select", options.Database); err != nil { c.Close() return nil, err } @@ -82,7 +172,8 @@ func NewPool(options *PoolOptions) *Pool { // NewConn gets a connection from the pool and returns it. // It can be used for directly interacting with the database. See // http://godoc.org/github.com/garyburd/redigo/redis for full documentation -// on the redis.Conn type. +// on the redis.Conn type. You must call Close on any connections after you are +// done using them. Failure to call Close can cause a resource leak. func (p *Pool) NewConn() redis.Conn { return p.redisPool.Get() } @@ -92,34 +183,3 @@ func (p *Pool) NewConn() redis.Conn { func (p *Pool) Close() error { return p.redisPool.Close() } - -// defaultPoolOptions holds the default values for each config option -// if the zero value is provided in the input configuration, the value -// will fallback to the default value -var defaultPoolOptions = PoolOptions{ - Address: "localhost:6379", - Network: "tcp", - Database: 0, - Password: "", -} - -// parsePoolOptions returns a well-formed PoolOptions struct. -// If the passedOptions is nil, returns defaultPoolOptions. -// Else, for each zero value field in passedOptions, -// use the default value for that field. -func parsePoolOptions(passedOptions *PoolOptions) *PoolOptions { - if passedOptions == nil { - return &defaultPoolOptions - } - // copy the passedOptions - newOptions := *passedOptions - if newOptions.Address == "" { - newOptions.Address = defaultPoolOptions.Address - } - if newOptions.Network == "" { - newOptions.Network = defaultPoolOptions.Network - } - // since the zero value for int is 0, we can skip config.Database - // since the zero value for string is "", we can skip config.Address - return &newOptions -} diff --git a/query.go b/query.go index 99b30da..3a2e16d 100644 --- a/query.go +++ b/query.go @@ -1,226 +1,55 @@ -// Copyright 2015 Alex Browne. All rights reserved. -// Use of this source code is governed by the MIT -// license, which can be found in the LICENSE file. - -// File query.go contains code related to the query abstraction. - package zoom -import ( - "errors" - "fmt" - "reflect" - "strings" - - "github.com/garyburd/redigo/redis" -) - // Query represents a query which will retrieve some models from // the database. A Query may consist of one or more query modifiers // (e.g. Filter or Order) and may be executed with a query finisher // (e.g. Run or Ids). type Query struct { - collection *Collection - pool *Pool - tx *Transaction - includes []string - excludes []string - order order - limit uint - offset uint - filters []filter - err error -} - -// String satisfies fmt.Stringer and prints out the query in a format that -// matches the go code used to declare it. -func (q *Query) String() string { - result := fmt.Sprintf("%s.NewQuery()", q.collection.Name()) - for _, filter := range q.filters { - result += fmt.Sprintf(".%s", filter) - } - if q.hasOrder() { - result += fmt.Sprintf(".%s", q.order) - } - if q.hasOffset() { - result += fmt.Sprintf(".Offset(%d)", q.offset) - } - if q.hasLimit() { - result += fmt.Sprintf(".Limit(%d)", q.limit) - } - if q.hasIncludes() { - result += fmt.Sprintf(`.Include("%s")`, strings.Join(q.includes, `", "`)) - } else if q.hasExcludes() { - result += fmt.Sprintf(`.Exclude("%s")`, strings.Join(q.excludes, `", "`)) - } - return result -} - -type order struct { - fieldName string - redisName string - kind orderKind -} - -func (o order) String() string { - if o.kind == ascendingOrder { - return fmt.Sprintf(`Order("%s")`, o.fieldName) - } else { - return fmt.Sprintf(`Order("-%s")`, o.fieldName) - } -} - -type orderKind int - -const ( - ascendingOrder orderKind = iota - descendingOrder -) - -func (ok orderKind) String() string { - switch ok { - case ascendingOrder: - return "ascending" - case descendingOrder: - return "descending" - } - return "" -} - -type filter struct { - fieldSpec *fieldSpec - op filterOp - value reflect.Value -} - -func (f filter) String() string { - if f.value.Kind() == reflect.String { - return fmt.Sprintf(`Filter("%s %s", "%s")`, f.fieldSpec.name, f.op, f.value.String()) - } else { - return fmt.Sprintf(`Filter("%s %s", %v)`, f.fieldSpec.name, f.op, f.value.Interface()) - } -} - -type filterOp int - -const ( - equalOp filterOp = iota - notEqualOp - greaterOp - lessOp - greaterOrEqualOp - lessOrEqualOp -) - -func (fk filterOp) String() string { - switch fk { - case equalOp: - return "=" - case notEqualOp: - return "!=" - case greaterOp: - return ">" - case lessOp: - return "<" - case greaterOrEqualOp: - return ">=" - case lessOrEqualOp: - return "<=" - } - return "" -} - -var filterOps = map[string]filterOp{ - "=": equalOp, - "!=": notEqualOp, - ">": greaterOp, - "<": lessOp, - ">=": greaterOrEqualOp, - "<=": lessOrEqualOp, + *query } // NewQuery is used to construct a query. The query returned can be chained // together with one or more query modifiers (e.g. Filter or Order), and then // executed using the Run, RunOne, Count, or Ids methods. If no query modifiers -// are used, running the query will return all models of the given type in uspecified -// order. Queries use delated execution, so nothing touches the database until you -// execute it. +// are used, running the query will return all models of the given type in +// unspecified order. Queries use delayed execution, so nothing touches the +// database until you execute them. func (collection *Collection) NewQuery() *Query { - q := &Query{ - collection: collection, - pool: collection.pool, + return &Query{ + query: newQuery(collection), } - // For now, only indexed collections are queryable. This might change in - // future versions. - if !collection.index { - q.setError(fmt.Errorf("zoom: error in NewQuery: Only indexed collections are queryable. To index the collection, pass CollectionOptions to the NewCollection method.")) - return q - } - return q } -// setError sets the err property of q only if it has not already been set -func (q *Query) setError(e error) { - if !q.hasError() { - q.err = e - } -} - -// Order specifies a field by which to sort the models. fieldName should be -// a field in the struct type corresponding to the Collection used in the query -// constructor. By default, the records are sorted by ascending order by the given -// field. To sort by descending order, put a negative sign before the field name. -// Zoom can only sort by fields which have been indexed, i.e. those which have the -// `zoom:"index"` struct tag. However, in the future this may change. Only one -// order may be specified per query. However in the future, secondary orders may be -// allowed, and will take effect when two or more models have the same value for the -// primary order field. Order will set an error on the query if the fieldName is invalid, -// if another order has already been applied to the query, or if the fieldName specified +// Order specifies a field by which to sort the models. fieldName should be a +// field in the struct type corresponding to the Collection used in the query +// constructor. By default, the records are sorted by ascending order by the +// given field. To sort by descending order, put a negative sign before the +// field name. Zoom can only sort by fields which have been indexed, i.e. those +// which have the `zoom:"index"` struct tag. Only one order may be specified per +// Order will set an error on the query if the fieldName is invalid, if another +// order has already been applied to the query, or if the fieldName specified // does not correspond to an indexed field. The error, same as any other error // that occurs during the lifetime of the query, is not returned until the query -// is executed. When the query is executed the first error that occured during -// the lifetime of the query object (if any) will be returned. +// is executed. func (q *Query) Order(fieldName string) *Query { - if q.hasOrder() { - // TODO: allow secondary sort orders? - q.setError(errors.New("zoom: error in Query.Order: previous order already specified. Only one order per query is allowed.")) - return q - } - // Check for the presence of the "-" prefix - var orderKind orderKind - if strings.HasPrefix(fieldName, "-") { - orderKind = descendingOrder - // remove the "-" prefix - fieldName = fieldName[1:] - } else { - orderKind = ascendingOrder - } - // Get the redisName for the given fieldName - fs, found := q.collection.spec.fieldsByName[fieldName] - if !found { - err := fmt.Errorf("zoom: error in Query.Order: could not find field %s in type %s", fieldName, q.collection.spec.typ.String()) - q.setError(err) - return q - } - q.order = order{ - fieldName: fs.name, - redisName: fs.redisName, - kind: orderKind, - } + q.query.Order(fieldName) return q } -// Limit specifies an upper limit on the number of records to return. If amount -// is 0, no limit will be applied. The default value is 0. +// Limit specifies an upper limit on the number of models to return. If amount +// is 0, no limit will be applied and any number of models may be returned. The +// default value is 0. func (q *Query) Limit(amount uint) *Query { - q.limit = amount + q.query.Limit(amount) return q } // Offset specifies a starting index (inclusive) from which to start counting -// records that will be returned. The default value is 0. +// models that will be returned. For example, if offset is 10, the first 10 +// models that the query would otherwise return will be skipped. The default +// value is 0. func (q *Query) Offset(amount uint) *Query { - q.offset = amount + q.query.Offset(amount) return q } @@ -230,15 +59,9 @@ func (q *Query) Offset(amount uint) *Query { // only use one of Include or Exclude, not both on the same query. Include will // set an error if you try to use it with Exclude on the same query. The error, // same as any other error that occurs during the lifetime of the query, is not -// returned until the query is executed. When the query is executed the first -// error that occured during the lifetime of the query object (if any) will be -// returned. +// returned until the query is executed. func (q *Query) Include(fields ...string) *Query { - if q.hasExcludes() { - q.setError(errors.New("zoom: cannot use both Include and Exclude modifiers on a query")) - return q - } - q.includes = append(q.includes, fields...) + q.query.Include(fields...) return q } @@ -248,548 +71,83 @@ func (q *Query) Include(fields ...string) *Query { // Exclude, not both on the same query. Exclude will set an error if you try to // use it with Include on the same query. The error, same as any other error // that occurs during the lifetime of the query, is not returned until the query -// is executed. When the query is executed the first error that occured during -// the lifetime of the query object (if any) will be returned. +// is executed. func (q *Query) Exclude(fields ...string) *Query { - if q.hasIncludes() { - q.setError(errors.New("zoom: cannot use both Include and Exclude modifiers on a query")) - return q - } - q.excludes = append(q.excludes, fields...) + q.query.Exclude(fields...) return q } // Filter applies a filter to the query, which will cause the query to only -// return models with attributes matching the expression. filterString should be -// an expression which includes a fieldName, a space, and an operator in that -// order. Operators must be one of "=", "!=", ">", "<", ">=", or "<=". You can -// only use Filter on fields which are indexed, i.e. those which have the -// `zoom:"index"` struct tag. If multiple filters are applied to the same query, -// the query will only return models which have matches for ALL of the filters. -// I.e. applying multiple filters is logially equivalent to combining them with -// a AND or INTERSECT operator. Filter will set an error on the query if the -// arguments are improperly formated, if the field you are attempting to filter -// is not indexed, or if the type of value does not match the type of the field. -// The error, same as any other error that occurs during the lifetime of the -// query, is not returned until the query is executed. When the query is -// executed the first error that occured during the lifetime of the query object -// (if any) will be returned. +// return models with field values matching the expression. filterString should +// be an expression which includes a fieldName, a space, and an operator in that +// order. For example: Filter("Age >=", 30) would only return models which have +// an Age value greater than or equal to 30. Operators must be one of "=", "!=", +// ">", "<", ">=", or "<=". You can only use Filter on fields which are indexed, +// i.e. those which have the `zoom:"index"` struct tag. If multiple filters are +// applied to the same query, the query will only return models which have +// matches for *all* of the filters. Filter will set an error on the query if +// the arguments are improperly formated, if the field you are attempting to +// filter is not indexed, or if the type of value does not match the type of the +// field. The error, same as any other error that occurs during the lifetime of +// the query, is not returned until the query is executed. func (q *Query) Filter(filterString string, value interface{}) *Query { - fieldName, operator, err := splitFilterString(filterString) - if err != nil { - q.setError(err) - return q - } - // Parse the filter operator - filterOp, found := filterOps[operator] - if !found { - q.setError(errors.New("zoom: invalid Filter operator in fieldStr. should be one of =, !=, >, <, >=, or <=.")) - return q - } - // Get the fieldSpec for the given fieldName - fieldSpec, found := q.collection.spec.fieldsByName[fieldName] - if !found { - err := fmt.Errorf("zoom: error in Query.Order: could not find field %s in type %s", fieldName, q.collection.spec.typ.String()) - q.setError(err) - return q - } - // Make sure the field is an indexed field - if fieldSpec.indexKind == noIndex { - err := fmt.Errorf("zoom: filters are only allowed on indexed fields. %s.%s is not indexed. You can index it by adding the `zoom:\"index\"` struct tag.", q.collection.spec.typ.String(), fieldName) - q.setError(err) - return q - } - filter := filter{ - fieldSpec: fieldSpec, - op: filterOp, - } - // Make sure the given value is the correct type - if err := filter.checkValType(value); err != nil { - q.setError(err) - return q - } - filter.value = reflect.ValueOf(value) - q.filters = append(q.filters, filter) + q.query.Filter(filterString, value) return q } -func splitFilterString(filterString string) (fieldName string, operator string, err error) { - tokens := strings.Split(filterString, " ") - if len(tokens) != 2 { - return "", "", errors.New("zoom: too many spaces in fieldStr argument. should be a field name, a space, and an operator.") - } - return tokens[0], tokens[1], nil -} - -// checkValType returns an error if the type of value does not correspond to -// filter.fieldSpec. -func (filter filter) checkValType(value interface{}) error { - // Here we iterate through pointer inderections. This is so you can - // just pass in a primative instead of a pointer to a primative for - // filtering on fields which have pointer values. - valueType := reflect.TypeOf(value) - valueVal := reflect.ValueOf(value) - for valueType.Kind() == reflect.Ptr { - valueType = valueType.Elem() - valueVal = valueVal.Elem() - if !valueVal.IsValid() { - return errors.New("zoom: invalid value arg for Filter. Is it a nil pointer?") - } - } - // Also dereference the field type to reach the underlying type. - fieldType := filter.fieldSpec.typ - for fieldType.Kind() == reflect.Ptr { - fieldType = fieldType.Elem() - } - if valueType != fieldType { - return fmt.Errorf("zoom: invalid value arg for Filter. Type of value (%T) does not match type of field (%s).", value, fieldType.String()) - } - return nil -} - // Run executes the query and scans the results into models. The type of models -// should be a pointer to a slice of pointers to a registered Model. Run will -// return the first error that occured during the lifetime of the query object -// (if any). It will also return an error if models is the wrong type. +// should be a pointer to a slice of Models. If no models fit the criteria, Run +// will set the length of models to 0 but will *not* return an error. Run will +// return the first error that occurred during the lifetime of the query (if +// any), or if models is the wrong type. func (q *Query) Run(models interface{}) error { - if q.hasError() { - return q.err - } - if err := q.collection.spec.checkModelsType(models); err != nil { - return err - } - q.tx = q.pool.NewTransaction() - idsKey, tmpKeys, err := q.generateIdsSet() - if err != nil { - if len(tmpKeys) > 0 { - q.tx.Command("DEL", (redis.Args{}).Add(tmpKeys...), nil) - } - return err - } - limit := int(q.limit) - if limit == 0 { - // In our query syntax, a limit of 0 means unlimited - // But in redis, -1 means unlimited - limit = -1 - } - sortArgs := q.collection.spec.sortArgs(idsKey, q.redisFieldNames(), limit, q.offset, q.order.kind) - q.tx.Command("SORT", sortArgs, newScanModelsHandler(q.collection.spec, append(q.fieldNames(), "-"), models)) - if len(tmpKeys) > 0 { - q.tx.Command("DEL", (redis.Args{}).Add(tmpKeys...), nil) - } - if err := q.tx.Exec(); err != nil { - return err - } - return nil + tx := q.pool.NewTransaction() + newTransactionalQuery(q.query, tx).Run(models) + return tx.Exec() } -// RunOne is exactly like Run but finds only the first model that fits the -// query criteria and scans the values into model. If no model fits the criteria, -// an error will be returned. +// RunOne is exactly like Run but finds only the first model that fits the query +// criteria and scans the values into model. If no model fits the criteria, +// RunOne *will* return a ModelNotFoundError. func (q *Query) RunOne(model Model) error { - if q.hasError() { - return q.err - } - if err := q.collection.spec.checkModelType(model); err != nil { - return err - } - models := reflect.New(reflect.SliceOf(reflect.TypeOf(model))) - if err := q.Run(models.Interface()); err != nil { - return err - } - if models.Elem().Len() == 0 { - msg := fmt.Sprintf("Could not find a model with the given query criteria: %s", q) - return ModelNotFoundError{Msg: msg} - } else { - modelVal := models.Elem().Index(0) - reflect.ValueOf(model).Elem().Set(modelVal.Elem()) - } - return nil + tx := q.pool.NewTransaction() + newTransactionalQuery(q.query, tx).RunOne(model) + return tx.Exec() } // Count counts the number of models that would be returned by the query without -// actually retreiving the models themselves. Count will also return the first -// error that occured during the lifetime of the query object (if any). -// Otherwise, the second return value will be nil. -func (q *Query) Count() (uint, error) { - if q.hasError() { - return 0, q.err - } - if !q.hasFilters() { - // Just return the number of ids in the all index set - conn := q.pool.NewConn() - defer conn.Close() - count64, err := redis.Uint64(conn.Do("SCARD", q.collection.spec.indexKey())) - if err != nil { - return 0, nil - } - count := uint(count64) - // Apply math to take into account limit and offset - switch { - case !q.hasLimit() && !q.hasOffset(): - return count, nil - default: - if q.hasOffset() { - count = count - q.offset - } - if q.hasLimit() && q.limit < count { - count = q.limit - } - return count, nil - } - } else { - // If the query has filters, it is difficult to do any optimizations. - // Instead we'll just count the number of ids that match the query - // criteria. - ids, err := q.Ids() - if err != nil { - return 0, err - } - return uint(len(ids)), nil +// actually retrieving the models themselves. Count will also return the first +// error that occurred during the lifetime of the query (if any). +func (q *Query) Count() (int, error) { + tx := q.pool.NewTransaction() + var count int + newTransactionalQuery(q.query, tx).Count(&count) + if err := tx.Exec(); err != nil { + return 0, err } + return count, nil } -// Ids returns only the ids of the models without actually retreiving the -// models themselves. Ids will return the first error that occured -// during the lifetime of the query object (if any). +// Ids returns only the ids of the models without actually retrieving the +// models themselves. Ids will return the first error that occurred during the +// lifetime of the query (if any). func (q *Query) Ids() ([]string, error) { - if q.hasError() { - return nil, q.err - } - q.tx = q.pool.NewTransaction() - idsKey, tmpKeys, err := q.generateIdsSet() - if err != nil { - if len(tmpKeys) > 0 { - q.tx.Command("DEL", (redis.Args{}).Add(tmpKeys...), nil) - } - return nil, err - } - limit := int(q.limit) - if limit == 0 { - // In our query syntax, a limit of 0 means unlimited - // But in redis, -1 means unlimited - limit = -1 - } - sortArgs := q.collection.spec.sortArgs(idsKey, nil, limit, q.offset, q.order.kind) + tx := q.pool.NewTransaction() ids := []string{} - q.tx.Command("SORT", sortArgs, NewScanStringsHandler(&ids)) - if len(tmpKeys) > 0 { - q.tx.Command("DEL", (redis.Args{}).Add(tmpKeys...), nil) - } - if err := q.tx.Exec(); err != nil { + newTransactionalQuery(q.query, tx).Ids(&ids) + if err := tx.Exec(); err != nil { return nil, err } return ids, nil } -// generateIdsSet will return the key of a set or sorted set that contains all the ids -// which match the query criteria. It may also return some temporary keys which were created -// during the process of creating the set of ids. Note that tmpKeys may contain idsKey itself, -// so the temporary keys should not be deleted until after the ids have been read from idsKey. -func (q *Query) generateIdsSet() (idsKey string, tmpKeys []interface{}, err error) { - idsKey = q.collection.spec.indexKey() - tmpKeys = []interface{}{} - if q.hasOrder() { - fieldIndexKey, err := q.collection.spec.fieldIndexKey(q.order.fieldName) - if err != nil { - return "", nil, err - } - fieldSpec := q.collection.spec.fieldsByName[q.order.fieldName] - if fieldSpec.indexKind == stringIndex { - // If the order is a string field, we need to extract the ids before - // we use ZRANGE. Create a temporary set to store the ordered ids - orderedIdsKey := generateRandomKey("order:" + q.order.fieldName) - tmpKeys = append(tmpKeys, orderedIdsKey) - idsKey = orderedIdsKey - // TODO: as an optimization, if there is a filter on the same field, - // pass the start and stop parameters to the script. - q.tx.ExtractIdsFromStringIndex(fieldIndexKey, orderedIdsKey, "-", "+") - } else { - idsKey = fieldIndexKey - } - } - if q.hasFilters() { - filteredIdsKey := generateRandomKey("filter:all") - tmpKeys = append(tmpKeys, filteredIdsKey) - for i, filter := range q.filters { - if i == 0 { - // The first time, we should intersect with the ids key from above - if err := q.intersectFilter(filter, idsKey, filteredIdsKey); err != nil { - return "", tmpKeys, err - } - } else { - // All other times, we should intersect with the filteredIdsKey itself - if err := q.intersectFilter(filter, filteredIdsKey, filteredIdsKey); err != nil { - return "", tmpKeys, err - } - } - } - idsKey = filteredIdsKey - } - return idsKey, tmpKeys, nil -} - -// intersectFilter adds commands to the query transacation which, when run, will create a -// temporary set which contains all the ids that fit the given filter criteria. Then it will -// intersect them with origKey and stores the result in destKey. The function will automatically -// delete any temporary sets created since, in this case, they are gauranteed to not be needed -// by any other transaction commands. -func (q *Query) intersectFilter(filter filter, origKey string, destKey string) error { - switch filter.fieldSpec.indexKind { - case numericIndex: - return q.intersectNumericFilter(filter, origKey, destKey) - case booleanIndex: - return q.intersectBoolFilter(filter, origKey, destKey) - case stringIndex: - return q.intersectStringFilter(filter, origKey, destKey) - } - return nil -} - -// intersectNumericFilter adds commands to the query transaction which, when run, will -// create a temporary set which contains all the ids of models which match the given -// numeric filter criteria, then intersect those ids with origKey and store the result -// in destKey. -func (q *Query) intersectNumericFilter(filter filter, origKey string, destKey string) error { - fieldIndexKey, err := q.collection.spec.fieldIndexKey(filter.fieldSpec.name) - if err != nil { - return err - } - if filter.op == notEqualOp { - // Special case for not equal. We need to use two separate commands - valueExclusive := fmt.Sprintf("(%v", filter.value.Interface()) - filterKey := generateRandomKey("filter:" + fieldIndexKey) - // ZADD all ids greater than filter.value - q.tx.ExtractIdsFromFieldIndex(fieldIndexKey, filterKey, valueExclusive, "+inf") - // ZADD all ids less than filter.value - q.tx.ExtractIdsFromFieldIndex(fieldIndexKey, filterKey, "-inf", valueExclusive) - // Intersect filterKey with origKey and store result in destKey - q.tx.Command("ZINTERSTORE", redis.Args{destKey, 2, origKey, filterKey, "WEIGHTS", 1, 0}, nil) - // Delete the temporary key - q.tx.Command("DEL", redis.Args{filterKey}, nil) - } else { - var min, max interface{} - switch filter.op { - case equalOp: - min, max = filter.value.Interface(), filter.value.Interface() - case lessOp: - min = "-inf" - // use "(" for exclusive - max = fmt.Sprintf("(%v", filter.value.Interface()) - case greaterOp: - min = fmt.Sprintf("(%v", filter.value.Interface()) - max = "+inf" - case lessOrEqualOp: - min = "-inf" - max = filter.value.Interface() - case greaterOrEqualOp: - min = filter.value.Interface() - max = "+inf" - } - // Get all the ids that fit the filter criteria and store them in a temporary key caled filterKey - filterKey := generateRandomKey("filter:" + fieldIndexKey) - q.tx.ExtractIdsFromFieldIndex(fieldIndexKey, filterKey, min, max) - // Intersect filterKey with origKey and store result in destKey - q.tx.Command("ZINTERSTORE", redis.Args{destKey, 2, origKey, filterKey, "WEIGHTS", 1, 0}, nil) - // Delete the temporary key - q.tx.Command("DEL", redis.Args{filterKey}, nil) - } - return nil -} - -// intersectBoolFilter adds commands to the query transaction which, when run, will -// create a temporary set which contains all the ids of models which match the given -// bool filter criteria, then intersect those ids with origKey and store the result -// in destKey. -func (q *Query) intersectBoolFilter(filter filter, origKey string, destKey string) error { - fieldIndexKey, err := q.collection.spec.fieldIndexKey(filter.fieldSpec.name) - if err != nil { - return err - } - var min, max interface{} - switch filter.op { - case equalOp: - if filter.value.Bool() { - min, max = 1, 1 - } else { - min, max = 0, 0 - } - case lessOp: - if filter.value.Bool() { - // Only false is less than true - min, max = 0, 0 - } else { - // No models are less than false, - // so we should eliminate all models - min, max = -1, -1 - } - case greaterOp: - if filter.value.Bool() { - // No models are greater than true, - // so we should eliminate all models - min, max = -1, -1 - } else { - // Only true is greater than false - min, max = 1, 1 - } - case lessOrEqualOp: - if filter.value.Bool() { - // All models are <= true - min, max = 0, 1 - } else { - // Only false is <= false - min, max = 0, 0 - } - case greaterOrEqualOp: - if filter.value.Bool() { - // Only true is >= true - min, max = 1, 1 - } else { - // All models are >= false - min, max = 0, 1 - } - case notEqualOp: - if filter.value.Bool() { - min, max = 0, 0 - } else { - min, max = 1, 1 - } - } - // Get all the ids that fit the filter criteria and store them in a temporary key caled filterKey - filterKey := generateRandomKey("filter:" + fieldIndexKey) - q.tx.ExtractIdsFromFieldIndex(fieldIndexKey, filterKey, min, max) - // Intersect filterKey with origKey and store result in destKey - q.tx.Command("ZINTERSTORE", redis.Args{destKey, 2, origKey, filterKey, "WEIGHTS", 1, 0}, nil) - // Delete the temporary key - q.tx.Command("DEL", redis.Args{filterKey}, nil) - return nil -} - -// intersectStringFilter adds commands to the query transaction which, when run, will -// create a temporary set which contains all the ids of models which match the given -// string filter criteria, then intersect those ids with origKey and store the result -// in destKey. -func (q *Query) intersectStringFilter(filter filter, origKey string, destKey string) error { - fieldIndexKey, err := q.collection.spec.fieldIndexKey(filter.fieldSpec.name) - if err != nil { - return err - } - valString := filter.value.String() - if filter.op == notEqualOp { - // Special case for not equal. We need to use two separate commands - filterKey := generateRandomKey("filter:" + fieldIndexKey) - // ZADD all ids greater than filter.value - min := "(" + valString + nullString + delString - q.tx.ExtractIdsFromStringIndex(fieldIndexKey, filterKey, min, "+") - // ZADD all ids less than filter.value - max := "(" + valString - q.tx.ExtractIdsFromStringIndex(fieldIndexKey, filterKey, "-", max) - // Intersect filterKey with origKey and store result in destKey - q.tx.Command("ZINTERSTORE", redis.Args{destKey, 2, origKey, filterKey, "WEIGHTS", 1, 0}, nil) - // Delete the temporary key - q.tx.Command("DEL", redis.Args{filterKey}, nil) - } else { - var min, max string - switch filter.op { - case equalOp: - min = "[" + valString - max = "(" + valString + nullString + delString - case lessOp: - min = "-" - max = "(" + valString - case greaterOp: - min = "(" + valString + nullString + delString - max = "+" - case lessOrEqualOp: - min = "-" - max = "(" + valString + nullString + delString - case greaterOrEqualOp: - min = "[" + valString - max = "+" - } - // Get all the ids that fit the filter criteria and store them in a temporary key caled filterKey - filterKey := generateRandomKey("filter:" + fieldIndexKey) - q.tx.ExtractIdsFromStringIndex(fieldIndexKey, filterKey, min, max) - // Intersect filterKey with origKey and store result in destKey - q.tx.Command("ZINTERSTORE", redis.Args{destKey, 2, origKey, filterKey, "WEIGHTS", 1, 0}, nil) - // Delete the temporary key - q.tx.Command("DEL", redis.Args{filterKey}, nil) - } - return nil -} - -// fieldNames parses the includes and excludes properties to return a list of -// field names which should be included in all find operations. If there are no -// includes or excludes, it returns all the field names. -func (q *Query) fieldNames() []string { - switch { - case q.hasIncludes(): - return q.includes - case q.hasExcludes(): - results := q.collection.spec.fieldNames() - for _, name := range q.excludes { - results = removeElementFromStringSlice(results, name) - } - return results - default: - return q.collection.spec.fieldNames() - } -} - -// redisFieldNames parses the includes and excludes properties to return a list of -// redis names for each field which should be included in all find operations. If -// there are no includes or excludes, it returns the redis names for all fields. -func (q *Query) redisFieldNames() []string { - fieldNames := q.fieldNames() - redisNames := []string{} - for _, fieldName := range fieldNames { - redisNames = append(redisNames, q.collection.spec.fieldsByName[fieldName].redisName) - } - return redisNames -} - -// converts limit and offset to start and stop values for cases where redis -// requires them. NOTE start cannot be negative, but stop can be -func (q *Query) getStartStop() (start int, stop int) { - start = int(q.offset) - stop = -1 - if q.hasLimit() { - stop = int(start) + int(q.limit) - 1 - } - return start, stop -} - -func (q *Query) hasFilters() bool { - return len(q.filters) > 0 -} - -func (q *Query) hasOrder() bool { - return q.order.fieldName != "" -} - -func (q *Query) hasLimit() bool { - return q.limit != 0 -} - -func (q *Query) hasOffset() bool { - return q.offset != 0 -} - -func (q *Query) hasIncludes() bool { - return len(q.includes) > 0 -} - -func (q *Query) hasExcludes() bool { - return len(q.excludes) > 0 -} - -func (q *Query) hasError() bool { - return q.err != nil -} - -// generateRandomKey generates a random string that is more or less -// garunteed to be unique and then prepends the given prefix. It is -// used to generate keys for temporary sorted sets in queries. -func generateRandomKey(prefix string) string { - return prefix + ":" + generateRandomId() +// StoreIds executes the query and stores the model ids matching the query +// criteria in a list identified by destKey. The list will be completely +// overwritten, and the model ids stored there will be in the correct order if +// the query includes an Order modifier. StoreIds will return the first error +// that occurred during the lifetime of the query (if any). +func (q *Query) StoreIds(destKey string) error { + tx := q.pool.NewTransaction() + newTransactionalQuery(q.query, tx).StoreIds(destKey) + return tx.Exec() } diff --git a/query_test.go b/query_test.go index da0d121..2738292 100644 --- a/query_test.go +++ b/query_test.go @@ -12,6 +12,8 @@ import ( "sort" "strconv" "testing" + + "github.com/garyburd/redigo/redis" ) func TestQueryAll(t *testing.T) { @@ -289,16 +291,19 @@ func TestQueryRunOne(t *testing.T) { // then the query was correct and the test will pass. models should be an array of all // the models which are being queried against. func testQuery(t *testing.T, q *Query, models []*indexedTestModel) { - expected := expectedResultsForQuery(q, models) + expected := expectedResultsForQuery(q.query, models) testQueryRun(t, q, expected) testQueryIds(t, q, expected) testQueryCount(t, q, expected) + testQueryStoreIds(t, q, expected) + checkForLeakedTmpKeys(t, q.query) } func testQueryRun(t *testing.T, q *Query, expected []*indexedTestModel) { got := []*indexedTestModel{} if err := q.Run(&got); err != nil { t.Errorf("Unexpected error in query.Run: %s", err.Error()) + return } if err := expectModelsToBeEqual(expected, got, q.hasOrder()); err != nil { t.Errorf("testQueryRun failed for query %s\nExpected: %#v\nGot: %#v", q, expected, got) @@ -306,9 +311,10 @@ func testQueryRun(t *testing.T, q *Query, expected []*indexedTestModel) { } func testQueryCount(t *testing.T, q *Query, expectedModels []*indexedTestModel) { - expected := uint(len(expectedModels)) + expected := len(expectedModels) if got, err := q.Count(); err != nil { t.Error(err) + return } else if got != expected { t.Errorf("testQueryCount failed for query %s. Expected %d but got %d.", q, expected, got) } @@ -318,6 +324,7 @@ func testQueryIds(t *testing.T, q *Query, expectedModels []*indexedTestModel) { got, err := q.Ids() if err != nil { t.Errorf("Unexpected error in query.Ids: %s", err.Error()) + return } expected := modelIds(Models(expectedModels)) if q.hasOrder() { @@ -333,11 +340,53 @@ func testQueryIds(t *testing.T, q *Query, expectedModels []*indexedTestModel) { } } +func testQueryStoreIds(t *testing.T, q *Query, expectedModels []*indexedTestModel) { + destKey := "queryDestKey:" + generateRandomId() + if err := q.StoreIds(destKey); err != nil { + t.Errorf("Unexpected error in query.StoreIds: %s", err.Error()) + return + } + expected := modelIds(Models(expectedModels)) + conn := testPool.NewConn() + defer conn.Close() + got, err := redis.Strings(conn.Do("LRANGE", destKey, 0, -1)) + if err != nil { + t.Error(err) + return + } + if q.hasOrder() { + // Order matters + if !reflect.DeepEqual(expected, got) { + t.Errorf("testQueryStoreIds failed for query %s\nExpected: %v\nGot: %v", q, expected, got) + return + } + } else { + // Order does not matter + if equal, msg := compareAsStringSet(expected, got); !equal { + t.Errorf("testQueryStoreIds failed for query %s\n%s\nExpected: %v\nGot: %v", q, msg, expected, got) + return + } + } +} + +func checkForLeakedTmpKeys(t *testing.T, query *query) { + conn := testPool.NewConn() + defer conn.Close() + keys, err := redis.Strings(conn.Do("KEYS", "tmp:*")) + if err != nil { + t.Error(err) + return + } + if len(keys) > 0 { + t.Errorf("Found leaked keys: %v\nFor query: %s", keys, query) + } +} + // expectedResultsForQuery returns the expected results for q on the given set of models. // It computes the models that should be returned in-memory, without touching the database, // and without the same optimizations that database queries have. It can be used to test for // the correctness of database queries. -func expectedResultsForQuery(q *Query, models []*indexedTestModel) []*indexedTestModel { +func expectedResultsForQuery(q *query, models []*indexedTestModel) []*indexedTestModel { expected := make([]*indexedTestModel, len(models)) copy(expected, models) diff --git a/scripts_test.go b/scripts_test.go index 5c5253b..c7e4f57 100644 --- a/scripts_test.go +++ b/scripts_test.go @@ -75,10 +75,8 @@ func TestDeleteStringIndexScript(t *testing.T) { String string `zoom:"index"` RandomId } - stringIndexModels, err := testPool.NewCollection(&stringIndexModel{}, - &CollectionOptions{ - Index: true, - }) + options := DefaultCollectionOptions.WithIndex(true) + stringIndexModels, err := testPool.NewCollectionWithOptions(&stringIndexModel{}, options) if err != nil { t.Errorf("Unexpected error registering stringIndexModel: %s", err.Error()) } diff --git a/struct_tags_test.go b/struct_tags_test.go index 486fc88..deca3c2 100644 --- a/struct_tags_test.go +++ b/struct_tags_test.go @@ -23,7 +23,7 @@ func TestRedisIgnoreOption(t *testing.T) { Attr string `redis:"-"` RandomId } - ignoredFieldModels, err := testPool.NewCollection(&ignoredFieldModel{}, nil) + ignoredFieldModels, err := testPool.NewCollection(&ignoredFieldModel{}) if err != nil { t.Errorf("Unexpected error in Register: %s", err) } @@ -67,7 +67,7 @@ func TestRedisNameOption(t *testing.T) { Attr string `redis:"a"` RandomId } - customFieldModels, err := testPool.NewCollection(&customFieldModel{}, nil) + customFieldModels, err := testPool.NewCollection(&customFieldModel{}) if err != nil { t.Errorf("Unexpected error in Register: %s", err.Error()) } @@ -102,7 +102,7 @@ func TestInvalidOptionThrowsError(t *testing.T) { Attr string `zoom:"index,poop"` RandomId } - if _, err := testPool.NewCollection(&invalid{}, nil); err == nil { + if _, err := testPool.NewCollection(&invalid{}); err == nil { t.Error("Expected error when registering struct with invalid tag") } } @@ -216,7 +216,7 @@ func TestIndexAndCustomName(t *testing.T) { Bool bool `zoom:"index" redis:"boolean"` RandomId } - customIndexModels, err := testPool.NewCollection(&customIndexModel{}, nil) + customIndexModels, err := testPool.NewCollection(&customIndexModel{}) if err != nil { t.Fatalf("Unexpected error in Register: %s", err.Error()) } diff --git a/test_util.go b/test_util.go index a12d1c7..4462315 100755 --- a/test_util.go +++ b/test_util.go @@ -33,11 +33,17 @@ var setUpOnce = sync.Once{} // testingSetUp func testingSetUp() { setUpOnce.Do(func() { - testPool = NewPool(&PoolOptions{ - Address: *address, - Network: *network, - Database: *database, - }) + options := DefaultPoolOptions + if address != nil { + options = options.WithAddress(*address) + } + if network != nil { + options = options.WithNetwork(*network) + } + if database != nil { + options = options.WithDatabase(*database) + } + testPool = NewPoolWithOptions(options) checkDatabaseEmpty() registerTestingTypes() }) @@ -261,9 +267,8 @@ func registerTestingTypes() { }, } for _, m := range testModelTypes { - collection, err := testPool.NewCollection(m.model, &CollectionOptions{ - Index: m.index, - }) + options := DefaultCollectionOptions.WithIndex(true) + collection, err := testPool.NewCollectionWithOptions(m.model, options) if err != nil { panic(err) } @@ -287,15 +292,14 @@ func checkDatabaseEmpty() { } // testingTearDown flushes the database. It should be run at the end -// of each test that toches the database, typically by using defer. +// of each test that touches the database, typically by using defer. func testingTearDown() { // flush and close the database conn := testPool.NewConn() - _, err := conn.Do("flushdb") - if err != nil { + defer conn.Close() + if _, err := conn.Do("flushdb"); err != nil { panic(err) } - conn.Close() } // expectSetContains sets an error via t.Errorf if member is not in the set @@ -333,6 +337,12 @@ func expectFieldEquals(t *testing.T, key string, fieldName string, marshalerUnma if err != nil { t.Errorf("Unexpected error in HGET: %s", err.Error()) } + if reply == nil { + if expected == nil { + return + } + t.Errorf("Field %s was nil. Expected: %v", fieldName, expected) + } srcBytes, ok := reply.([]byte) if !ok { t.Fatalf("Unexpected error: could not convert %v of type %T to []byte.\n", reply, reply) diff --git a/transaction.go b/transaction.go index 7ebe4a3..f494c84 100644 --- a/transaction.go +++ b/transaction.go @@ -7,7 +7,11 @@ package zoom -import "github.com/garyburd/redigo/redis" +import ( + "fmt" + + "github.com/garyburd/redigo/redis" +) // Transaction is an abstraction layer around a redis transaction. // Transactions consist of a set of actions which are either redis @@ -144,6 +148,9 @@ func (t *Transaction) Exec() error { // Iterate through the replies, calling the corresponding handler functions for i, reply := range replies { a := t.actions[i] + if err, ok := reply.(error); ok { + return err + } if a.handler != nil { if err := a.handler(reply); err != nil { return err @@ -198,3 +205,17 @@ func (t *Transaction) ExtractIdsFromFieldIndex(setKey string, destKey string, mi func (t *Transaction) ExtractIdsFromStringIndex(setKey, destKey, min, max string) { t.Script(extractIdsFromStringIndexScript, redis.Args{setKey, destKey, min, max}, nil) } + +func (t *Transaction) FindModelsByIdsKey(collection *Collection, idsKey string, fieldNames []string, limit uint, offset uint, reverse bool, models interface{}) { + if err := collection.checkModelsType(models); err != nil { + t.setError(fmt.Errorf("zoom: error in FindModelsByIdKey: %s", err.Error())) + return + } + redisNames, err := collection.spec.redisNamesForFieldNames(fieldNames) + if err != nil { + t.setError(fmt.Errorf("zoom: error in FindModelsByIdKey: %s", err.Error())) + return + } + sortArgs := collection.spec.sortArgs(idsKey, redisNames, int(limit), offset, reverse) + t.Command("SORT", sortArgs, newScanModelsHandler(collection.spec, append(fieldNames, "-"), models)) +} diff --git a/transaction_query.go b/transaction_query.go new file mode 100644 index 0000000..a9c7533 --- /dev/null +++ b/transaction_query.go @@ -0,0 +1,245 @@ +package zoom + +import "github.com/garyburd/redigo/redis" + +// TransactionalQuery represents a query which will be run inside an existing +// transaction. A TransactionalQuery may consist of one or more query modifiers +// (e.g. Filter or Order) and should always be finished with a query finisher +// (e.g. Run or Ids). Unlike Query, the finisher methods for TransactionalQuery +// always expect pointers as arguments and will set the values when the +// corresponding Transaction is executed. +type TransactionQuery struct { + *query + tx *Transaction +} + +// newTransactionalQuery creates and returns a new TransactionalQuery. It is an +// internal function that allows us to convert a Query to a TransactionalQuery. +// That way, there is only one canonical implementation of the query finisher +// methods (e.g. Run, RunOne, Ids). +func newTransactionalQuery(query *query, tx *Transaction) *TransactionQuery { + return &TransactionQuery{ + query: query, + tx: tx, + } +} + +// Query is used to construct a query in the context of an existing Transaction +// It can be used to run a query atomically along with commands, scripts, or +// other queries in a single round trip. Note that this method returns a +// TransactionalQuery, whereas Collection.NewQuery returns a Query. The two +// types are very similar, but there are differences in how they are eventually +// executed. Like a regular Query, a TransactionalQuery can be chained together +// with one or more query modifiers (e.g. Filter or Order). You also need to +// finish the query with a method such as Run, RunOne, or Count. The major +// difference is that TransactionQueries are not actually run until you call +// Transaction.Exec(). As a consequence, the finisher methods (e.g. Run, RunOne, +// Count, etc) do not return anything. Instead they accept arguments which are +// then mutated after the transaction is executed. +func (tx *Transaction) Query(collection *Collection) *TransactionQuery { + return &TransactionQuery{ + query: newQuery(collection), + tx: tx, + } +} + +// Order works exactly like Query.Order. See the documentation for Query.Order +// for a full description. +func (q *TransactionQuery) Order(fieldName string) *TransactionQuery { + q.query.Order(fieldName) + return q +} + +// Limit works exactly like Query.Limit. See the documentation for Query.Limit +// for more information. +func (q *TransactionQuery) Limit(amount uint) *TransactionQuery { + q.query.Limit(amount) + return q +} + +// Offset works exactly like Query.Offset. See the documentation for +// Query.Offset for more information. +func (q *TransactionQuery) Offset(amount uint) *TransactionQuery { + q.query.Offset(amount) + return q +} + +// Include works exactly like Query.Include. See the documentation for +// Query.Include for more information. +func (q *TransactionQuery) Include(fields ...string) *TransactionQuery { + q.query.Include(fields...) + return q +} + +// Exclude works exactly like Query.Exclude. See the documentation for +// Query.Exclude for more information. +func (q *TransactionQuery) Exclude(fields ...string) *TransactionQuery { + q.query.Exclude(fields...) + return q +} + +// Filter works exactly like Query.Filter. See the documentation for +// Query.Filter for more information. +func (q *TransactionQuery) Filter(filterString string, value interface{}) *TransactionQuery { + q.query.Filter(filterString, value) + return q +} + +// Run will run the query and scan the results into models when the Transaction +// is executed. It works very similarly to Query.Run, so you can check the +// documentation for Query.Run for more information. The first error encountered +// will be saved to the corresponding Transaction (if there is not already an +// error for the Transaction) and returned when you call Transaction.Exec. +func (q *TransactionQuery) Run(models interface{}) { + if q.hasError() { + q.tx.setError(q.err) + return + } + if err := q.collection.spec.checkModelsType(models); err != nil { + q.tx.setError(err) + return + } + idsKey, tmpKeys, err := generateIdsSet(q.query, q.tx) + if err != nil { + q.tx.setError(err) + return + } + limit := int(q.limit) + if limit == 0 { + // In our query syntax, a limit of 0 means unlimited + // But in redis, -1 means unlimited + limit = -1 + } + sortArgs := q.collection.spec.sortArgs(idsKey, q.redisFieldNames(), limit, q.offset, q.order.kind == descendingOrder) + q.tx.Command("SORT", sortArgs, newScanModelsHandler(q.collection.spec, append(q.fieldNames(), "-"), models)) + if len(tmpKeys) > 0 { + q.tx.Command("DEL", (redis.Args{}).Add(tmpKeys...), nil) + } +} + +// RunOne will run the query and scan the first model which matches the query +// criteria into model. If no model matches the query criteria, it will set a +// ModelNotFoundError on the Transaction. It works very similarly to +// Query.RunOne, so you can check the documentation for Query.RunOne for more +// information. The first error encountered will be saved to the corresponding +// Transaction (if there is not already an error for the Transaction) and +// returned when you call Transaction.Exec. +func (q *TransactionQuery) RunOne(model Model) { + if q.hasError() { + q.tx.setError(q.err) + return + } + if err := q.collection.spec.checkModelType(model); err != nil { + q.tx.setError(err) + return + } + idsKey, tmpKeys, err := generateIdsSet(q.query, q.tx) + if err != nil { + q.tx.setError(err) + return + } + sortArgs := q.collection.spec.sortArgs(idsKey, q.redisFieldNames(), 1, q.offset, q.order.kind == descendingOrder) + q.tx.Command("SORT", sortArgs, newScanOneModelHandler(q.query, q.collection.spec, append(q.fieldNames(), "-"), model)) + if len(tmpKeys) > 0 { + q.tx.Command("DEL", (redis.Args{}).Add(tmpKeys...), nil) + } +} + +// Count will count the number of models that match the query criteria and set +// the value of count. It works very similarly to Query.Count, so you can check +// the documentation for Query.Count for more information. The first error +// encountered will be saved to the corresponding Transaction (if there is not +// already an error for the Transaction) and returned when you call +// Transaction.Exec. +func (q *TransactionQuery) Count(count *int) { + if q.hasError() { + q.tx.setError(q.err) + return + } + if !q.hasFilters() { + // Start by getting the number of models in the all index set + q.tx.Command("SCARD", redis.Args{q.collection.spec.indexKey()}, func(reply interface{}) error { + gotCount, err := redis.Int(reply, nil) + if err != nil { + return err + } + // Apply math to take into account limit and offset + if q.hasOffset() { + gotCount = gotCount - int(q.offset) + } + if q.hasLimit() && int(q.limit) < gotCount { + gotCount = int(q.limit) + } + // Assign the value of count + (*count) = gotCount + return nil + }) + } else { + // If the query has filters, it is difficult to do any optimizations. + // Instead we'll just count the number of ids that match the query + // criteria. To do in a single transaction, we use the StoreIds method and + // then add a LLEN command. + destKey := generateRandomKey("tmp:countDestKey") + q.StoreIds(destKey) + q.tx.Command("LLEN", redis.Args{destKey}, NewScanIntHandler(count)) + // Delete the temporary destKey when we're done. + q.tx.Command("DEL", redis.Args{destKey}, nil) + } +} + +// Ids will find the ids for models matching the query criteria and set the +// value of ids. It works very similarly to Query.Ids, so you can check the +// documentation for Query.Ids for more information. The first error encountered +// will be saved to the corresponding Transaction (if there is not already an +// error for the Transaction) and returned when you call Transaction.Exec. +func (q *TransactionQuery) Ids(ids *[]string) { + if q.hasError() { + q.tx.setError(q.err) + return + } + idsKey, tmpKeys, err := generateIdsSet(q.query, q.tx) + if err != nil { + q.tx.setError(err) + } + limit := int(q.limit) + if limit == 0 { + // In our query syntax, a limit of 0 means unlimited + // But in redis, -1 means unlimited + limit = -1 + } + sortArgs := q.collection.spec.sortArgs(idsKey, nil, limit, q.offset, q.order.kind == descendingOrder) + q.tx.Command("SORT", sortArgs, NewScanStringsHandler(ids)) + if len(tmpKeys) > 0 { + q.tx.Command("DEL", (redis.Args{}).Add(tmpKeys...), nil) + } +} + +// StoreIds will store the ids for for models matching the criteria in a list +// identified by destKey. It works very similarly to Query.StoreIds, so you can +// check the documentation for Query.StoreIds for more information. The first +// error encountered will be saved to the corresponding Transaction (if there is +// not already an error for the Transaction) and returned when you call +// Transaction.Exec. +func (q *TransactionQuery) StoreIds(destKey string) { + if q.hasError() { + q.tx.setError(q.err) + return + } + idsKey, tmpKeys, err := generateIdsSet(q.query, q.tx) + if err != nil { + q.tx.setError(err) + } + limit := int(q.limit) + if limit == 0 { + // In our query syntax, a limit of 0 means unlimited + // But in Redis, -1 means unlimited + limit = -1 + } + sortArgs := q.collection.spec.sortArgs(idsKey, nil, limit, q.offset, q.order.kind == descendingOrder) + // Append the STORE argument to cause Redis to store the results in destKey. + sortAndStoreArgs := append(sortArgs, "STORE", destKey) + q.tx.Command("SORT", sortAndStoreArgs, nil) + if len(tmpKeys) > 0 { + q.tx.Command("DEL", (redis.Args{}).Add(tmpKeys...), nil) + } +} diff --git a/transaction_query_test.go b/transaction_query_test.go new file mode 100644 index 0000000..d485149 --- /dev/null +++ b/transaction_query_test.go @@ -0,0 +1,71 @@ +package zoom + +import ( + "testing" +) + +func TestTransactionQueries(t *testing.T) { + testingSetUp() + defer testingTearDown() + + // Create some test models + models, err := createAndSaveIndexedTestModels(10) + if err != nil { + t.Fatal(err) + } + + // Create a transaction and add some queries to it + tx := testPool.NewTransaction() + queries := []*TransactionQuery{ + tx.Query(indexedTestModels), + tx.Query(indexedTestModels).Filter("Int >", 3).Order("-String").Limit(3), + // Note: Offset(11) means no models should be returned. + tx.Query(indexedTestModels).Offset(11), + } + + // Calculate the expected models and got models for each query + gotModels := make([][]*indexedTestModel, len(queries)) + expectedModels := make([][]*indexedTestModel, len(queries)) + for i, query := range queries { + expectedModels[i] = expectedResultsForQuery(query.query, models) + modelsHolder := []*indexedTestModel{} + query.Run(&modelsHolder) + expectedModels[i] = modelsHolder + } + + // Execute the transaction and check the results + if err := tx.Exec(); err != nil { + t.Fatal(err) + } + for i, query := range queries { + if err := expectModelsToBeEqual(expectedModels[i], gotModels[i], true); err != nil { + t.Errorf("Query %d failed: %s\n%v", i, query, err) + } + checkForLeakedTmpKeys(t, query.query) + } +} + +func TestTransactionQueriesError(t *testing.T) { + testingSetUp() + defer testingTearDown() + + // Create some test models + if _, err := createAndSaveIndexedTestModels(10); err != nil { + t.Fatal(err) + } + + // Create a transaction and add a RunOne query to it. We expect this query to + // fail because we used Offset(11) and there are only 10 models. + tx := testPool.NewTransaction() + gotModel := indexedTestModel{} + query := tx.Query(indexedTestModels).Offset(11) + query.RunOne(&gotModel) + + // Execute the transaction and check the results + if err := tx.Exec(); err == nil { + t.Error("Expected an error but got none") + } else if _, ok := err.(ModelNotFoundError); !ok { + t.Errorf("Expected a ModelNotFoundError but got: %T: %v", err, err) + } + checkForLeakedTmpKeys(t, query.query) +}