Add IN comparator
This commit is contained in:
parent
847d5add53
commit
ba77c93d64
6 changed files with 174 additions and 1 deletions
87
internal/code/models_test.go
Normal file
87
internal/code/models_test.go
Normal file
|
@ -0,0 +1,87 @@
|
||||||
|
package code_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/sunboyy/repogen/internal/code"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestStructsByName(t *testing.T) {
|
||||||
|
userStruct := code.Struct{
|
||||||
|
Name: "UserModel",
|
||||||
|
Fields: code.StructFields{
|
||||||
|
{Name: "ID", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}},
|
||||||
|
{Name: "Username", Type: code.SimpleType("string")},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
structs := code.Structs{userStruct}
|
||||||
|
|
||||||
|
t.Run("struct found", func(t *testing.T) {
|
||||||
|
structModel, ok := structs.ByName("UserModel")
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
t.Fail()
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(structModel, userStruct) {
|
||||||
|
t.Errorf("Expected = %v\nReceived = %v", userStruct, structModel)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("struct not found", func(t *testing.T) {
|
||||||
|
_, ok := structs.ByName("ProductModel")
|
||||||
|
|
||||||
|
if ok {
|
||||||
|
t.Fail()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStructFieldsByName(t *testing.T) {
|
||||||
|
idField := code.StructField{Name: "ID", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}
|
||||||
|
usernameField := code.StructField{Name: "Username", Type: code.SimpleType("string")}
|
||||||
|
fields := code.StructFields{idField, usernameField}
|
||||||
|
|
||||||
|
t.Run("struct field found", func(t *testing.T) {
|
||||||
|
field, ok := fields.ByName("Username")
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
t.Fail()
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(field, usernameField) {
|
||||||
|
t.Errorf("Expected = %v\nReceived = %v", usernameField, field)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("struct field not found", func(t *testing.T) {
|
||||||
|
_, ok := fields.ByName("Password")
|
||||||
|
|
||||||
|
if ok {
|
||||||
|
t.Fail()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInterfacesByName(t *testing.T) {
|
||||||
|
userRepoIntf := code.Interface{Name: "UserRepository"}
|
||||||
|
interfaces := code.Interfaces{userRepoIntf}
|
||||||
|
|
||||||
|
t.Run("struct field found", func(t *testing.T) {
|
||||||
|
intf, ok := interfaces.ByName("UserRepository")
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
t.Fail()
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(intf, userRepoIntf) {
|
||||||
|
t.Errorf("Expected = %v\nReceived = %v", userRepoIntf, intf)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("struct field not found", func(t *testing.T) {
|
||||||
|
_, ok := interfaces.ByName("Password")
|
||||||
|
|
||||||
|
if ok {
|
||||||
|
t.Fail()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
|
@ -146,6 +146,17 @@ func TestGenerateMongoRepository(t *testing.T) {
|
||||||
code.SimpleType("error"),
|
code.SimpleType("error"),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Name: "FindByGenderIn",
|
||||||
|
Params: []code.Param{
|
||||||
|
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
|
||||||
|
{Name: "gender", Type: code.ArrayType{ContainedType: code.SimpleType("Gender")}},
|
||||||
|
},
|
||||||
|
Returns: []code.Type{
|
||||||
|
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
|
||||||
|
code.SimpleType("error"),
|
||||||
|
},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -306,6 +317,20 @@ func (r *UserRepositoryMongo) FindByGenderOrAgeLessThan(ctx context.Context, arg
|
||||||
}
|
}
|
||||||
return entities, nil
|
return entities, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *UserRepositoryMongo) FindByGenderIn(ctx context.Context, arg0 []Gender) ([]*UserModel, error) {
|
||||||
|
cursor, err := r.collection.Find(ctx, bson.M{
|
||||||
|
"gender": bson.M{"$in": arg0},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var entities []*UserModel
|
||||||
|
if err := cursor.All(ctx, &entities); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return entities, nil
|
||||||
|
}
|
||||||
`
|
`
|
||||||
expectedCodeLines := strings.Split(expectedCode, "\n")
|
expectedCodeLines := strings.Split(expectedCode, "\n")
|
||||||
actualCodeLines := strings.Split(code, "\n")
|
actualCodeLines := strings.Split(code, "\n")
|
||||||
|
|
|
@ -53,6 +53,8 @@ func (p predicate) Code(argIndex int) string {
|
||||||
return fmt.Sprintf(`"%s": bson.M{"$gt": arg%d}`, p.Field, argIndex)
|
return fmt.Sprintf(`"%s": bson.M{"$gt": arg%d}`, p.Field, argIndex)
|
||||||
case spec.ComparatorGreaterThanEqual:
|
case spec.ComparatorGreaterThanEqual:
|
||||||
return fmt.Sprintf(`"%s": bson.M{"$gte": arg%d}`, p.Field, argIndex)
|
return fmt.Sprintf(`"%s": bson.M{"$gte": arg%d}`, p.Field, argIndex)
|
||||||
|
case spec.ComparatorIn:
|
||||||
|
return fmt.Sprintf(`"%s": bson.M{"$in": arg%d}`, p.Field, argIndex)
|
||||||
}
|
}
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
|
@ -64,8 +64,17 @@ const (
|
||||||
ComparatorLessThanEqual Comparator = "LESS_THAN_EQUAL"
|
ComparatorLessThanEqual Comparator = "LESS_THAN_EQUAL"
|
||||||
ComparatorGreaterThan Comparator = "GREATER_THAN"
|
ComparatorGreaterThan Comparator = "GREATER_THAN"
|
||||||
ComparatorGreaterThanEqual Comparator = "GREATER_THAN_EQUAL"
|
ComparatorGreaterThanEqual Comparator = "GREATER_THAN_EQUAL"
|
||||||
|
ComparatorIn Comparator = "IN"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ArgumentTypeFromFieldType returns a type of required argument from the given struct field type
|
||||||
|
func (c Comparator) ArgumentTypeFromFieldType(t code.Type) code.Type {
|
||||||
|
if c == ComparatorIn {
|
||||||
|
return code.ArrayType{ContainedType: t}
|
||||||
|
}
|
||||||
|
return t
|
||||||
|
}
|
||||||
|
|
||||||
// Predicate is a criteria for querying a field
|
// Predicate is a criteria for querying a field
|
||||||
type Predicate struct {
|
type Predicate struct {
|
||||||
Field string
|
Field string
|
||||||
|
@ -90,5 +99,8 @@ func (t predicateToken) ToPredicate() Predicate {
|
||||||
if len(t) > 3 && t[len(t)-3] == "Greater" && t[len(t)-2] == "Than" && t[len(t)-1] == "Equal" {
|
if len(t) > 3 && t[len(t)-3] == "Greater" && t[len(t)-2] == "Than" && t[len(t)-1] == "Equal" {
|
||||||
return Predicate{Field: strings.Join(t[:len(t)-3], ""), Comparator: ComparatorGreaterThanEqual}
|
return Predicate{Field: strings.Join(t[:len(t)-3], ""), Comparator: ComparatorGreaterThanEqual}
|
||||||
}
|
}
|
||||||
|
if len(t) > 1 && t[len(t)-1] == "In" {
|
||||||
|
return Predicate{Field: strings.Join(t[:len(t)-1], ""), Comparator: ComparatorIn}
|
||||||
|
}
|
||||||
return Predicate{Field: strings.Join(t, ""), Comparator: ComparatorEqual}
|
return Predicate{Field: strings.Join(t, ""), Comparator: ComparatorEqual}
|
||||||
}
|
}
|
||||||
|
|
|
@ -157,7 +157,8 @@ func (p interfaceMethodParser) validateMethodSignature(querySpec QuerySpec) erro
|
||||||
return StructFieldNotFoundError
|
return StructFieldNotFoundError
|
||||||
}
|
}
|
||||||
|
|
||||||
if structField.Type != p.Method.Params[currentParamIndex].Type {
|
if p.Method.Params[currentParamIndex].Type != predicate.Comparator.ArgumentTypeFromFieldType(
|
||||||
|
structField.Type) {
|
||||||
return InvalidParamError
|
return InvalidParamError
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -390,6 +390,37 @@ func TestParseInterfaceMethod(t *testing.T) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Name: "FindByArgIn method",
|
||||||
|
Method: code.Method{
|
||||||
|
Name: "FindByCityIn",
|
||||||
|
Params: []code.Param{
|
||||||
|
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
|
||||||
|
{Type: code.ArrayType{ContainedType: code.SimpleType("string")}},
|
||||||
|
},
|
||||||
|
Returns: []code.Type{
|
||||||
|
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
|
||||||
|
code.SimpleType("error"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ExpectedOutput: spec.MethodSpec{
|
||||||
|
Name: "FindByCityIn",
|
||||||
|
Params: []code.Param{
|
||||||
|
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
|
||||||
|
{Type: code.ArrayType{ContainedType: code.SimpleType("string")}},
|
||||||
|
},
|
||||||
|
Returns: []code.Type{
|
||||||
|
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
|
||||||
|
code.SimpleType("error"),
|
||||||
|
},
|
||||||
|
Operation: spec.FindOperation{
|
||||||
|
Mode: spec.QueryModeMany,
|
||||||
|
Query: spec.QuerySpec{Predicates: []spec.Predicate{
|
||||||
|
{Field: "City", Comparator: spec.ComparatorIn},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, testCase := range testTable {
|
for _, testCase := range testTable {
|
||||||
|
@ -566,6 +597,21 @@ func TestParseInterfaceMethodInvalid(t *testing.T) {
|
||||||
},
|
},
|
||||||
ExpectedError: spec.InvalidParamError,
|
ExpectedError: spec.InvalidParamError,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Name: "mismatched method parameter type for special case",
|
||||||
|
Method: code.Method{
|
||||||
|
Name: "FindByCityIn",
|
||||||
|
Params: []code.Param{
|
||||||
|
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
|
||||||
|
{Type: code.SimpleType("string")},
|
||||||
|
},
|
||||||
|
Returns: []code.Type{
|
||||||
|
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
|
||||||
|
code.SimpleType("error"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ExpectedError: spec.InvalidParamError,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, testCase := range testTable {
|
for _, testCase := range testTable {
|
||||||
|
|
Loading…
Reference in a new issue