Merge pull request #9 from sunboyy/comparator-between

Add comparator BETWEEN and fix multi-names parameter code extractor
This commit is contained in:
sunboyy 2021-01-26 18:51:26 +07:00 committed by GitHub
commit c5c4542a80
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 159 additions and 15 deletions

View file

@ -79,14 +79,19 @@ func ExtractComponents(f *ast.File) File {
} }
for _, param := range funcType.Params.List { for _, param := range funcType.Params.List {
var p Param paramType := getType(param.Type)
for _, name := range param.Names {
p.Name = name.Name
break
}
p.Type = getType(param.Type)
meth.Params = append(meth.Params, p) if len(param.Names) == 0 {
meth.Params = append(meth.Params, Param{Type: paramType})
continue
}
for _, name := range param.Names {
meth.Params = append(meth.Params, Param{
Name: name.Name,
Type: paramType,
})
}
} }
for _, result := range funcType.Results.List { for _, result := range funcType.Results.List {

View file

@ -96,6 +96,7 @@ type UserModel struct {
type UserRepository interface { type UserRepository interface {
FindOneByID(ctx context.Context, id primitive.ObjectID) (*UserModel, error) FindOneByID(ctx context.Context, id primitive.ObjectID) (*UserModel, error)
FindAll(context.Context) ([]*UserModel, error) FindAll(context.Context) ([]*UserModel, error)
FindByAgeBetween(ctx context.Context, fromAge, toAge int) ([]*UserModel, error)
}`, }`,
ExpectedOutput: code.File{ ExpectedOutput: code.File{
PackageName: "user", PackageName: "user",
@ -124,6 +125,18 @@ type UserRepository interface {
code.SimpleType("error"), code.SimpleType("error"),
}, },
}, },
{
Name: "FindByAgeBetween",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "fromAge", Type: code.SimpleType("int")},
{Name: "toAge", Type: code.SimpleType("int")},
},
Returns: []code.Type{
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
code.SimpleType("error"),
},
},
}, },
}, },
}, },

View file

@ -133,6 +133,26 @@ func TestGenerateMongoRepository(t *testing.T) {
}, },
}, },
}, },
{
Name: "FindByAgeBetween",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "fromAge", Type: code.SimpleType("int")},
{Name: "toAge", Type: code.SimpleType("int")},
},
Returns: []code.Type{
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
code.SimpleType("error"),
},
Operation: spec.FindOperation{
Mode: spec.QueryModeMany,
Query: spec.QuerySpec{
Predicates: []spec.Predicate{
{Field: "Age", Comparator: spec.ComparatorBetween},
},
},
},
},
{ {
Name: "FindByGenderOrAge", Name: "FindByGenderOrAge",
Params: []code.Param{ Params: []code.Param{
@ -250,6 +270,20 @@ func (r *UserRepositoryMongo) FindByAgeGreaterThanEqual(ctx context.Context, arg
return entities, nil return entities, nil
} }
func (r *UserRepositoryMongo) FindByAgeBetween(ctx context.Context, arg0 int, arg1 int) ([]*UserModel, error) {
cursor, err := r.collection.Find(ctx, bson.M{
"age": bson.M{"$gte": arg0, "$lte": arg1},
})
if err != nil {
return nil, err
}
var entities []*UserModel
if err := cursor.All(ctx, &entities); err != nil {
return nil, err
}
return entities, nil
}
func (r *UserRepositoryMongo) FindByGenderOrAge(ctx context.Context, arg0 Gender, arg1 int) ([]*UserModel, error) { func (r *UserRepositoryMongo) FindByGenderOrAge(ctx context.Context, arg0 Gender, arg1 int) ([]*UserModel, error) {
cursor, err := r.collection.Find(ctx, bson.M{ cursor, err := r.collection.Find(ctx, bson.M{
"$or": []bson.M{ "$or": []bson.M{

View file

@ -411,6 +411,44 @@ func (r *UserRepositoryMongo) FindByAgeGreaterThanEqual(ctx context.Context, arg
} }
return entities, nil return entities, nil
} }
`,
},
{
Name: "find with Between comparator",
MethodSpec: spec.MethodSpec{
Name: "FindByAgeBetween",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "fromAge", Type: code.SimpleType("int")},
{Name: "toAge", Type: code.SimpleType("int")},
},
Returns: []code.Type{
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
code.SimpleType("error"),
},
Operation: spec.FindOperation{
Mode: spec.QueryModeMany,
Query: spec.QuerySpec{
Predicates: []spec.Predicate{
{Comparator: spec.ComparatorBetween, Field: "Age"},
},
},
},
},
ExpectedCode: `
func (r *UserRepositoryMongo) FindByAgeBetween(ctx context.Context, arg0 int, arg1 int) ([]*UserModel, error) {
cursor, err := r.collection.Find(ctx, bson.M{
"age": bson.M{"$gte": arg0, "$lte": arg1},
})
if err != nil {
return nil, err
}
var entities []*UserModel
if err := cursor.All(ctx, &entities); err != nil {
return nil, err
}
return entities, nil
}
`, `,
}, },
{ {

View file

@ -14,8 +14,10 @@ type querySpec struct {
func (q querySpec) Code() string { func (q querySpec) Code() string {
var predicateCodes []string var predicateCodes []string
for i, predicate := range q.Predicates { var argIndex int
predicateCodes = append(predicateCodes, predicate.Code(i)) for _, predicate := range q.Predicates {
predicateCodes = append(predicateCodes, predicate.Code(argIndex))
argIndex += predicate.Comparator.NumberOfArguments()
} }
var lines []string var lines []string
@ -53,6 +55,8 @@ func (p predicate) Code(argIndex int) string {
return fmt.Sprintf(`"%s": bson.M{"$gt": arg%d}`, p.Field, argIndex) return fmt.Sprintf(`"%s": bson.M{"$gt": arg%d}`, p.Field, argIndex)
case spec.ComparatorGreaterThanEqual: case spec.ComparatorGreaterThanEqual:
return fmt.Sprintf(`"%s": bson.M{"$gte": arg%d}`, p.Field, argIndex) return fmt.Sprintf(`"%s": bson.M{"$gte": arg%d}`, p.Field, argIndex)
case spec.ComparatorBetween:
return fmt.Sprintf(`"%s": bson.M{"$gte": arg%d, "$lte": arg%d}`, p.Field, argIndex, argIndex+1)
case spec.ComparatorIn: case spec.ComparatorIn:
return fmt.Sprintf(`"%s": bson.M{"$in": arg%d}`, p.Field, argIndex) return fmt.Sprintf(`"%s": bson.M{"$in": arg%d}`, p.Field, argIndex)
} }

View file

@ -41,7 +41,11 @@ type QuerySpec struct {
// NumberOfArguments returns number of arguments required to perform the query // NumberOfArguments returns number of arguments required to perform the query
func (q QuerySpec) NumberOfArguments() int { func (q QuerySpec) NumberOfArguments() int {
return len(q.Predicates) var totalArgs int
for _, predicate := range q.Predicates {
totalArgs += predicate.Comparator.NumberOfArguments()
}
return totalArgs
} }
// Operator is a boolean operator for merging conditions // Operator is a boolean operator for merging conditions
@ -64,6 +68,7 @@ const (
ComparatorLessThanEqual Comparator = "LESS_THAN_EQUAL" ComparatorLessThanEqual Comparator = "LESS_THAN_EQUAL"
ComparatorGreaterThan Comparator = "GREATER_THAN" ComparatorGreaterThan Comparator = "GREATER_THAN"
ComparatorGreaterThanEqual Comparator = "GREATER_THAN_EQUAL" ComparatorGreaterThanEqual Comparator = "GREATER_THAN_EQUAL"
ComparatorBetween Comparator = "BETWEEN"
ComparatorIn Comparator = "IN" ComparatorIn Comparator = "IN"
) )
@ -75,6 +80,14 @@ func (c Comparator) ArgumentTypeFromFieldType(t code.Type) code.Type {
return t return t
} }
// NumberOfArguments returns the number of arguments required to perform the comparison
func (c Comparator) NumberOfArguments() int {
if c == ComparatorBetween {
return 2
}
return 1
}
// Predicate is a criteria for querying a field // Predicate is a criteria for querying a field
type Predicate struct { type Predicate struct {
Field string Field string
@ -102,5 +115,8 @@ func (t predicateToken) ToPredicate() Predicate {
if len(t) > 1 && t[len(t)-1] == "In" { if len(t) > 1 && t[len(t)-1] == "In" {
return Predicate{Field: strings.Join(t[:len(t)-1], ""), Comparator: ComparatorIn} return Predicate{Field: strings.Join(t[:len(t)-1], ""), Comparator: ComparatorIn}
} }
if len(t) > 1 && t[len(t)-1] == "Between" {
return Predicate{Field: strings.Join(t[:len(t)-1], ""), Comparator: ComparatorBetween}
}
return Predicate{Field: strings.Join(t, ""), Comparator: ComparatorEqual} return Predicate{Field: strings.Join(t, ""), Comparator: ComparatorEqual}
} }

View file

@ -157,13 +157,14 @@ func (p interfaceMethodParser) validateMethodSignature(querySpec QuerySpec) erro
return StructFieldNotFoundError return StructFieldNotFoundError
} }
for i := 0; i < predicate.Comparator.NumberOfArguments(); i++ {
if p.Method.Params[currentParamIndex].Type != predicate.Comparator.ArgumentTypeFromFieldType( if p.Method.Params[currentParamIndex].Type != predicate.Comparator.ArgumentTypeFromFieldType(
structField.Type) { structField.Type) {
return InvalidParamError return InvalidParamError
} }
currentParamIndex++ currentParamIndex++
} }
}
return nil return nil
} }

View file

@ -390,6 +390,39 @@ func TestParseInterfaceMethod(t *testing.T) {
}, },
}, },
}, },
{
Name: "FindByArgBetween method",
Method: code.Method{
Name: "FindByAgeBetween",
Params: []code.Param{
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Type: code.SimpleType("int")},
{Type: code.SimpleType("int")},
},
Returns: []code.Type{
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
code.SimpleType("error"),
},
},
ExpectedOutput: spec.MethodSpec{
Name: "FindByAgeBetween",
Params: []code.Param{
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Type: code.SimpleType("int")},
{Type: code.SimpleType("int")},
},
Returns: []code.Type{
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
code.SimpleType("error"),
},
Operation: spec.FindOperation{
Mode: spec.QueryModeMany,
Query: spec.QuerySpec{Predicates: []spec.Predicate{
{Field: "Age", Comparator: spec.ComparatorBetween},
}},
},
},
},
{ {
Name: "FindByArgIn method", Name: "FindByArgIn method",
Method: code.Method{ Method: code.Method{