Add IN comparator

This commit is contained in:
sunboyy 2021-01-22 09:56:30 +07:00
parent 847d5add53
commit ba77c93d64
6 changed files with 174 additions and 1 deletions

View file

@ -0,0 +1,87 @@
package code_test
import (
"reflect"
"testing"
"github.com/sunboyy/repogen/internal/code"
)
func TestStructsByName(t *testing.T) {
userStruct := code.Struct{
Name: "UserModel",
Fields: code.StructFields{
{Name: "ID", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}},
{Name: "Username", Type: code.SimpleType("string")},
},
}
structs := code.Structs{userStruct}
t.Run("struct found", func(t *testing.T) {
structModel, ok := structs.ByName("UserModel")
if !ok {
t.Fail()
}
if !reflect.DeepEqual(structModel, userStruct) {
t.Errorf("Expected = %v\nReceived = %v", userStruct, structModel)
}
})
t.Run("struct not found", func(t *testing.T) {
_, ok := structs.ByName("ProductModel")
if ok {
t.Fail()
}
})
}
func TestStructFieldsByName(t *testing.T) {
idField := code.StructField{Name: "ID", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}
usernameField := code.StructField{Name: "Username", Type: code.SimpleType("string")}
fields := code.StructFields{idField, usernameField}
t.Run("struct field found", func(t *testing.T) {
field, ok := fields.ByName("Username")
if !ok {
t.Fail()
}
if !reflect.DeepEqual(field, usernameField) {
t.Errorf("Expected = %v\nReceived = %v", usernameField, field)
}
})
t.Run("struct field not found", func(t *testing.T) {
_, ok := fields.ByName("Password")
if ok {
t.Fail()
}
})
}
func TestInterfacesByName(t *testing.T) {
userRepoIntf := code.Interface{Name: "UserRepository"}
interfaces := code.Interfaces{userRepoIntf}
t.Run("struct field found", func(t *testing.T) {
intf, ok := interfaces.ByName("UserRepository")
if !ok {
t.Fail()
}
if !reflect.DeepEqual(intf, userRepoIntf) {
t.Errorf("Expected = %v\nReceived = %v", userRepoIntf, intf)
}
})
t.Run("struct field not found", func(t *testing.T) {
_, ok := interfaces.ByName("Password")
if ok {
t.Fail()
}
})
}

View file

@ -146,6 +146,17 @@ func TestGenerateMongoRepository(t *testing.T) {
code.SimpleType("error"), code.SimpleType("error"),
}, },
}, },
{
Name: "FindByGenderIn",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "gender", Type: code.ArrayType{ContainedType: code.SimpleType("Gender")}},
},
Returns: []code.Type{
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
code.SimpleType("error"),
},
},
}, },
} }
@ -306,6 +317,20 @@ func (r *UserRepositoryMongo) FindByGenderOrAgeLessThan(ctx context.Context, arg
} }
return entities, nil return entities, nil
} }
func (r *UserRepositoryMongo) FindByGenderIn(ctx context.Context, arg0 []Gender) ([]*UserModel, error) {
cursor, err := r.collection.Find(ctx, bson.M{
"gender": bson.M{"$in": arg0},
})
if err != nil {
return nil, err
}
var entities []*UserModel
if err := cursor.All(ctx, &entities); err != nil {
return nil, err
}
return entities, nil
}
` `
expectedCodeLines := strings.Split(expectedCode, "\n") expectedCodeLines := strings.Split(expectedCode, "\n")
actualCodeLines := strings.Split(code, "\n") actualCodeLines := strings.Split(code, "\n")

View file

@ -53,6 +53,8 @@ func (p predicate) Code(argIndex int) string {
return fmt.Sprintf(`"%s": bson.M{"$gt": arg%d}`, p.Field, argIndex) return fmt.Sprintf(`"%s": bson.M{"$gt": arg%d}`, p.Field, argIndex)
case spec.ComparatorGreaterThanEqual: case spec.ComparatorGreaterThanEqual:
return fmt.Sprintf(`"%s": bson.M{"$gte": arg%d}`, p.Field, argIndex) return fmt.Sprintf(`"%s": bson.M{"$gte": arg%d}`, p.Field, argIndex)
case spec.ComparatorIn:
return fmt.Sprintf(`"%s": bson.M{"$in": arg%d}`, p.Field, argIndex)
} }
return "" return ""
} }

View file

@ -64,8 +64,17 @@ const (
ComparatorLessThanEqual Comparator = "LESS_THAN_EQUAL" ComparatorLessThanEqual Comparator = "LESS_THAN_EQUAL"
ComparatorGreaterThan Comparator = "GREATER_THAN" ComparatorGreaterThan Comparator = "GREATER_THAN"
ComparatorGreaterThanEqual Comparator = "GREATER_THAN_EQUAL" ComparatorGreaterThanEqual Comparator = "GREATER_THAN_EQUAL"
ComparatorIn Comparator = "IN"
) )
// ArgumentTypeFromFieldType returns a type of required argument from the given struct field type
func (c Comparator) ArgumentTypeFromFieldType(t code.Type) code.Type {
if c == ComparatorIn {
return code.ArrayType{ContainedType: t}
}
return t
}
// Predicate is a criteria for querying a field // Predicate is a criteria for querying a field
type Predicate struct { type Predicate struct {
Field string Field string
@ -90,5 +99,8 @@ func (t predicateToken) ToPredicate() Predicate {
if len(t) > 3 && t[len(t)-3] == "Greater" && t[len(t)-2] == "Than" && t[len(t)-1] == "Equal" { if len(t) > 3 && t[len(t)-3] == "Greater" && t[len(t)-2] == "Than" && t[len(t)-1] == "Equal" {
return Predicate{Field: strings.Join(t[:len(t)-3], ""), Comparator: ComparatorGreaterThanEqual} return Predicate{Field: strings.Join(t[:len(t)-3], ""), Comparator: ComparatorGreaterThanEqual}
} }
if len(t) > 1 && t[len(t)-1] == "In" {
return Predicate{Field: strings.Join(t[:len(t)-1], ""), Comparator: ComparatorIn}
}
return Predicate{Field: strings.Join(t, ""), Comparator: ComparatorEqual} return Predicate{Field: strings.Join(t, ""), Comparator: ComparatorEqual}
} }

View file

@ -157,7 +157,8 @@ func (p interfaceMethodParser) validateMethodSignature(querySpec QuerySpec) erro
return StructFieldNotFoundError return StructFieldNotFoundError
} }
if structField.Type != p.Method.Params[currentParamIndex].Type { if p.Method.Params[currentParamIndex].Type != predicate.Comparator.ArgumentTypeFromFieldType(
structField.Type) {
return InvalidParamError return InvalidParamError
} }

View file

@ -390,6 +390,37 @@ func TestParseInterfaceMethod(t *testing.T) {
}, },
}, },
}, },
{
Name: "FindByArgIn method",
Method: code.Method{
Name: "FindByCityIn",
Params: []code.Param{
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Type: code.ArrayType{ContainedType: code.SimpleType("string")}},
},
Returns: []code.Type{
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
code.SimpleType("error"),
},
},
ExpectedOutput: spec.MethodSpec{
Name: "FindByCityIn",
Params: []code.Param{
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Type: code.ArrayType{ContainedType: code.SimpleType("string")}},
},
Returns: []code.Type{
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
code.SimpleType("error"),
},
Operation: spec.FindOperation{
Mode: spec.QueryModeMany,
Query: spec.QuerySpec{Predicates: []spec.Predicate{
{Field: "City", Comparator: spec.ComparatorIn},
}},
},
},
},
} }
for _, testCase := range testTable { for _, testCase := range testTable {
@ -566,6 +597,21 @@ func TestParseInterfaceMethodInvalid(t *testing.T) {
}, },
ExpectedError: spec.InvalidParamError, ExpectedError: spec.InvalidParamError,
}, },
{
Name: "mismatched method parameter type for special case",
Method: code.Method{
Name: "FindByCityIn",
Params: []code.Param{
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Type: code.SimpleType("string")},
},
Returns: []code.Type{
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
code.SimpleType("error"),
},
},
ExpectedError: spec.InvalidParamError,
},
} }
for _, testCase := range testTable { for _, testCase := range testTable {