diff --git a/internal/code/extractor.go b/internal/code/extractor.go index 313aa8c..93320bc 100644 --- a/internal/code/extractor.go +++ b/internal/code/extractor.go @@ -79,14 +79,19 @@ func ExtractComponents(f *ast.File) File { } for _, param := range funcType.Params.List { - var p Param - for _, name := range param.Names { - p.Name = name.Name - break - } - p.Type = getType(param.Type) + paramType := getType(param.Type) - meth.Params = append(meth.Params, p) + if len(param.Names) == 0 { + meth.Params = append(meth.Params, Param{Type: paramType}) + continue + } + + for _, name := range param.Names { + meth.Params = append(meth.Params, Param{ + Name: name.Name, + Type: paramType, + }) + } } for _, result := range funcType.Results.List { diff --git a/internal/code/extractor_test.go b/internal/code/extractor_test.go index d77af3d..0cea722 100644 --- a/internal/code/extractor_test.go +++ b/internal/code/extractor_test.go @@ -96,6 +96,7 @@ type UserModel struct { type UserRepository interface { FindOneByID(ctx context.Context, id primitive.ObjectID) (*UserModel, error) FindAll(context.Context) ([]*UserModel, error) + FindByAgeBetween(ctx context.Context, fromAge, toAge int) ([]*UserModel, error) }`, ExpectedOutput: code.File{ PackageName: "user", @@ -124,6 +125,18 @@ type UserRepository interface { code.SimpleType("error"), }, }, + { + Name: "FindByAgeBetween", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Name: "fromAge", Type: code.SimpleType("int")}, + {Name: "toAge", Type: code.SimpleType("int")}, + }, + Returns: []code.Type{ + code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + code.SimpleType("error"), + }, + }, }, }, }, diff --git a/internal/generator/generator_test.go b/internal/generator/generator_test.go index 00be3d7..4ce0880 100644 --- a/internal/generator/generator_test.go +++ b/internal/generator/generator_test.go @@ -133,6 +133,26 @@ func TestGenerateMongoRepository(t *testing.T) { }, }, }, + { + Name: "FindByAgeBetween", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Name: "fromAge", Type: code.SimpleType("int")}, + {Name: "toAge", Type: code.SimpleType("int")}, + }, + Returns: []code.Type{ + code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + code.SimpleType("error"), + }, + Operation: spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + {Field: "Age", Comparator: spec.ComparatorBetween}, + }, + }, + }, + }, { Name: "FindByGenderOrAge", Params: []code.Param{ @@ -250,6 +270,20 @@ func (r *UserRepositoryMongo) FindByAgeGreaterThanEqual(ctx context.Context, arg return entities, nil } +func (r *UserRepositoryMongo) FindByAgeBetween(ctx context.Context, arg0 int, arg1 int) ([]*UserModel, error) { + cursor, err := r.collection.Find(ctx, bson.M{ + "age": bson.M{"$gte": arg0, "$lte": arg1}, + }) + if err != nil { + return nil, err + } + var entities []*UserModel + if err := cursor.All(ctx, &entities); err != nil { + return nil, err + } + return entities, nil +} + func (r *UserRepositoryMongo) FindByGenderOrAge(ctx context.Context, arg0 Gender, arg1 int) ([]*UserModel, error) { cursor, err := r.collection.Find(ctx, bson.M{ "$or": []bson.M{ diff --git a/internal/mongo/generator_test.go b/internal/mongo/generator_test.go index a581d7e..0579e15 100644 --- a/internal/mongo/generator_test.go +++ b/internal/mongo/generator_test.go @@ -411,6 +411,44 @@ func (r *UserRepositoryMongo) FindByAgeGreaterThanEqual(ctx context.Context, arg } return entities, nil } +`, + }, + { + Name: "find with Between comparator", + MethodSpec: spec.MethodSpec{ + Name: "FindByAgeBetween", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Name: "fromAge", Type: code.SimpleType("int")}, + {Name: "toAge", Type: code.SimpleType("int")}, + }, + Returns: []code.Type{ + code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + code.SimpleType("error"), + }, + Operation: spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + {Comparator: spec.ComparatorBetween, Field: "Age"}, + }, + }, + }, + }, + ExpectedCode: ` +func (r *UserRepositoryMongo) FindByAgeBetween(ctx context.Context, arg0 int, arg1 int) ([]*UserModel, error) { + cursor, err := r.collection.Find(ctx, bson.M{ + "age": bson.M{"$gte": arg0, "$lte": arg1}, + }) + if err != nil { + return nil, err + } + var entities []*UserModel + if err := cursor.All(ctx, &entities); err != nil { + return nil, err + } + return entities, nil +} `, }, { diff --git a/internal/mongo/models.go b/internal/mongo/models.go index 5e69e8f..92d2cfd 100644 --- a/internal/mongo/models.go +++ b/internal/mongo/models.go @@ -14,8 +14,10 @@ type querySpec struct { func (q querySpec) Code() string { var predicateCodes []string - for i, predicate := range q.Predicates { - predicateCodes = append(predicateCodes, predicate.Code(i)) + var argIndex int + for _, predicate := range q.Predicates { + predicateCodes = append(predicateCodes, predicate.Code(argIndex)) + argIndex += predicate.Comparator.NumberOfArguments() } var lines []string @@ -53,6 +55,8 @@ func (p predicate) Code(argIndex int) string { return fmt.Sprintf(`"%s": bson.M{"$gt": arg%d}`, p.Field, argIndex) case spec.ComparatorGreaterThanEqual: return fmt.Sprintf(`"%s": bson.M{"$gte": arg%d}`, p.Field, argIndex) + case spec.ComparatorBetween: + return fmt.Sprintf(`"%s": bson.M{"$gte": arg%d, "$lte": arg%d}`, p.Field, argIndex, argIndex+1) case spec.ComparatorIn: return fmt.Sprintf(`"%s": bson.M{"$in": arg%d}`, p.Field, argIndex) } diff --git a/internal/spec/models.go b/internal/spec/models.go index 0313ef8..a993f6b 100644 --- a/internal/spec/models.go +++ b/internal/spec/models.go @@ -41,7 +41,11 @@ type QuerySpec struct { // NumberOfArguments returns number of arguments required to perform the query func (q QuerySpec) NumberOfArguments() int { - return len(q.Predicates) + var totalArgs int + for _, predicate := range q.Predicates { + totalArgs += predicate.Comparator.NumberOfArguments() + } + return totalArgs } // Operator is a boolean operator for merging conditions @@ -64,6 +68,7 @@ const ( ComparatorLessThanEqual Comparator = "LESS_THAN_EQUAL" ComparatorGreaterThan Comparator = "GREATER_THAN" ComparatorGreaterThanEqual Comparator = "GREATER_THAN_EQUAL" + ComparatorBetween Comparator = "BETWEEN" ComparatorIn Comparator = "IN" ) @@ -75,6 +80,14 @@ func (c Comparator) ArgumentTypeFromFieldType(t code.Type) code.Type { return t } +// NumberOfArguments returns the number of arguments required to perform the comparison +func (c Comparator) NumberOfArguments() int { + if c == ComparatorBetween { + return 2 + } + return 1 +} + // Predicate is a criteria for querying a field type Predicate struct { Field string @@ -102,5 +115,8 @@ func (t predicateToken) ToPredicate() Predicate { if len(t) > 1 && t[len(t)-1] == "In" { return Predicate{Field: strings.Join(t[:len(t)-1], ""), Comparator: ComparatorIn} } + if len(t) > 1 && t[len(t)-1] == "Between" { + return Predicate{Field: strings.Join(t[:len(t)-1], ""), Comparator: ComparatorBetween} + } return Predicate{Field: strings.Join(t, ""), Comparator: ComparatorEqual} } diff --git a/internal/spec/parser.go b/internal/spec/parser.go index 7a98503..95e91de 100644 --- a/internal/spec/parser.go +++ b/internal/spec/parser.go @@ -157,12 +157,13 @@ func (p interfaceMethodParser) validateMethodSignature(querySpec QuerySpec) erro return StructFieldNotFoundError } - if p.Method.Params[currentParamIndex].Type != predicate.Comparator.ArgumentTypeFromFieldType( - structField.Type) { - return InvalidParamError + for i := 0; i < predicate.Comparator.NumberOfArguments(); i++ { + if p.Method.Params[currentParamIndex].Type != predicate.Comparator.ArgumentTypeFromFieldType( + structField.Type) { + return InvalidParamError + } + currentParamIndex++ } - - currentParamIndex++ } return nil diff --git a/internal/spec/parser_test.go b/internal/spec/parser_test.go index 9a0a869..feffacd 100644 --- a/internal/spec/parser_test.go +++ b/internal/spec/parser_test.go @@ -390,6 +390,39 @@ func TestParseInterfaceMethod(t *testing.T) { }, }, }, + { + Name: "FindByArgBetween method", + Method: code.Method{ + Name: "FindByAgeBetween", + Params: []code.Param{ + {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Type: code.SimpleType("int")}, + {Type: code.SimpleType("int")}, + }, + Returns: []code.Type{ + code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + code.SimpleType("error"), + }, + }, + ExpectedOutput: spec.MethodSpec{ + Name: "FindByAgeBetween", + Params: []code.Param{ + {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Type: code.SimpleType("int")}, + {Type: code.SimpleType("int")}, + }, + Returns: []code.Type{ + code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + code.SimpleType("error"), + }, + Operation: spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + {Field: "Age", Comparator: spec.ComparatorBetween}, + }}, + }, + }, + }, { Name: "FindByArgIn method", Method: code.Method{