From 962779ce401dcaa0def5d32e6206291efe993d0d Mon Sep 17 00:00:00 2001 From: sunboyy Date: Fri, 26 Feb 2021 10:27:00 +0700 Subject: [PATCH] Add functionality to sort in find operation --- README.md | 2 +- internal/generator/generator_test.go | 42 +++++-- internal/mongo/generator.go | 24 ++++ internal/mongo/generator_test.go | 173 ++++++++++++++++++++++++--- internal/mongo/templates.go | 23 +++- internal/spec/errors.go | 13 ++ internal/spec/errors_test.go | 5 + internal/spec/models.go | 16 +++ internal/spec/parser.go | 67 ++++++++++- internal/spec/parser_test.go | 126 +++++++++++++++++++ 10 files changed, 460 insertions(+), 31 deletions(-) diff --git a/README.md b/README.md index 775f854..ea6d240 100644 --- a/README.md +++ b/README.md @@ -131,7 +131,7 @@ FindByCity(ctx context.Context, city string) ([]*Model, error) FindAll(ctx context.Context) ([]*Model, error) ``` -Repogen determines a single-entity operation or a multiple-entity by checking the first return value. If it is a pointer of a model, the method will be single-entity operation. If it is a slice of pointers of a model, the method will be multiple-entity operation. +Repogen determines a single-entity or a multiple-entity operation by checking the first return value. If it is a pointer of a model, the method will be single-entity operation. If it is a slice of pointers of a model, the method will be multiple-entity operation. The requirement of the `Find` operation method is that there must be only two return values, the second return value must be of type `error` and the first method parameter must be of type `context.Context`. The requirement of number of method parameters depends on the query which will be described in the query specification section. diff --git a/internal/generator/generator_test.go b/internal/generator/generator_test.go index ea6f2b7..a69a0c7 100644 --- a/internal/generator/generator_test.go +++ b/internal/generator/generator_test.go @@ -77,7 +77,7 @@ func TestGenerateMongoRepository(t *testing.T) { }, }, { - Name: "FindByAgeLessThanEqual", + Name: "FindByAgeLessThanEqualOrderByAge", Params: []code.Param{ {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, {Name: "age", Type: code.SimpleType("int")}, @@ -93,10 +93,13 @@ func TestGenerateMongoRepository(t *testing.T) { {Field: "Age", Comparator: spec.ComparatorLessThanEqual, ParamIndex: 1}, }, }, + Sorts: []spec.Sort{ + {FieldName: "Age", Ordering: spec.OrderingAscending}, + }, }, }, { - Name: "FindByAgeGreaterThan", + Name: "FindByAgeGreaterThanOrderByAgeAsc", Params: []code.Param{ {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, {Name: "age", Type: code.SimpleType("int")}, @@ -112,10 +115,13 @@ func TestGenerateMongoRepository(t *testing.T) { {Field: "Age", Comparator: spec.ComparatorGreaterThan, ParamIndex: 1}, }, }, + Sorts: []spec.Sort{ + {FieldName: "Age", Ordering: spec.OrderingAscending}, + }, }, }, { - Name: "FindByAgeGreaterThanEqual", + Name: "FindByAgeGreaterThanEqualOrderByAgeDesc", Params: []code.Param{ {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, {Name: "age", Type: code.SimpleType("int")}, @@ -131,6 +137,9 @@ func TestGenerateMongoRepository(t *testing.T) { {Field: "Age", Comparator: spec.ComparatorGreaterThanEqual, ParamIndex: 1}, }, }, + Sorts: []spec.Sort{ + {FieldName: "Age", Ordering: spec.OrderingDescending}, + }, }, }, { @@ -191,6 +200,7 @@ import ( "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" ) func NewUserRepository(collection *mongo.Collection) UserRepository { @@ -207,7 +217,7 @@ func (r *UserRepositoryMongo) FindByID(arg0 context.Context, arg1 primitive.Obje var entity UserModel if err := r.collection.FindOne(arg0, bson.M{ "_id": arg1, - }).Decode(&entity); err != nil { + }, options.FindOne().SetSort(bson.M{})).Decode(&entity); err != nil { return nil, err } return &entity, nil @@ -219,7 +229,7 @@ func (r *UserRepositoryMongo) FindByGenderNotAndAgeLessThan(arg0 context.Context {"gender": bson.M{"$ne": arg1}}, {"age": bson.M{"$lt": arg2}}, }, - }) + }, options.Find().SetSort(bson.M{})) if err != nil { return nil, err } @@ -230,10 +240,12 @@ func (r *UserRepositoryMongo) FindByGenderNotAndAgeLessThan(arg0 context.Context return entities, nil } -func (r *UserRepositoryMongo) FindByAgeLessThanEqual(arg0 context.Context, arg1 int) ([]*UserModel, error) { +func (r *UserRepositoryMongo) FindByAgeLessThanEqualOrderByAge(arg0 context.Context, arg1 int) ([]*UserModel, error) { cursor, err := r.collection.Find(arg0, bson.M{ "age": bson.M{"$lte": arg1}, - }) + }, options.Find().SetSort(bson.M{ + "age": 1, + })) if err != nil { return nil, err } @@ -244,10 +256,12 @@ func (r *UserRepositoryMongo) FindByAgeLessThanEqual(arg0 context.Context, arg1 return entities, nil } -func (r *UserRepositoryMongo) FindByAgeGreaterThan(arg0 context.Context, arg1 int) ([]*UserModel, error) { +func (r *UserRepositoryMongo) FindByAgeGreaterThanOrderByAgeAsc(arg0 context.Context, arg1 int) ([]*UserModel, error) { cursor, err := r.collection.Find(arg0, bson.M{ "age": bson.M{"$gt": arg1}, - }) + }, options.Find().SetSort(bson.M{ + "age": 1, + })) if err != nil { return nil, err } @@ -258,10 +272,12 @@ func (r *UserRepositoryMongo) FindByAgeGreaterThan(arg0 context.Context, arg1 in return entities, nil } -func (r *UserRepositoryMongo) FindByAgeGreaterThanEqual(arg0 context.Context, arg1 int) ([]*UserModel, error) { +func (r *UserRepositoryMongo) FindByAgeGreaterThanEqualOrderByAgeDesc(arg0 context.Context, arg1 int) ([]*UserModel, error) { cursor, err := r.collection.Find(arg0, bson.M{ "age": bson.M{"$gte": arg1}, - }) + }, options.Find().SetSort(bson.M{ + "age": -1, + })) if err != nil { return nil, err } @@ -275,7 +291,7 @@ func (r *UserRepositoryMongo) FindByAgeGreaterThanEqual(arg0 context.Context, ar func (r *UserRepositoryMongo) FindByAgeBetween(arg0 context.Context, arg1 int, arg2 int) ([]*UserModel, error) { cursor, err := r.collection.Find(arg0, bson.M{ "age": bson.M{"$gte": arg1, "$lte": arg2}, - }) + }, options.Find().SetSort(bson.M{})) if err != nil { return nil, err } @@ -292,7 +308,7 @@ func (r *UserRepositoryMongo) FindByGenderOrAge(arg0 context.Context, arg1 Gende {"gender": arg1}, {"age": arg2}, }, - }) + }, options.Find().SetSort(bson.M{})) if err != nil { return nil, err } diff --git a/internal/mongo/generator.go b/internal/mongo/generator.go index a70b871..13343e2 100644 --- a/internal/mongo/generator.go +++ b/internal/mongo/generator.go @@ -105,9 +105,15 @@ func (g RepositoryGenerator) generateFindImplementation(operation spec.FindOpera return "", err } + sorts, err := g.mongoSorts(operation.Sorts) + if err != nil { + return "", err + } + tmplData := mongoFindTemplateData{ EntityType: g.StructModel.Name, QuerySpec: querySpec, + Sorts: sorts, } if operation.Mode == spec.QueryModeOne { @@ -116,6 +122,24 @@ func (g RepositoryGenerator) generateFindImplementation(operation spec.FindOpera return generateFromTemplate("mongo_repository_findmany", findManyTemplate, tmplData) } +func (g RepositoryGenerator) mongoSorts(sortSpec []spec.Sort) ([]sort, error) { + var sorts []sort + + for _, s := range sortSpec { + bsonTag, err := g.bsonTagFromFieldName(s.FieldName) + if err != nil { + return nil, err + } + + sorts = append(sorts, sort{ + BsonTag: bsonTag, + Ordering: s.Ordering, + }) + } + + return sorts, nil +} + func (g RepositoryGenerator) generateUpdateImplementation(operation spec.UpdateOperation) (string, error) { update, err := g.getMongoUpdate(operation.Update) if err != nil { diff --git a/internal/mongo/generator_test.go b/internal/mongo/generator_test.go index fe5ae05..4c8f35b 100644 --- a/internal/mongo/generator_test.go +++ b/internal/mongo/generator_test.go @@ -52,6 +52,7 @@ import ( "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" ) func NewUserRepository(collection *mongo.Collection) UserRepository { @@ -192,7 +193,8 @@ func (r *UserRepositoryMongo) FindByID(arg0 context.Context, arg1 primitive.Obje var entity UserModel if err := r.collection.FindOne(arg0, bson.M{ "_id": arg1, - }).Decode(&entity); err != nil { + }, options.FindOne().SetSort(bson.M{ + })).Decode(&entity); err != nil { return nil, err } return &entity, nil @@ -224,7 +226,8 @@ func (r *UserRepositoryMongo) FindByID(arg0 context.Context, arg1 primitive.Obje func (r *UserRepositoryMongo) FindByGender(arg0 context.Context, arg1 Gender) ([]*UserModel, error) { cursor, err := r.collection.Find(arg0, bson.M{ "gender": arg1, - }) + }, options.Find().SetSort(bson.M{ + })) if err != nil { return nil, err } @@ -267,7 +270,8 @@ func (r *UserRepositoryMongo) FindByGenderAndAge(arg0 context.Context, arg1 Gend {"gender": arg1}, {"age": arg2}, }, - }) + }, options.Find().SetSort(bson.M{ + })) if err != nil { return nil, err } @@ -310,7 +314,8 @@ func (r *UserRepositoryMongo) FindByGenderOrAge(arg0 context.Context, arg1 Gende {"gender": arg1}, {"age": arg2}, }, - }) + }, options.Find().SetSort(bson.M{ + })) if err != nil { return nil, err } @@ -347,7 +352,8 @@ func (r *UserRepositoryMongo) FindByGenderOrAge(arg0 context.Context, arg1 Gende func (r *UserRepositoryMongo) FindByGenderNot(arg0 context.Context, arg1 Gender) ([]*UserModel, error) { cursor, err := r.collection.Find(arg0, bson.M{ "gender": bson.M{"$ne": arg1}, - }) + }, options.Find().SetSort(bson.M{ + })) if err != nil { return nil, err } @@ -384,7 +390,8 @@ func (r *UserRepositoryMongo) FindByGenderNot(arg0 context.Context, arg1 Gender) func (r *UserRepositoryMongo) FindByAgeLessThan(arg0 context.Context, arg1 int) ([]*UserModel, error) { cursor, err := r.collection.Find(arg0, bson.M{ "age": bson.M{"$lt": arg1}, - }) + }, options.Find().SetSort(bson.M{ + })) if err != nil { return nil, err } @@ -421,7 +428,8 @@ func (r *UserRepositoryMongo) FindByAgeLessThan(arg0 context.Context, arg1 int) func (r *UserRepositoryMongo) FindByAgeLessThanEqual(arg0 context.Context, arg1 int) ([]*UserModel, error) { cursor, err := r.collection.Find(arg0, bson.M{ "age": bson.M{"$lte": arg1}, - }) + }, options.Find().SetSort(bson.M{ + })) if err != nil { return nil, err } @@ -458,7 +466,8 @@ func (r *UserRepositoryMongo) FindByAgeLessThanEqual(arg0 context.Context, arg1 func (r *UserRepositoryMongo) FindByAgeGreaterThan(arg0 context.Context, arg1 int) ([]*UserModel, error) { cursor, err := r.collection.Find(arg0, bson.M{ "age": bson.M{"$gt": arg1}, - }) + }, options.Find().SetSort(bson.M{ + })) if err != nil { return nil, err } @@ -495,7 +504,8 @@ func (r *UserRepositoryMongo) FindByAgeGreaterThan(arg0 context.Context, arg1 in func (r *UserRepositoryMongo) FindByAgeGreaterThanEqual(arg0 context.Context, arg1 int) ([]*UserModel, error) { cursor, err := r.collection.Find(arg0, bson.M{ "age": bson.M{"$gte": arg1}, - }) + }, options.Find().SetSort(bson.M{ + })) if err != nil { return nil, err } @@ -533,7 +543,8 @@ func (r *UserRepositoryMongo) FindByAgeGreaterThanEqual(arg0 context.Context, ar func (r *UserRepositoryMongo) FindByAgeBetween(arg0 context.Context, arg1 int, arg2 int) ([]*UserModel, error) { cursor, err := r.collection.Find(arg0, bson.M{ "age": bson.M{"$gte": arg1, "$lte": arg2}, - }) + }, options.Find().SetSort(bson.M{ + })) if err != nil { return nil, err } @@ -570,7 +581,8 @@ func (r *UserRepositoryMongo) FindByAgeBetween(arg0 context.Context, arg1 int, a func (r *UserRepositoryMongo) FindByGenderIn(arg0 context.Context, arg1 []Gender) ([]*UserModel, error) { cursor, err := r.collection.Find(arg0, bson.M{ "gender": bson.M{"$in": arg1}, - }) + }, options.Find().SetSort(bson.M{ + })) if err != nil { return nil, err } @@ -607,7 +619,8 @@ func (r *UserRepositoryMongo) FindByGenderIn(arg0 context.Context, arg1 []Gender func (r *UserRepositoryMongo) FindByGenderNotIn(arg0 context.Context, arg1 []Gender) ([]*UserModel, error) { cursor, err := r.collection.Find(arg0, bson.M{ "gender": bson.M{"$nin": arg1}, - }) + }, options.Find().SetSort(bson.M{ + })) if err != nil { return nil, err } @@ -643,7 +656,8 @@ func (r *UserRepositoryMongo) FindByGenderNotIn(arg0 context.Context, arg1 []Gen func (r *UserRepositoryMongo) FindByEnabledTrue(arg0 context.Context) ([]*UserModel, error) { cursor, err := r.collection.Find(arg0, bson.M{ "enabled": true, - }) + }, options.Find().SetSort(bson.M{ + })) if err != nil { return nil, err } @@ -679,7 +693,118 @@ func (r *UserRepositoryMongo) FindByEnabledTrue(arg0 context.Context) ([]*UserMo func (r *UserRepositoryMongo) FindByEnabledFalse(arg0 context.Context) ([]*UserModel, error) { cursor, err := r.collection.Find(arg0, bson.M{ "enabled": false, - }) + }, options.Find().SetSort(bson.M{ + })) + if err != nil { + return nil, err + } + var entities []*UserModel + if err := cursor.All(arg0, &entities); err != nil { + return nil, err + } + return entities, nil +} +`, + }, + { + Name: "find with sort ascending", + MethodSpec: spec.MethodSpec{ + Name: "FindAllOrderByAge", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + }, + Returns: []code.Type{ + code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + code.SimpleType("error"), + }, + Operation: spec.FindOperation{ + Mode: spec.QueryModeMany, + Sorts: []spec.Sort{ + {FieldName: "Age", Ordering: spec.OrderingAscending}, + }, + }, + }, + ExpectedCode: ` +func (r *UserRepositoryMongo) FindAllOrderByAge(arg0 context.Context) ([]*UserModel, error) { + cursor, err := r.collection.Find(arg0, bson.M{ + + }, options.Find().SetSort(bson.M{ + "age": 1, + })) + if err != nil { + return nil, err + } + var entities []*UserModel + if err := cursor.All(arg0, &entities); err != nil { + return nil, err + } + return entities, nil +} +`, + }, + { + Name: "find with sort descending", + MethodSpec: spec.MethodSpec{ + Name: "FindAllOrderByAge", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + }, + Returns: []code.Type{ + code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + code.SimpleType("error"), + }, + Operation: spec.FindOperation{ + Mode: spec.QueryModeMany, + Sorts: []spec.Sort{ + {FieldName: "Age", Ordering: spec.OrderingDescending}, + }, + }, + }, + ExpectedCode: ` +func (r *UserRepositoryMongo) FindAllOrderByAge(arg0 context.Context) ([]*UserModel, error) { + cursor, err := r.collection.Find(arg0, bson.M{ + + }, options.Find().SetSort(bson.M{ + "age": -1, + })) + if err != nil { + return nil, err + } + var entities []*UserModel + if err := cursor.All(arg0, &entities); err != nil { + return nil, err + } + return entities, nil +} +`, + }, + { + Name: "find with multiple sorts", + MethodSpec: spec.MethodSpec{ + Name: "FindAllOrderByGenderAndAgeDesc", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + }, + Returns: []code.Type{ + code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + code.SimpleType("error"), + }, + Operation: spec.FindOperation{ + Mode: spec.QueryModeMany, + Sorts: []spec.Sort{ + {FieldName: "Gender", Ordering: spec.OrderingAscending}, + {FieldName: "Age", Ordering: spec.OrderingDescending}, + }, + }, + }, + ExpectedCode: ` +func (r *UserRepositoryMongo) FindAllOrderByGenderAndAgeDesc(arg0 context.Context) ([]*UserModel, error) { + cursor, err := r.collection.Find(arg0, bson.M{ + + }, options.Find().SetSort(bson.M{ + "gender": 1, + "age": -1, + })) if err != nil { return nil, err } @@ -1652,6 +1777,26 @@ func TestGenerateMethod_Invalid(t *testing.T) { }, ExpectedError: mongo.NewBsonTagNotFoundError("AccessToken"), }, + { + Name: "bson tag not found in sort", + Method: spec.MethodSpec{ + Name: "FindAllOrderByAccessToken", + Params: []code.Param{ + {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + }, + Returns: []code.Type{ + code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + code.SimpleType("error"), + }, + Operation: spec.FindOperation{ + Mode: spec.QueryModeOne, + Sorts: []spec.Sort{ + {FieldName: "AccessToken", Ordering: spec.OrderingAscending}, + }, + }, + }, + ExpectedError: mongo.NewBsonTagNotFoundError("AccessToken"), + }, { Name: "bson tag not found in update field", Method: spec.MethodSpec{ diff --git a/internal/mongo/templates.go b/internal/mongo/templates.go index a9acff0..a00feb8 100644 --- a/internal/mongo/templates.go +++ b/internal/mongo/templates.go @@ -5,6 +5,7 @@ import ( "strings" "github.com/sunboyy/repogen/internal/code" + "github.com/sunboyy/repogen/internal/spec" ) const constructorTemplate = ` @@ -14,6 +15,7 @@ import ( "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" ) func New{{.InterfaceName}}(collection *mongo.Collection) {{.InterfaceName}} { @@ -89,19 +91,36 @@ const insertManyTemplate = ` var entities []interface{} type mongoFindTemplateData struct { EntityType string QuerySpec querySpec + Sorts []sort +} + +type sort struct { + BsonTag string + Ordering spec.Ordering +} + +func (s sort) OrderNum() int { + if s.Ordering == spec.OrderingAscending { + return 1 + } + return -1 } const findOneTemplate = ` var entity {{.EntityType}} if err := r.collection.FindOne(arg0, bson.M{ {{.QuerySpec.Code}} - }).Decode(&entity); err != nil { + }, options.FindOne().SetSort(bson.M{ +{{range $index, $element := .Sorts}} "{{$element.BsonTag}}": {{$element.OrderNum}}, +{{end}} })).Decode(&entity); err != nil { return nil, err } return &entity, nil` const findManyTemplate = ` cursor, err := r.collection.Find(arg0, bson.M{ {{.QuerySpec.Code}} - }) + }, options.Find().SetSort(bson.M{ +{{range $index, $element := .Sorts}} "{{$element.BsonTag}}": {{$element.OrderNum}}, +{{end}} })) if err != nil { return nil, err } diff --git a/internal/spec/errors.go b/internal/spec/errors.go index a25f07a..c7100f0 100644 --- a/internal/spec/errors.go +++ b/internal/spec/errors.go @@ -48,6 +48,19 @@ func (err invalidQueryError) Error() string { return fmt.Sprintf("invalid query '%s'", err.QueryString) } +// NewInvalidSortError creates invalidSortError +func NewInvalidSortError(sortTokens []string) error { + return invalidSortError{SortString: strings.Join(sortTokens, "")} +} + +type invalidSortError struct { + SortString string +} + +func (err invalidSortError) Error() string { + return fmt.Sprintf("invalid sort '%s'", err.SortString) +} + // NewUnknownOperationError creates unknownOperationError func NewUnknownOperationError(operationName string) error { return unknownOperationError{OperationName: operationName} diff --git a/internal/spec/errors_test.go b/internal/spec/errors_test.go index 28f4b6c..8b9fdca 100644 --- a/internal/spec/errors_test.go +++ b/internal/spec/errors_test.go @@ -38,6 +38,11 @@ func TestError(t *testing.T) { }), ExpectedString: "cannot use comparator EQUAL_TRUE with struct field 'Age' of type 'int'", }, + { + Name: "InvalidSortError", + Error: spec.NewInvalidSortError([]string{"Order", "By"}), + ExpectedString: "invalid sort 'OrderBy'", + }, } for _, testCase := range testTable { diff --git a/internal/spec/models.go b/internal/spec/models.go index f8e1720..cac7e4c 100644 --- a/internal/spec/models.go +++ b/internal/spec/models.go @@ -40,6 +40,7 @@ func (o InsertOperation) Name() string { type FindOperation struct { Mode QueryMode Query QuerySpec + Sorts []Sort } // Name returns "Find" operation name @@ -47,6 +48,21 @@ func (o FindOperation) Name() string { return "Find" } +// Sort is a detail of sorting find result +type Sort struct { + FieldName string + Ordering Ordering +} + +// Ordering is a sort order +type Ordering string + +// Ordering constants +const ( + OrderingAscending = "ASC" + OrderingDescending = "DESC" +) + // UpdateOperation is a method specification for update operations type UpdateOperation struct { Update Update diff --git a/internal/spec/parser.go b/internal/spec/parser.go index f7f57a8..105b26f 100644 --- a/internal/spec/parser.go +++ b/internal/spec/parser.go @@ -1,6 +1,8 @@ package spec import ( + "strings" + "github.com/fatih/camelcase" "github.com/sunboyy/repogen/internal/code" ) @@ -111,7 +113,14 @@ func (p interfaceMethodParser) parseFindOperation(tokens []string) (Operation, e return nil, err } - querySpec, err := parseQuery(tokens, 1) + queryTokens, sortTokens := p.splitQueryAndSortTokens(tokens) + + querySpec, err := parseQuery(queryTokens, 1) + if err != nil { + return nil, err + } + + sorts, err := p.parseSort(sortTokens) if err != nil { return nil, err } @@ -127,9 +136,65 @@ func (p interfaceMethodParser) parseFindOperation(tokens []string) (Operation, e return FindOperation{ Mode: mode, Query: querySpec, + Sorts: sorts, }, nil } +func (p interfaceMethodParser) parseSort(rawTokens []string) ([]Sort, error) { + if len(rawTokens) == 0 { + return nil, nil + } + + sortTokens := rawTokens[2:] + + var sorts []Sort + var aggregatedToken sortToken + for _, token := range sortTokens { + if token != "And" { + aggregatedToken = append(aggregatedToken, token) + } else if len(aggregatedToken) == 0 { + return nil, NewInvalidSortError(rawTokens) + } else { + sorts = append(sorts, aggregatedToken.ToSort()) + aggregatedToken = sortToken{} + } + } + if len(aggregatedToken) == 0 { + return nil, NewInvalidSortError(rawTokens) + } + sorts = append(sorts, aggregatedToken.ToSort()) + + return sorts, nil +} + +type sortToken []string + +func (t sortToken) ToSort() Sort { + if len(t) > 1 && t[len(t)-1] == "Asc" { + return Sort{FieldName: strings.Join(t[:len(t)-1], ""), Ordering: OrderingAscending} + } + if len(t) > 1 && t[len(t)-1] == "Desc" { + return Sort{FieldName: strings.Join(t[:len(t)-1], ""), Ordering: OrderingDescending} + } + return Sort{FieldName: strings.Join(t, ""), Ordering: OrderingAscending} +} + +func (p interfaceMethodParser) splitQueryAndSortTokens(tokens []string) ([]string, []string) { + var queryTokens []string + var sortTokens []string + + for i, token := range tokens { + if len(tokens) > i && token == "Order" && tokens[i+1] == "By" { + sortTokens = tokens[i:] + break + } else { + queryTokens = append(queryTokens, token) + } + } + + return queryTokens, sortTokens +} + func (p interfaceMethodParser) extractModelOrSliceReturns(returns []code.Type) (QueryMode, error) { if len(returns) != 2 { return "", UnsupportedReturnError diff --git a/internal/spec/parser_test.go b/internal/spec/parser_test.go index 48d53fe..7a840bb 100644 --- a/internal/spec/parser_test.go +++ b/internal/spec/parser_test.go @@ -429,6 +429,99 @@ func TestParseInterfaceMethod_Find(t *testing.T) { }}, }, }, + { + Name: "FindByArgOrderByArg method", + Method: code.Method{ + Name: "FindByCityOrderByAge", + Params: []code.Param{ + {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Type: code.SimpleType("string")}, + }, + Returns: []code.Type{ + code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + code.SimpleType("error"), + }, + }, + ExpectedOperation: spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + {Field: "City", Comparator: spec.ComparatorEqual, ParamIndex: 1}, + }}, + Sorts: []spec.Sort{ + {FieldName: "Age", Ordering: spec.OrderingAscending}, + }, + }, + }, + { + Name: "FindByArgOrderByArgAsc method", + Method: code.Method{ + Name: "FindByCityOrderByAgeAsc", + Params: []code.Param{ + {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Type: code.SimpleType("string")}, + }, + Returns: []code.Type{ + code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + code.SimpleType("error"), + }, + }, + ExpectedOperation: spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + {Field: "City", Comparator: spec.ComparatorEqual, ParamIndex: 1}, + }}, + Sorts: []spec.Sort{ + {FieldName: "Age", Ordering: spec.OrderingAscending}, + }, + }, + }, + { + Name: "FindByArgOrderByArgDesc method", + Method: code.Method{ + Name: "FindByCityOrderByAgeDesc", + Params: []code.Param{ + {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Type: code.SimpleType("string")}, + }, + Returns: []code.Type{ + code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + code.SimpleType("error"), + }, + }, + ExpectedOperation: spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + {Field: "City", Comparator: spec.ComparatorEqual, ParamIndex: 1}, + }}, + Sorts: []spec.Sort{ + {FieldName: "Age", Ordering: spec.OrderingDescending}, + }, + }, + }, + { + Name: "FindByArgOrderByArgAndArg method", + Method: code.Method{ + Name: "FindByCityOrderByCityAndAgeDesc", + Params: []code.Param{ + {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Type: code.SimpleType("string")}, + }, + Returns: []code.Type{ + code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + code.SimpleType("error"), + }, + }, + ExpectedOperation: spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + {Field: "City", Comparator: spec.ComparatorEqual, ParamIndex: 1}, + }}, + Sorts: []spec.Sort{ + {FieldName: "City", Ordering: spec.OrderingAscending}, + {FieldName: "Age", Ordering: spec.OrderingDescending}, + }, + }, + }, } for _, testCase := range testTable { @@ -1248,6 +1341,39 @@ func TestParseInterfaceMethod_Find_Invalid(t *testing.T) { }, ExpectedError: spec.InvalidParamError, }, + { + Name: "misplaced operator token (leftmost)", + Method: code.Method{ + Name: "FindAllOrderByAndAge", + Returns: []code.Type{ + code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + code.SimpleType("error"), + }, + }, + ExpectedError: spec.NewInvalidSortError([]string{"Order", "By", "And", "Age"}), + }, + { + Name: "misplaced operator token (rightmost)", + Method: code.Method{ + Name: "FindAllOrderByAgeAnd", + Returns: []code.Type{ + code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + code.SimpleType("error"), + }, + }, + ExpectedError: spec.NewInvalidSortError([]string{"Order", "By", "Age", "And"}), + }, + { + Name: "misplaced operator token (double operator)", + Method: code.Method{ + Name: "FindAllOrderByAgeAndAndGender", + Returns: []code.Type{ + code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + code.SimpleType("error"), + }, + }, + ExpectedError: spec.NewInvalidSortError([]string{"Order", "By", "Age", "And", "And", "Gender"}), + }, } for _, testCase := range testTable {