diff --git a/internal/mongo/common.go b/internal/mongo/common.go index bc8992a..b47d5a8 100644 --- a/internal/mongo/common.go +++ b/internal/mongo/common.go @@ -8,6 +8,8 @@ import ( "github.com/sunboyy/repogen/internal/spec" ) +var errOccurred = codegen.RawStatement("err != nil") + var returnNilErr = codegen.ReturnStatement{ codegen.Identifier("nil"), codegen.Identifier("err"), @@ -15,7 +17,7 @@ var returnNilErr = codegen.ReturnStatement{ var ifErrReturnNilErr = codegen.IfBlock{ Condition: []codegen.Statement{ - codegen.RawStatement("err != nil"), + errOccurred, }, Statements: []codegen.Statement{ returnNilErr, @@ -24,7 +26,7 @@ var ifErrReturnNilErr = codegen.IfBlock{ var ifErrReturn0Err = codegen.IfBlock{ Condition: []codegen.Statement{ - codegen.RawStatement("err != nil"), + errOccurred, }, Statements: []codegen.Statement{ codegen.ReturnStatement{ @@ -36,7 +38,7 @@ var ifErrReturn0Err = codegen.IfBlock{ var ifErrReturnFalseErr = codegen.IfBlock{ Condition: []codegen.Statement{ - codegen.RawStatement("err != nil"), + errOccurred, }, Statements: []codegen.Statement{ codegen.ReturnStatement{ diff --git a/internal/mongo/find.go b/internal/mongo/find.go index ed6f155..7df4f56 100644 --- a/internal/mongo/find.go +++ b/internal/mongo/find.go @@ -1,6 +1,8 @@ package mongo import ( + "strconv" + "github.com/sunboyy/repogen/internal/code" "github.com/sunboyy/repogen/internal/codegen" "github.com/sunboyy/repogen/internal/spec" @@ -66,7 +68,7 @@ func (g findBodyGenerator) generateFindOneBody(querySpec querySpec, ).Build(), }, }, - codegen.RawStatement("err != nil"), + errOccurred, }, Statements: []codegen.Statement{ returnNilErr, @@ -91,10 +93,7 @@ func (g findBodyGenerator) generateFindManyBody(querySpec querySpec, Call("Find", codegen.Identifier("arg0"), querySpec.Code(), - codegen.NewChainBuilder("options"). - Call("Find"). - Call("SetSort", sortsCode). - Build(), + g.findManyOptions(sortsCode), ).Build(), }, }, @@ -119,7 +118,7 @@ func (g findBodyGenerator) generateFindManyBody(querySpec querySpec, ).Build(), }, }, - codegen.RawStatement("err != nil"), + errOccurred, }, Statements: []codegen.Statement{ returnNilErr, @@ -132,6 +131,21 @@ func (g findBodyGenerator) generateFindManyBody(querySpec querySpec, } } +func (g findBodyGenerator) findManyOptions( + sortsCode codegen.MapStatement) codegen.Statement { + + optionsBuilder := codegen.NewChainBuilder("options"). + Call("Find"). + Call("SetSort", sortsCode) + if g.operation.Limit > 0 { + optionsBuilder = optionsBuilder.Call("SetLimit", + codegen.Identifier(strconv.Itoa(g.operation.Limit)), + ) + } + + return optionsBuilder.Build() +} + func (g findBodyGenerator) generateSortMap() ( codegen.MapStatement, error) { diff --git a/internal/mongo/find_test.go b/internal/mongo/find_test.go index de1299d..d02e657 100644 --- a/internal/mongo/find_test.go +++ b/internal/mongo/find_test.go @@ -825,6 +825,38 @@ func TestGenerateMethod_Find(t *testing.T) { if err := cursor.All(arg0, &entities); err != nil { return nil, err } + return entities, nil`, + }, + { + Name: "find with limit", + MethodSpec: spec.MethodSpec{ + Name: "FindTop5AllOrderByAgeDesc", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + }, + Returns: []code.Type{ + code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + code.TypeError, + }, + Operation: spec.FindOperation{ + Mode: spec.QueryModeMany, + Sorts: []spec.Sort{ + {FieldReference: spec.FieldReference{ageField}, Ordering: spec.OrderingDescending}, + }, + Limit: 5, + }, + }, + ExpectedBody: ` cursor, err := r.collection.Find(arg0, bson.M{ + }, options.Find().SetSort(bson.M{ + "age": -1, + }).SetLimit(5)) + if err != nil { + return nil, err + } + var entities []*UserModel + if err := cursor.All(arg0, &entities); err != nil { + return nil, err + } return entities, nil`, }, } diff --git a/internal/spec/errors.go b/internal/spec/errors.go index 11189b9..37fb336 100644 --- a/internal/spec/errors.go +++ b/internal/spec/errors.go @@ -14,6 +14,9 @@ var ( ErrInvalidParam = errors.New("spec: parameters do not match the query") ErrInvalidUpdateFields = errors.New("spec: update fields are invalid") ErrContextParamRequired = errors.New("spec: context parameter is required") + ErrLimitAmountRequired = errors.New("spec: limit amount is required") + ErrLimitNonPositive = errors.New("spec: limit value must be positive") + ErrLimitOnFindOne = errors.New("spec: cannot specify limit on find one") ) // NewUnsupportedReturnError creates unsupportedReturnError diff --git a/internal/spec/models.go b/internal/spec/models.go index 7bff286..5c856cf 100644 --- a/internal/spec/models.go +++ b/internal/spec/models.go @@ -41,6 +41,7 @@ type FindOperation struct { Mode QueryMode Query QuerySpec Sorts []Sort + Limit int } // Name returns "Find" operation name diff --git a/internal/spec/parser.go b/internal/spec/parser.go index 1bfa9eb..aa5d9c6 100644 --- a/internal/spec/parser.go +++ b/internal/spec/parser.go @@ -1,6 +1,8 @@ package spec import ( + "strconv" + "github.com/fatih/camelcase" "github.com/sunboyy/repogen/internal/code" ) @@ -114,6 +116,14 @@ func (p interfaceMethodParser) parseFindOperation(tokens []string) (Operation, e return nil, err } + limit, tokens, err := p.parseFindTop(tokens) + if err != nil { + return nil, err + } + if mode == QueryModeOne && limit != 0 { + return nil, ErrLimitOnFindOne + } + queryTokens, sortTokens := p.splitQueryAndSortTokens(tokens) querySpec, err := p.parseQuery(queryTokens, 1) @@ -138,9 +148,32 @@ func (p interfaceMethodParser) parseFindOperation(tokens []string) (Operation, e Mode: mode, Query: querySpec, Sorts: sorts, + Limit: limit, }, nil } +func (p interfaceMethodParser) parseFindTop(tokens []string) (int, []string, + error) { + + if len(tokens) >= 1 && tokens[0] == "Top" { + if len(tokens) < 2 { + return 0, nil, ErrLimitAmountRequired + } + + limit, err := strconv.Atoi(tokens[1]) + if err != nil { + return 0, nil, ErrLimitAmountRequired + } + + if limit <= 0 { + return 0, nil, ErrLimitNonPositive + } + return limit, tokens[2:], nil + } + + return 0, tokens, nil +} + func (p interfaceMethodParser) parseSort(rawTokens []string) ([]Sort, error) { if len(rawTokens) == 0 { return nil, nil diff --git a/internal/spec/parser_test.go b/internal/spec/parser_test.go index 7e4e785..916e1d2 100644 --- a/internal/spec/parser_test.go +++ b/internal/spec/parser_test.go @@ -745,6 +745,30 @@ func TestParseInterfaceMethod_Find(t *testing.T) { }, }, }, + { + Name: "FindTopNByArg method", + Method: code.Method{ + Name: "FindTop5ByGenderOrderByAgeDesc", + Params: []code.Param{ + {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Type: code.SimpleType("Gender")}, + }, + Returns: []code.Type{ + code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + code.TypeError, + }, + }, + ExpectedOperation: spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + {FieldReference: spec.FieldReference{genderField}, Comparator: spec.ComparatorEqual, ParamIndex: 1}, + }}, + Sorts: []spec.Sort{ + {FieldReference: spec.FieldReference{ageField}, Ordering: spec.OrderingDescending}, + }, + Limit: 5, + }, + }, } for _, testCase := range testTable { @@ -1603,6 +1627,50 @@ func TestParseInterfaceMethod_Find_Invalid(t *testing.T) { }, ExpectedError: spec.NewUnsupportedReturnError(code.TypeInt, 1), }, + { + Name: "find method with Top keyword but no number and query", + Method: code.Method{ + Name: "FindTop", + Returns: []code.Type{ + code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + code.TypeError, + }, + }, + ExpectedError: spec.ErrLimitAmountRequired, + }, + { + Name: "find method with Top keyword but no number", + Method: code.Method{ + Name: "FindTopAll", + Returns: []code.Type{ + code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + code.TypeError, + }, + }, + ExpectedError: spec.ErrLimitAmountRequired, + }, + { + Name: "find method with TopN keyword where N is not positive", + Method: code.Method{ + Name: "FindTop0All", + Returns: []code.Type{ + code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + code.TypeError, + }, + }, + ExpectedError: spec.ErrLimitNonPositive, + }, + { + Name: "find one method with TopN keyword", + Method: code.Method{ + Name: "FindTop5All", + Returns: []code.Type{ + code.PointerType{ContainedType: code.SimpleType("UserModel")}, + code.TypeError, + }, + }, + ExpectedError: spec.ErrLimitOnFindOne, + }, { Name: "find method without query", Method: code.Method{