diff --git a/internal/code/models_test.go b/internal/code/models_test.go new file mode 100644 index 0000000..d1e84b9 --- /dev/null +++ b/internal/code/models_test.go @@ -0,0 +1,87 @@ +package code_test + +import ( + "reflect" + "testing" + + "github.com/sunboyy/repogen/internal/code" +) + +func TestStructsByName(t *testing.T) { + userStruct := code.Struct{ + Name: "UserModel", + Fields: code.StructFields{ + {Name: "ID", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, + {Name: "Username", Type: code.SimpleType("string")}, + }, + } + structs := code.Structs{userStruct} + + t.Run("struct found", func(t *testing.T) { + structModel, ok := structs.ByName("UserModel") + + if !ok { + t.Fail() + } + if !reflect.DeepEqual(structModel, userStruct) { + t.Errorf("Expected = %v\nReceived = %v", userStruct, structModel) + } + }) + + t.Run("struct not found", func(t *testing.T) { + _, ok := structs.ByName("ProductModel") + + if ok { + t.Fail() + } + }) +} + +func TestStructFieldsByName(t *testing.T) { + idField := code.StructField{Name: "ID", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}} + usernameField := code.StructField{Name: "Username", Type: code.SimpleType("string")} + fields := code.StructFields{idField, usernameField} + + t.Run("struct field found", func(t *testing.T) { + field, ok := fields.ByName("Username") + + if !ok { + t.Fail() + } + if !reflect.DeepEqual(field, usernameField) { + t.Errorf("Expected = %v\nReceived = %v", usernameField, field) + } + }) + + t.Run("struct field not found", func(t *testing.T) { + _, ok := fields.ByName("Password") + + if ok { + t.Fail() + } + }) +} + +func TestInterfacesByName(t *testing.T) { + userRepoIntf := code.Interface{Name: "UserRepository"} + interfaces := code.Interfaces{userRepoIntf} + + t.Run("struct field found", func(t *testing.T) { + intf, ok := interfaces.ByName("UserRepository") + + if !ok { + t.Fail() + } + if !reflect.DeepEqual(intf, userRepoIntf) { + t.Errorf("Expected = %v\nReceived = %v", userRepoIntf, intf) + } + }) + + t.Run("struct field not found", func(t *testing.T) { + _, ok := interfaces.ByName("Password") + + if ok { + t.Fail() + } + }) +} diff --git a/internal/mongo/generator_test.go b/internal/mongo/generator_test.go index 8c38357..51ac3a9 100644 --- a/internal/mongo/generator_test.go +++ b/internal/mongo/generator_test.go @@ -146,6 +146,17 @@ func TestGenerateMongoRepository(t *testing.T) { code.SimpleType("error"), }, }, + { + Name: "FindByGenderIn", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Name: "gender", Type: code.ArrayType{ContainedType: code.SimpleType("Gender")}}, + }, + Returns: []code.Type{ + code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + code.SimpleType("error"), + }, + }, }, } @@ -306,6 +317,20 @@ func (r *UserRepositoryMongo) FindByGenderOrAgeLessThan(ctx context.Context, arg } return entities, nil } + +func (r *UserRepositoryMongo) FindByGenderIn(ctx context.Context, arg0 []Gender) ([]*UserModel, error) { + cursor, err := r.collection.Find(ctx, bson.M{ + "gender": bson.M{"$in": arg0}, + }) + if err != nil { + return nil, err + } + var entities []*UserModel + if err := cursor.All(ctx, &entities); err != nil { + return nil, err + } + return entities, nil +} ` expectedCodeLines := strings.Split(expectedCode, "\n") actualCodeLines := strings.Split(code, "\n") diff --git a/internal/mongo/models.go b/internal/mongo/models.go index 959daba..5e69e8f 100644 --- a/internal/mongo/models.go +++ b/internal/mongo/models.go @@ -53,6 +53,8 @@ func (p predicate) Code(argIndex int) string { return fmt.Sprintf(`"%s": bson.M{"$gt": arg%d}`, p.Field, argIndex) case spec.ComparatorGreaterThanEqual: return fmt.Sprintf(`"%s": bson.M{"$gte": arg%d}`, p.Field, argIndex) + case spec.ComparatorIn: + return fmt.Sprintf(`"%s": bson.M{"$in": arg%d}`, p.Field, argIndex) } return "" } diff --git a/internal/spec/models.go b/internal/spec/models.go index 472f050..0313ef8 100644 --- a/internal/spec/models.go +++ b/internal/spec/models.go @@ -64,8 +64,17 @@ const ( ComparatorLessThanEqual Comparator = "LESS_THAN_EQUAL" ComparatorGreaterThan Comparator = "GREATER_THAN" ComparatorGreaterThanEqual Comparator = "GREATER_THAN_EQUAL" + ComparatorIn Comparator = "IN" ) +// ArgumentTypeFromFieldType returns a type of required argument from the given struct field type +func (c Comparator) ArgumentTypeFromFieldType(t code.Type) code.Type { + if c == ComparatorIn { + return code.ArrayType{ContainedType: t} + } + return t +} + // Predicate is a criteria for querying a field type Predicate struct { Field string @@ -90,5 +99,8 @@ func (t predicateToken) ToPredicate() Predicate { if len(t) > 3 && t[len(t)-3] == "Greater" && t[len(t)-2] == "Than" && t[len(t)-1] == "Equal" { return Predicate{Field: strings.Join(t[:len(t)-3], ""), Comparator: ComparatorGreaterThanEqual} } + if len(t) > 1 && t[len(t)-1] == "In" { + return Predicate{Field: strings.Join(t[:len(t)-1], ""), Comparator: ComparatorIn} + } return Predicate{Field: strings.Join(t, ""), Comparator: ComparatorEqual} } diff --git a/internal/spec/parser.go b/internal/spec/parser.go index c9d75b7..7a98503 100644 --- a/internal/spec/parser.go +++ b/internal/spec/parser.go @@ -157,7 +157,8 @@ func (p interfaceMethodParser) validateMethodSignature(querySpec QuerySpec) erro return StructFieldNotFoundError } - if structField.Type != p.Method.Params[currentParamIndex].Type { + if p.Method.Params[currentParamIndex].Type != predicate.Comparator.ArgumentTypeFromFieldType( + structField.Type) { return InvalidParamError } diff --git a/internal/spec/parser_test.go b/internal/spec/parser_test.go index bd60043..9a0a869 100644 --- a/internal/spec/parser_test.go +++ b/internal/spec/parser_test.go @@ -390,6 +390,37 @@ func TestParseInterfaceMethod(t *testing.T) { }, }, }, + { + Name: "FindByArgIn method", + Method: code.Method{ + Name: "FindByCityIn", + Params: []code.Param{ + {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Type: code.ArrayType{ContainedType: code.SimpleType("string")}}, + }, + Returns: []code.Type{ + code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + code.SimpleType("error"), + }, + }, + ExpectedOutput: spec.MethodSpec{ + Name: "FindByCityIn", + Params: []code.Param{ + {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Type: code.ArrayType{ContainedType: code.SimpleType("string")}}, + }, + Returns: []code.Type{ + code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + code.SimpleType("error"), + }, + Operation: spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + {Field: "City", Comparator: spec.ComparatorIn}, + }}, + }, + }, + }, } for _, testCase := range testTable { @@ -566,6 +597,21 @@ func TestParseInterfaceMethodInvalid(t *testing.T) { }, ExpectedError: spec.InvalidParamError, }, + { + Name: "mismatched method parameter type for special case", + Method: code.Method{ + Name: "FindByCityIn", + Params: []code.Param{ + {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Type: code.SimpleType("string")}, + }, + Returns: []code.Type{ + code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + code.SimpleType("error"), + }, + }, + ExpectedError: spec.InvalidParamError, + }, } for _, testCase := range testTable {