diff --git a/README.md b/README.md index 6e6ddee..44aee98 100644 --- a/README.md +++ b/README.md @@ -7,8 +7,12 @@ -Repogen is a code generator for database repository in Golang. (WIP) +Repogen is a code generator for database repository in Golang inspired by Spring Data JPA. (WIP) ## Features Repogen is a library that generates MongoDB repository implementation from repository interface by using method name pattern. + +- Method signature validation +- Supports single-entity and multiple-entity operations +- Supports many comparison operators diff --git a/codecov.yml b/codecov.yml index 080f560..2c28e2f 100644 --- a/codecov.yml +++ b/codecov.yml @@ -6,4 +6,4 @@ coverage: threshold: 4% patch: default: - target: 50% + target: 60% diff --git a/internal/generator/generator_test.go b/internal/generator/generator_test.go index 4ce0880..bd5a642 100644 --- a/internal/generator/generator_test.go +++ b/internal/generator/generator_test.go @@ -48,7 +48,7 @@ func TestGenerateMongoRepository(t *testing.T) { Mode: spec.QueryModeOne, Query: spec.QuerySpec{ Predicates: []spec.Predicate{ - {Field: "ID", Comparator: spec.ComparatorEqual}, + {Field: "ID", Comparator: spec.ComparatorEqual, ParamIndex: 1}, }, }, }, @@ -70,8 +70,8 @@ func TestGenerateMongoRepository(t *testing.T) { Query: spec.QuerySpec{ Operator: spec.OperatorAnd, Predicates: []spec.Predicate{ - {Field: "Gender", Comparator: spec.ComparatorNot}, - {Field: "Age", Comparator: spec.ComparatorLessThan}, + {Field: "Gender", Comparator: spec.ComparatorNot, ParamIndex: 1}, + {Field: "Age", Comparator: spec.ComparatorLessThan, ParamIndex: 2}, }, }, }, @@ -90,7 +90,7 @@ func TestGenerateMongoRepository(t *testing.T) { Mode: spec.QueryModeMany, Query: spec.QuerySpec{ Predicates: []spec.Predicate{ - {Field: "Age", Comparator: spec.ComparatorLessThanEqual}, + {Field: "Age", Comparator: spec.ComparatorLessThanEqual, ParamIndex: 1}, }, }, }, @@ -109,7 +109,7 @@ func TestGenerateMongoRepository(t *testing.T) { Mode: spec.QueryModeMany, Query: spec.QuerySpec{ Predicates: []spec.Predicate{ - {Field: "Age", Comparator: spec.ComparatorGreaterThan}, + {Field: "Age", Comparator: spec.ComparatorGreaterThan, ParamIndex: 1}, }, }, }, @@ -128,7 +128,7 @@ func TestGenerateMongoRepository(t *testing.T) { Mode: spec.QueryModeMany, Query: spec.QuerySpec{ Predicates: []spec.Predicate{ - {Field: "Age", Comparator: spec.ComparatorGreaterThanEqual}, + {Field: "Age", Comparator: spec.ComparatorGreaterThanEqual, ParamIndex: 1}, }, }, }, @@ -148,7 +148,7 @@ func TestGenerateMongoRepository(t *testing.T) { Mode: spec.QueryModeMany, Query: spec.QuerySpec{ Predicates: []spec.Predicate{ - {Field: "Age", Comparator: spec.ComparatorBetween}, + {Field: "Age", Comparator: spec.ComparatorBetween, ParamIndex: 1}, }, }, }, @@ -169,8 +169,8 @@ func TestGenerateMongoRepository(t *testing.T) { Query: spec.QuerySpec{ Operator: spec.OperatorOr, Predicates: []spec.Predicate{ - {Field: "Gender", Comparator: spec.ComparatorEqual}, - {Field: "Age", Comparator: spec.ComparatorEqual}, + {Field: "Gender", Comparator: spec.ComparatorEqual, ParamIndex: 1}, + {Field: "Age", Comparator: spec.ComparatorEqual, ParamIndex: 2}, }, }, }, @@ -203,99 +203,99 @@ type UserRepositoryMongo struct { collection *mongo.Collection } -func (r *UserRepositoryMongo) FindByID(ctx context.Context, arg0 primitive.ObjectID) (*UserModel, error) { +func (r *UserRepositoryMongo) FindByID(arg0 context.Context, arg1 primitive.ObjectID) (*UserModel, error) { var entity UserModel - if err := r.collection.FindOne(ctx, bson.M{ - "_id": arg0, + if err := r.collection.FindOne(arg0, bson.M{ + "_id": arg1, }).Decode(&entity); err != nil { return nil, err } return &entity, nil } -func (r *UserRepositoryMongo) FindByGenderNotAndAgeLessThan(ctx context.Context, arg0 Gender, arg1 int) (*UserModel, error) { - cursor, err := r.collection.Find(ctx, bson.M{ - "gender": bson.M{"$ne": arg0}, - "age": bson.M{"$lt": arg1}, +func (r *UserRepositoryMongo) FindByGenderNotAndAgeLessThan(arg0 context.Context, arg1 Gender, arg2 int) (*UserModel, error) { + cursor, err := r.collection.Find(arg0, bson.M{ + "gender": bson.M{"$ne": arg1}, + "age": bson.M{"$lt": arg2}, }) if err != nil { return nil, err } var entities []*UserModel - if err := cursor.All(ctx, &entities); err != nil { + if err := cursor.All(arg0, &entities); err != nil { return nil, err } return entities, nil } -func (r *UserRepositoryMongo) FindByAgeLessThanEqual(ctx context.Context, arg0 int) ([]*UserModel, error) { - cursor, err := r.collection.Find(ctx, bson.M{ - "age": bson.M{"$lte": arg0}, +func (r *UserRepositoryMongo) FindByAgeLessThanEqual(arg0 context.Context, arg1 int) ([]*UserModel, error) { + cursor, err := r.collection.Find(arg0, bson.M{ + "age": bson.M{"$lte": arg1}, }) if err != nil { return nil, err } var entities []*UserModel - if err := cursor.All(ctx, &entities); err != nil { + if err := cursor.All(arg0, &entities); err != nil { return nil, err } return entities, nil } -func (r *UserRepositoryMongo) FindByAgeGreaterThan(ctx context.Context, arg0 int) ([]*UserModel, error) { - cursor, err := r.collection.Find(ctx, bson.M{ - "age": bson.M{"$gt": arg0}, +func (r *UserRepositoryMongo) FindByAgeGreaterThan(arg0 context.Context, arg1 int) ([]*UserModel, error) { + cursor, err := r.collection.Find(arg0, bson.M{ + "age": bson.M{"$gt": arg1}, }) if err != nil { return nil, err } var entities []*UserModel - if err := cursor.All(ctx, &entities); err != nil { + if err := cursor.All(arg0, &entities); err != nil { return nil, err } return entities, nil } -func (r *UserRepositoryMongo) FindByAgeGreaterThanEqual(ctx context.Context, arg0 int) ([]*UserModel, error) { - cursor, err := r.collection.Find(ctx, bson.M{ - "age": bson.M{"$gte": arg0}, +func (r *UserRepositoryMongo) FindByAgeGreaterThanEqual(arg0 context.Context, arg1 int) ([]*UserModel, error) { + cursor, err := r.collection.Find(arg0, bson.M{ + "age": bson.M{"$gte": arg1}, }) if err != nil { return nil, err } var entities []*UserModel - if err := cursor.All(ctx, &entities); err != nil { + if err := cursor.All(arg0, &entities); err != nil { return nil, err } return entities, nil } -func (r *UserRepositoryMongo) FindByAgeBetween(ctx context.Context, arg0 int, arg1 int) ([]*UserModel, error) { - cursor, err := r.collection.Find(ctx, bson.M{ - "age": bson.M{"$gte": arg0, "$lte": arg1}, +func (r *UserRepositoryMongo) FindByAgeBetween(arg0 context.Context, arg1 int, arg2 int) ([]*UserModel, error) { + cursor, err := r.collection.Find(arg0, bson.M{ + "age": bson.M{"$gte": arg1, "$lte": arg2}, }) if err != nil { return nil, err } var entities []*UserModel - if err := cursor.All(ctx, &entities); err != nil { + if err := cursor.All(arg0, &entities); err != nil { return nil, err } return entities, nil } -func (r *UserRepositoryMongo) FindByGenderOrAge(ctx context.Context, arg0 Gender, arg1 int) ([]*UserModel, error) { - cursor, err := r.collection.Find(ctx, bson.M{ +func (r *UserRepositoryMongo) FindByGenderOrAge(arg0 context.Context, arg1 Gender, arg2 int) ([]*UserModel, error) { + cursor, err := r.collection.Find(arg0, bson.M{ "$or": []bson.M{ - {"gender": arg0}, - {"age": arg1}, + {"gender": arg1}, + {"age": arg2}, }, }) if err != nil { return nil, err } var entities []*UserModel - if err := cursor.All(ctx, &entities); err != nil { + if err := cursor.All(arg0, &entities); err != nil { return nil, err } return entities, nil diff --git a/internal/mongo/generator.go b/internal/mongo/generator.go index 7c1aade..e421c93 100644 --- a/internal/mongo/generator.go +++ b/internal/mongo/generator.go @@ -56,7 +56,7 @@ func (g RepositoryGenerator) GenerateMethod(methodSpec spec.MethodSpec, buffer i } var paramTypes []code.Type - for _, param := range methodSpec.Params[1:] { + for _, param := range methodSpec.Params { paramTypes = append(paramTypes, param.Type) } @@ -79,6 +79,8 @@ func (g RepositoryGenerator) generateMethodImplementation(methodSpec spec.Method switch operation := methodSpec.Operation.(type) { case spec.FindOperation: return g.generateFindImplementation(operation) + case spec.UpdateOperation: + return g.generateUpdateImplementation(operation) case spec.DeleteOperation: return g.generateDeleteImplementation(operation) } @@ -122,6 +124,51 @@ func (g RepositoryGenerator) generateFindImplementation(operation spec.FindOpera return buffer.String(), nil } +func (g RepositoryGenerator) generateUpdateImplementation(operation spec.UpdateOperation) (string, error) { + buffer := new(bytes.Buffer) + + var fields []updateField + for _, field := range operation.Fields { + bsonTag, err := g.bsonTagFromFieldName(field.Name) + if err != nil { + return "", err + } + fields = append(fields, updateField{BsonTag: bsonTag, ParamIndex: field.ParamIndex}) + } + + querySpec, err := g.mongoQuerySpec(operation.Query) + if err != nil { + return "", err + } + + tmplData := mongoUpdateTemplateData{ + UpdateFields: fields, + QuerySpec: querySpec, + } + + if operation.Mode == spec.QueryModeOne { + tmpl, err := template.New("mongo_repository_updateone").Parse(updateOneTemplate) + if err != nil { + return "", err + } + + if err := tmpl.Execute(buffer, tmplData); err != nil { + return "", err + } + } else { + tmpl, err := template.New("mongo_repository_updatemany").Parse(updateManyTemplate) + if err != nil { + return "", err + } + + if err := tmpl.Execute(buffer, tmplData); err != nil { + return "", err + } + } + + return buffer.String(), nil +} + func (g RepositoryGenerator) generateDeleteImplementation(operation spec.DeleteOperation) (string, error) { buffer := new(bytes.Buffer) @@ -161,17 +208,16 @@ func (g RepositoryGenerator) mongoQuerySpec(query spec.QuerySpec) (querySpec, er var predicates []predicate for _, predicateSpec := range query.Predicates { - structField, ok := g.StructModel.Fields.ByName(predicateSpec.Field) - if !ok { - return querySpec{}, fmt.Errorf("struct field %s not found", predicateSpec.Field) + bsonTag, err := g.bsonTagFromFieldName(predicateSpec.Field) + if err != nil { + return querySpec{}, err } - bsonTag, ok := structField.Tags["bson"] - if !ok { - return querySpec{}, BsonTagNotFoundError - } - - predicates = append(predicates, predicate{Field: bsonTag[0], Comparator: predicateSpec.Comparator}) + predicates = append(predicates, predicate{ + Field: bsonTag, + Comparator: predicateSpec.Comparator, + ParamIndex: predicateSpec.ParamIndex, + }) } return querySpec{ @@ -180,6 +226,20 @@ func (g RepositoryGenerator) mongoQuerySpec(query spec.QuerySpec) (querySpec, er }, nil } +func (g RepositoryGenerator) bsonTagFromFieldName(fieldName string) (string, error) { + structField, ok := g.StructModel.Fields.ByName(fieldName) + if !ok { + return "", fmt.Errorf("struct field %s not found", fieldName) + } + + bsonTag, ok := structField.Tags["bson"] + if !ok { + return "", BsonTagNotFoundError + } + + return bsonTag[0], nil +} + func (g RepositoryGenerator) structName() string { return g.InterfaceName + "Mongo" } diff --git a/internal/mongo/generator_test.go b/internal/mongo/generator_test.go index 14ae9b8..bdfc73c 100644 --- a/internal/mongo/generator_test.go +++ b/internal/mongo/generator_test.go @@ -95,16 +95,16 @@ func TestGenerateMethod_Find(t *testing.T) { Mode: spec.QueryModeOne, Query: spec.QuerySpec{ Predicates: []spec.Predicate{ - {Comparator: spec.ComparatorEqual, Field: "ID"}, + {Comparator: spec.ComparatorEqual, Field: "ID", ParamIndex: 1}, }, }, }, }, ExpectedCode: ` -func (r *UserRepositoryMongo) FindByID(ctx context.Context, arg0 primitive.ObjectID) (*UserModel, error) { +func (r *UserRepositoryMongo) FindByID(arg0 context.Context, arg1 primitive.ObjectID) (*UserModel, error) { var entity UserModel - if err := r.collection.FindOne(ctx, bson.M{ - "_id": arg0, + if err := r.collection.FindOne(arg0, bson.M{ + "_id": arg1, }).Decode(&entity); err != nil { return nil, err } @@ -128,21 +128,21 @@ func (r *UserRepositoryMongo) FindByID(ctx context.Context, arg0 primitive.Objec Mode: spec.QueryModeMany, Query: spec.QuerySpec{ Predicates: []spec.Predicate{ - {Comparator: spec.ComparatorEqual, Field: "Gender"}, + {Comparator: spec.ComparatorEqual, Field: "Gender", ParamIndex: 1}, }, }, }, }, ExpectedCode: ` -func (r *UserRepositoryMongo) FindByGender(ctx context.Context, arg0 Gender) ([]*UserModel, error) { - cursor, err := r.collection.Find(ctx, bson.M{ - "gender": arg0, +func (r *UserRepositoryMongo) FindByGender(arg0 context.Context, arg1 Gender) ([]*UserModel, error) { + cursor, err := r.collection.Find(arg0, bson.M{ + "gender": arg1, }) if err != nil { return nil, err } var entities []*UserModel - if err := cursor.All(ctx, &entities); err != nil { + if err := cursor.All(arg0, &entities); err != nil { return nil, err } return entities, nil @@ -167,23 +167,23 @@ func (r *UserRepositoryMongo) FindByGender(ctx context.Context, arg0 Gender) ([] Query: spec.QuerySpec{ Operator: spec.OperatorAnd, Predicates: []spec.Predicate{ - {Comparator: spec.ComparatorEqual, Field: "Gender"}, - {Comparator: spec.ComparatorEqual, Field: "Age"}, + {Comparator: spec.ComparatorEqual, Field: "Gender", ParamIndex: 1}, + {Comparator: spec.ComparatorEqual, Field: "Age", ParamIndex: 2}, }, }, }, }, ExpectedCode: ` -func (r *UserRepositoryMongo) FindByGenderAndAge(ctx context.Context, arg0 Gender, arg1 int) ([]*UserModel, error) { - cursor, err := r.collection.Find(ctx, bson.M{ - "gender": arg0, - "age": arg1, +func (r *UserRepositoryMongo) FindByGenderAndAge(arg0 context.Context, arg1 Gender, arg2 int) ([]*UserModel, error) { + cursor, err := r.collection.Find(arg0, bson.M{ + "gender": arg1, + "age": arg2, }) if err != nil { return nil, err } var entities []*UserModel - if err := cursor.All(ctx, &entities); err != nil { + if err := cursor.All(arg0, &entities); err != nil { return nil, err } return entities, nil @@ -208,25 +208,25 @@ func (r *UserRepositoryMongo) FindByGenderAndAge(ctx context.Context, arg0 Gende Query: spec.QuerySpec{ Operator: spec.OperatorOr, Predicates: []spec.Predicate{ - {Comparator: spec.ComparatorEqual, Field: "Gender"}, - {Comparator: spec.ComparatorEqual, Field: "Age"}, + {Comparator: spec.ComparatorEqual, Field: "Gender", ParamIndex: 1}, + {Comparator: spec.ComparatorEqual, Field: "Age", ParamIndex: 2}, }, }, }, }, ExpectedCode: ` -func (r *UserRepositoryMongo) FindByGenderOrAge(ctx context.Context, arg0 Gender, arg1 int) ([]*UserModel, error) { - cursor, err := r.collection.Find(ctx, bson.M{ +func (r *UserRepositoryMongo) FindByGenderOrAge(arg0 context.Context, arg1 Gender, arg2 int) ([]*UserModel, error) { + cursor, err := r.collection.Find(arg0, bson.M{ "$or": []bson.M{ - {"gender": arg0}, - {"age": arg1}, + {"gender": arg1}, + {"age": arg2}, }, }) if err != nil { return nil, err } var entities []*UserModel - if err := cursor.All(ctx, &entities); err != nil { + if err := cursor.All(arg0, &entities); err != nil { return nil, err } return entities, nil @@ -249,21 +249,21 @@ func (r *UserRepositoryMongo) FindByGenderOrAge(ctx context.Context, arg0 Gender Mode: spec.QueryModeMany, Query: spec.QuerySpec{ Predicates: []spec.Predicate{ - {Comparator: spec.ComparatorNot, Field: "Gender"}, + {Comparator: spec.ComparatorNot, Field: "Gender", ParamIndex: 1}, }, }, }, }, ExpectedCode: ` -func (r *UserRepositoryMongo) FindByGenderNot(ctx context.Context, arg0 Gender) ([]*UserModel, error) { - cursor, err := r.collection.Find(ctx, bson.M{ - "gender": bson.M{"$ne": arg0}, +func (r *UserRepositoryMongo) FindByGenderNot(arg0 context.Context, arg1 Gender) ([]*UserModel, error) { + cursor, err := r.collection.Find(arg0, bson.M{ + "gender": bson.M{"$ne": arg1}, }) if err != nil { return nil, err } var entities []*UserModel - if err := cursor.All(ctx, &entities); err != nil { + if err := cursor.All(arg0, &entities); err != nil { return nil, err } return entities, nil @@ -286,21 +286,21 @@ func (r *UserRepositoryMongo) FindByGenderNot(ctx context.Context, arg0 Gender) Mode: spec.QueryModeMany, Query: spec.QuerySpec{ Predicates: []spec.Predicate{ - {Comparator: spec.ComparatorLessThan, Field: "Age"}, + {Comparator: spec.ComparatorLessThan, Field: "Age", ParamIndex: 1}, }, }, }, }, ExpectedCode: ` -func (r *UserRepositoryMongo) FindByAgeLessThan(ctx context.Context, arg0 int) ([]*UserModel, error) { - cursor, err := r.collection.Find(ctx, bson.M{ - "age": bson.M{"$lt": arg0}, +func (r *UserRepositoryMongo) FindByAgeLessThan(arg0 context.Context, arg1 int) ([]*UserModel, error) { + cursor, err := r.collection.Find(arg0, bson.M{ + "age": bson.M{"$lt": arg1}, }) if err != nil { return nil, err } var entities []*UserModel - if err := cursor.All(ctx, &entities); err != nil { + if err := cursor.All(arg0, &entities); err != nil { return nil, err } return entities, nil @@ -323,21 +323,21 @@ func (r *UserRepositoryMongo) FindByAgeLessThan(ctx context.Context, arg0 int) ( Mode: spec.QueryModeMany, Query: spec.QuerySpec{ Predicates: []spec.Predicate{ - {Comparator: spec.ComparatorLessThanEqual, Field: "Age"}, + {Comparator: spec.ComparatorLessThanEqual, Field: "Age", ParamIndex: 1}, }, }, }, }, ExpectedCode: ` -func (r *UserRepositoryMongo) FindByAgeLessThanEqual(ctx context.Context, arg0 int) ([]*UserModel, error) { - cursor, err := r.collection.Find(ctx, bson.M{ - "age": bson.M{"$lte": arg0}, +func (r *UserRepositoryMongo) FindByAgeLessThanEqual(arg0 context.Context, arg1 int) ([]*UserModel, error) { + cursor, err := r.collection.Find(arg0, bson.M{ + "age": bson.M{"$lte": arg1}, }) if err != nil { return nil, err } var entities []*UserModel - if err := cursor.All(ctx, &entities); err != nil { + if err := cursor.All(arg0, &entities); err != nil { return nil, err } return entities, nil @@ -360,21 +360,21 @@ func (r *UserRepositoryMongo) FindByAgeLessThanEqual(ctx context.Context, arg0 i Mode: spec.QueryModeMany, Query: spec.QuerySpec{ Predicates: []spec.Predicate{ - {Comparator: spec.ComparatorGreaterThan, Field: "Age"}, + {Comparator: spec.ComparatorGreaterThan, Field: "Age", ParamIndex: 1}, }, }, }, }, ExpectedCode: ` -func (r *UserRepositoryMongo) FindByAgeGreaterThan(ctx context.Context, arg0 int) ([]*UserModel, error) { - cursor, err := r.collection.Find(ctx, bson.M{ - "age": bson.M{"$gt": arg0}, +func (r *UserRepositoryMongo) FindByAgeGreaterThan(arg0 context.Context, arg1 int) ([]*UserModel, error) { + cursor, err := r.collection.Find(arg0, bson.M{ + "age": bson.M{"$gt": arg1}, }) if err != nil { return nil, err } var entities []*UserModel - if err := cursor.All(ctx, &entities); err != nil { + if err := cursor.All(arg0, &entities); err != nil { return nil, err } return entities, nil @@ -397,21 +397,21 @@ func (r *UserRepositoryMongo) FindByAgeGreaterThan(ctx context.Context, arg0 int Mode: spec.QueryModeMany, Query: spec.QuerySpec{ Predicates: []spec.Predicate{ - {Comparator: spec.ComparatorGreaterThanEqual, Field: "Age"}, + {Comparator: spec.ComparatorGreaterThanEqual, Field: "Age", ParamIndex: 1}, }, }, }, }, ExpectedCode: ` -func (r *UserRepositoryMongo) FindByAgeGreaterThanEqual(ctx context.Context, arg0 int) ([]*UserModel, error) { - cursor, err := r.collection.Find(ctx, bson.M{ - "age": bson.M{"$gte": arg0}, +func (r *UserRepositoryMongo) FindByAgeGreaterThanEqual(arg0 context.Context, arg1 int) ([]*UserModel, error) { + cursor, err := r.collection.Find(arg0, bson.M{ + "age": bson.M{"$gte": arg1}, }) if err != nil { return nil, err } var entities []*UserModel - if err := cursor.All(ctx, &entities); err != nil { + if err := cursor.All(arg0, &entities); err != nil { return nil, err } return entities, nil @@ -435,21 +435,21 @@ func (r *UserRepositoryMongo) FindByAgeGreaterThanEqual(ctx context.Context, arg Mode: spec.QueryModeMany, Query: spec.QuerySpec{ Predicates: []spec.Predicate{ - {Comparator: spec.ComparatorBetween, Field: "Age"}, + {Comparator: spec.ComparatorBetween, Field: "Age", ParamIndex: 1}, }, }, }, }, ExpectedCode: ` -func (r *UserRepositoryMongo) FindByAgeBetween(ctx context.Context, arg0 int, arg1 int) ([]*UserModel, error) { - cursor, err := r.collection.Find(ctx, bson.M{ - "age": bson.M{"$gte": arg0, "$lte": arg1}, +func (r *UserRepositoryMongo) FindByAgeBetween(arg0 context.Context, arg1 int, arg2 int) ([]*UserModel, error) { + cursor, err := r.collection.Find(arg0, bson.M{ + "age": bson.M{"$gte": arg1, "$lte": arg2}, }) if err != nil { return nil, err } var entities []*UserModel - if err := cursor.All(ctx, &entities); err != nil { + if err := cursor.All(arg0, &entities); err != nil { return nil, err } return entities, nil @@ -472,21 +472,21 @@ func (r *UserRepositoryMongo) FindByAgeBetween(ctx context.Context, arg0 int, ar Mode: spec.QueryModeMany, Query: spec.QuerySpec{ Predicates: []spec.Predicate{ - {Comparator: spec.ComparatorIn, Field: "Gender"}, + {Comparator: spec.ComparatorIn, Field: "Gender", ParamIndex: 1}, }, }, }, }, ExpectedCode: ` -func (r *UserRepositoryMongo) FindByGenderIn(ctx context.Context, arg0 []Gender) ([]*UserModel, error) { - cursor, err := r.collection.Find(ctx, bson.M{ - "gender": bson.M{"$in": arg0}, +func (r *UserRepositoryMongo) FindByGenderIn(arg0 context.Context, arg1 []Gender) ([]*UserModel, error) { + cursor, err := r.collection.Find(arg0, bson.M{ + "gender": bson.M{"$in": arg1}, }) if err != nil { return nil, err } var entities []*UserModel - if err := cursor.All(ctx, &entities); err != nil { + if err := cursor.All(arg0, &entities); err != nil { return nil, err } return entities, nil @@ -512,6 +512,109 @@ func (r *UserRepositoryMongo) FindByGenderIn(ctx context.Context, arg0 []Gender) } } +func TestGenerateMethod_Update(t *testing.T) { + testTable := []GenerateMethodTestCase{ + { + Name: "simple update one method", + MethodSpec: spec.MethodSpec{ + Name: "UpdateAgeByID", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Name: "age", Type: code.SimpleType("int")}, + {Name: "id", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, + }, + Returns: []code.Type{ + code.SimpleType("bool"), + code.SimpleType("error"), + }, + Operation: spec.UpdateOperation{ + Fields: []spec.UpdateField{ + {Name: "Age", ParamIndex: 1}, + }, + Mode: spec.QueryModeOne, + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + {Field: "ID", Comparator: spec.ComparatorEqual, ParamIndex: 2}, + }, + }, + }, + }, + ExpectedCode: ` +func (r *UserRepositoryMongo) UpdateAgeByID(arg0 context.Context, arg1 int, arg2 primitive.ObjectID) (bool, error) { + result, err := r.collection.UpdateOne(arg0, bson.M{ + "_id": arg2, + }, bson.M{ + "$set": bson.M{ + "age": arg1, + }, + }) + if err != nil { + return false, err + } + return result.MatchedCount > 0, err +} +`, + }, + { + Name: "simple update many method", + MethodSpec: spec.MethodSpec{ + Name: "UpdateAgeByGender", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Name: "age", Type: code.SimpleType("int")}, + {Name: "gender", Type: code.SimpleType("Gender")}, + }, + Returns: []code.Type{ + code.SimpleType("int"), + code.SimpleType("error"), + }, + Operation: spec.UpdateOperation{ + Fields: []spec.UpdateField{ + {Name: "Age", ParamIndex: 1}, + }, + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + {Field: "Gender", Comparator: spec.ComparatorEqual, ParamIndex: 2}, + }, + }, + }, + }, + ExpectedCode: ` +func (r *UserRepositoryMongo) UpdateAgeByGender(arg0 context.Context, arg1 int, arg2 Gender) (int, error) { + result, err := r.collection.UpdateMany(arg0, bson.M{ + "gender": arg2, + }, bson.M{ + "$set": bson.M{ + "age": arg1, + }, + }) + if err != nil { + return 0, err + } + return int(result.MatchedCount), err +} +`, + }, + } + + for _, testCase := range testTable { + t.Run(testCase.Name, func(t *testing.T) { + generator := mongo.NewGenerator(userModel, "UserRepository") + buffer := new(bytes.Buffer) + + err := generator.GenerateMethod(testCase.MethodSpec, buffer) + + if err != nil { + t.Error(err) + } + if err := testutils.ExpectMultiLineString(testCase.ExpectedCode, buffer.String()); err != nil { + t.Error(err) + } + }) + } +} + func TestGenerateMethod_Delete(t *testing.T) { testTable := []GenerateMethodTestCase{ { @@ -527,15 +630,15 @@ func TestGenerateMethod_Delete(t *testing.T) { Mode: spec.QueryModeOne, Query: spec.QuerySpec{ Predicates: []spec.Predicate{ - {Comparator: spec.ComparatorEqual, Field: "ID"}, + {Comparator: spec.ComparatorEqual, Field: "ID", ParamIndex: 1}, }, }, }, }, ExpectedCode: ` -func (r *UserRepositoryMongo) DeleteByID(ctx context.Context, arg0 primitive.ObjectID) (bool, error) { - result, err := r.collection.DeleteOne(ctx, bson.M{ - "_id": arg0, +func (r *UserRepositoryMongo) DeleteByID(arg0 context.Context, arg1 primitive.ObjectID) (bool, error) { + result, err := r.collection.DeleteOne(arg0, bson.M{ + "_id": arg1, }) if err != nil { return false, err @@ -560,15 +663,15 @@ func (r *UserRepositoryMongo) DeleteByID(ctx context.Context, arg0 primitive.Obj Mode: spec.QueryModeMany, Query: spec.QuerySpec{ Predicates: []spec.Predicate{ - {Comparator: spec.ComparatorEqual, Field: "Gender"}, + {Comparator: spec.ComparatorEqual, Field: "Gender", ParamIndex: 1}, }, }, }, }, ExpectedCode: ` -func (r *UserRepositoryMongo) DeleteByGender(ctx context.Context, arg0 Gender) (int, error) { - result, err := r.collection.DeleteMany(ctx, bson.M{ - "gender": arg0, +func (r *UserRepositoryMongo) DeleteByGender(arg0 context.Context, arg1 Gender) (int, error) { + result, err := r.collection.DeleteMany(arg0, bson.M{ + "gender": arg1, }) if err != nil { return 0, err @@ -595,17 +698,17 @@ func (r *UserRepositoryMongo) DeleteByGender(ctx context.Context, arg0 Gender) ( Query: spec.QuerySpec{ Operator: spec.OperatorAnd, Predicates: []spec.Predicate{ - {Comparator: spec.ComparatorEqual, Field: "Gender"}, - {Comparator: spec.ComparatorEqual, Field: "Age"}, + {Comparator: spec.ComparatorEqual, Field: "Gender", ParamIndex: 1}, + {Comparator: spec.ComparatorEqual, Field: "Age", ParamIndex: 2}, }, }, }, }, ExpectedCode: ` -func (r *UserRepositoryMongo) DeleteByGenderAndAge(ctx context.Context, arg0 Gender, arg1 int) (int, error) { - result, err := r.collection.DeleteMany(ctx, bson.M{ - "gender": arg0, - "age": arg1, +func (r *UserRepositoryMongo) DeleteByGenderAndAge(arg0 context.Context, arg1 Gender, arg2 int) (int, error) { + result, err := r.collection.DeleteMany(arg0, bson.M{ + "gender": arg1, + "age": arg2, }) if err != nil { return 0, err @@ -632,18 +735,18 @@ func (r *UserRepositoryMongo) DeleteByGenderAndAge(ctx context.Context, arg0 Gen Query: spec.QuerySpec{ Operator: spec.OperatorOr, Predicates: []spec.Predicate{ - {Comparator: spec.ComparatorEqual, Field: "Gender"}, - {Comparator: spec.ComparatorEqual, Field: "Age"}, + {Comparator: spec.ComparatorEqual, Field: "Gender", ParamIndex: 1}, + {Comparator: spec.ComparatorEqual, Field: "Age", ParamIndex: 2}, }, }, }, }, ExpectedCode: ` -func (r *UserRepositoryMongo) DeleteByGenderOrAge(ctx context.Context, arg0 Gender, arg1 int) (int, error) { - result, err := r.collection.DeleteMany(ctx, bson.M{ +func (r *UserRepositoryMongo) DeleteByGenderOrAge(arg0 context.Context, arg1 Gender, arg2 int) (int, error) { + result, err := r.collection.DeleteMany(arg0, bson.M{ "$or": []bson.M{ - {"gender": arg0}, - {"age": arg1}, + {"gender": arg1}, + {"age": arg2}, }, }) if err != nil { @@ -669,15 +772,15 @@ func (r *UserRepositoryMongo) DeleteByGenderOrAge(ctx context.Context, arg0 Gend Mode: spec.QueryModeMany, Query: spec.QuerySpec{ Predicates: []spec.Predicate{ - {Comparator: spec.ComparatorNot, Field: "Gender"}, + {Comparator: spec.ComparatorNot, Field: "Gender", ParamIndex: 1}, }, }, }, }, ExpectedCode: ` -func (r *UserRepositoryMongo) DeleteByGenderNot(ctx context.Context, arg0 Gender) (int, error) { - result, err := r.collection.DeleteMany(ctx, bson.M{ - "gender": bson.M{"$ne": arg0}, +func (r *UserRepositoryMongo) DeleteByGenderNot(arg0 context.Context, arg1 Gender) (int, error) { + result, err := r.collection.DeleteMany(arg0, bson.M{ + "gender": bson.M{"$ne": arg1}, }) if err != nil { return 0, err @@ -702,15 +805,15 @@ func (r *UserRepositoryMongo) DeleteByGenderNot(ctx context.Context, arg0 Gender Mode: spec.QueryModeMany, Query: spec.QuerySpec{ Predicates: []spec.Predicate{ - {Comparator: spec.ComparatorLessThan, Field: "Age"}, + {Comparator: spec.ComparatorLessThan, Field: "Age", ParamIndex: 1}, }, }, }, }, ExpectedCode: ` -func (r *UserRepositoryMongo) DeleteByAgeLessThan(ctx context.Context, arg0 int) (int, error) { - result, err := r.collection.DeleteMany(ctx, bson.M{ - "age": bson.M{"$lt": arg0}, +func (r *UserRepositoryMongo) DeleteByAgeLessThan(arg0 context.Context, arg1 int) (int, error) { + result, err := r.collection.DeleteMany(arg0, bson.M{ + "age": bson.M{"$lt": arg1}, }) if err != nil { return 0, err @@ -735,15 +838,15 @@ func (r *UserRepositoryMongo) DeleteByAgeLessThan(ctx context.Context, arg0 int) Mode: spec.QueryModeMany, Query: spec.QuerySpec{ Predicates: []spec.Predicate{ - {Comparator: spec.ComparatorLessThanEqual, Field: "Age"}, + {Comparator: spec.ComparatorLessThanEqual, Field: "Age", ParamIndex: 1}, }, }, }, }, ExpectedCode: ` -func (r *UserRepositoryMongo) DeleteByAgeLessThanEqual(ctx context.Context, arg0 int) (int, error) { - result, err := r.collection.DeleteMany(ctx, bson.M{ - "age": bson.M{"$lte": arg0}, +func (r *UserRepositoryMongo) DeleteByAgeLessThanEqual(arg0 context.Context, arg1 int) (int, error) { + result, err := r.collection.DeleteMany(arg0, bson.M{ + "age": bson.M{"$lte": arg1}, }) if err != nil { return 0, err @@ -768,15 +871,15 @@ func (r *UserRepositoryMongo) DeleteByAgeLessThanEqual(ctx context.Context, arg0 Mode: spec.QueryModeMany, Query: spec.QuerySpec{ Predicates: []spec.Predicate{ - {Comparator: spec.ComparatorGreaterThan, Field: "Age"}, + {Comparator: spec.ComparatorGreaterThan, Field: "Age", ParamIndex: 1}, }, }, }, }, ExpectedCode: ` -func (r *UserRepositoryMongo) DeleteByAgeGreaterThan(ctx context.Context, arg0 int) (int, error) { - result, err := r.collection.DeleteMany(ctx, bson.M{ - "age": bson.M{"$gt": arg0}, +func (r *UserRepositoryMongo) DeleteByAgeGreaterThan(arg0 context.Context, arg1 int) (int, error) { + result, err := r.collection.DeleteMany(arg0, bson.M{ + "age": bson.M{"$gt": arg1}, }) if err != nil { return 0, err @@ -801,15 +904,15 @@ func (r *UserRepositoryMongo) DeleteByAgeGreaterThan(ctx context.Context, arg0 i Mode: spec.QueryModeMany, Query: spec.QuerySpec{ Predicates: []spec.Predicate{ - {Comparator: spec.ComparatorGreaterThanEqual, Field: "Age"}, + {Comparator: spec.ComparatorGreaterThanEqual, Field: "Age", ParamIndex: 1}, }, }, }, }, ExpectedCode: ` -func (r *UserRepositoryMongo) DeleteByAgeGreaterThanEqual(ctx context.Context, arg0 int) (int, error) { - result, err := r.collection.DeleteMany(ctx, bson.M{ - "age": bson.M{"$gte": arg0}, +func (r *UserRepositoryMongo) DeleteByAgeGreaterThanEqual(arg0 context.Context, arg1 int) (int, error) { + result, err := r.collection.DeleteMany(arg0, bson.M{ + "age": bson.M{"$gte": arg1}, }) if err != nil { return 0, err @@ -835,15 +938,15 @@ func (r *UserRepositoryMongo) DeleteByAgeGreaterThanEqual(ctx context.Context, a Mode: spec.QueryModeMany, Query: spec.QuerySpec{ Predicates: []spec.Predicate{ - {Comparator: spec.ComparatorBetween, Field: "Age"}, + {Comparator: spec.ComparatorBetween, Field: "Age", ParamIndex: 1}, }, }, }, }, ExpectedCode: ` -func (r *UserRepositoryMongo) DeleteByAgeBetween(ctx context.Context, arg0 int, arg1 int) (int, error) { - result, err := r.collection.DeleteMany(ctx, bson.M{ - "age": bson.M{"$gte": arg0, "$lte": arg1}, +func (r *UserRepositoryMongo) DeleteByAgeBetween(arg0 context.Context, arg1 int, arg2 int) (int, error) { + result, err := r.collection.DeleteMany(arg0, bson.M{ + "age": bson.M{"$gte": arg1, "$lte": arg2}, }) if err != nil { return 0, err @@ -868,15 +971,15 @@ func (r *UserRepositoryMongo) DeleteByAgeBetween(ctx context.Context, arg0 int, Mode: spec.QueryModeMany, Query: spec.QuerySpec{ Predicates: []spec.Predicate{ - {Comparator: spec.ComparatorIn, Field: "Gender"}, + {Comparator: spec.ComparatorIn, Field: "Gender", ParamIndex: 1}, }, }, }, }, ExpectedCode: ` -func (r *UserRepositoryMongo) DeleteByGenderIn(ctx context.Context, arg0 []Gender) (int, error) { - result, err := r.collection.DeleteMany(ctx, bson.M{ - "gender": bson.M{"$in": arg0}, +func (r *UserRepositoryMongo) DeleteByGenderIn(arg0 context.Context, arg1 []Gender) (int, error) { + result, err := r.collection.DeleteMany(arg0, bson.M{ + "gender": bson.M{"$in": arg1}, }) if err != nil { return 0, err @@ -929,7 +1032,7 @@ func TestGenerateMethod_Invalid(t *testing.T) { ExpectedError: mongo.OperationNotSupportedError, }, { - Name: "bson tag not found", + Name: "bson tag not found in query", Method: spec.MethodSpec{ Name: "FindByAccessToken", Params: []code.Param{ @@ -944,7 +1047,34 @@ func TestGenerateMethod_Invalid(t *testing.T) { Mode: spec.QueryModeOne, Query: spec.QuerySpec{ Predicates: []spec.Predicate{ - {Field: "AccessToken", Comparator: spec.ComparatorEqual}, + {Field: "AccessToken", Comparator: spec.ComparatorEqual, ParamIndex: 1}, + }, + }, + }, + }, + ExpectedError: mongo.BsonTagNotFoundError, + }, + { + Name: "bson tag not found in update field", + Method: spec.MethodSpec{ + Name: "UpdateAccessTokenByID", + Params: []code.Param{ + {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Type: code.SimpleType("string")}, + {Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, + }, + Returns: []code.Type{ + code.SimpleType("bool"), + code.SimpleType("error"), + }, + Operation: spec.UpdateOperation{ + Fields: []spec.UpdateField{ + {Name: "AccessToken", ParamIndex: 1}, + }, + Mode: spec.QueryModeOne, + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + {Field: "ID", Comparator: spec.ComparatorEqual, ParamIndex: 2}, }, }, }, diff --git a/internal/mongo/models.go b/internal/mongo/models.go index 92d2cfd..4ef0726 100644 --- a/internal/mongo/models.go +++ b/internal/mongo/models.go @@ -7,6 +7,11 @@ import ( "github.com/sunboyy/repogen/internal/spec" ) +type updateField struct { + BsonTag string + ParamIndex int +} + type querySpec struct { Operator spec.Operator Predicates []predicate @@ -14,10 +19,8 @@ type querySpec struct { func (q querySpec) Code() string { var predicateCodes []string - var argIndex int for _, predicate := range q.Predicates { - predicateCodes = append(predicateCodes, predicate.Code(argIndex)) - argIndex += predicate.Comparator.NumberOfArguments() + predicateCodes = append(predicateCodes, predicate.Code()) } var lines []string @@ -39,26 +42,27 @@ func (q querySpec) Code() string { type predicate struct { Field string Comparator spec.Comparator + ParamIndex int } -func (p predicate) Code(argIndex int) string { +func (p predicate) Code() string { switch p.Comparator { case spec.ComparatorEqual: - return fmt.Sprintf(`"%s": arg%d`, p.Field, argIndex) + return fmt.Sprintf(`"%s": arg%d`, p.Field, p.ParamIndex) case spec.ComparatorNot: - return fmt.Sprintf(`"%s": bson.M{"$ne": arg%d}`, p.Field, argIndex) + return fmt.Sprintf(`"%s": bson.M{"$ne": arg%d}`, p.Field, p.ParamIndex) case spec.ComparatorLessThan: - return fmt.Sprintf(`"%s": bson.M{"$lt": arg%d}`, p.Field, argIndex) + return fmt.Sprintf(`"%s": bson.M{"$lt": arg%d}`, p.Field, p.ParamIndex) case spec.ComparatorLessThanEqual: - return fmt.Sprintf(`"%s": bson.M{"$lte": arg%d}`, p.Field, argIndex) + return fmt.Sprintf(`"%s": bson.M{"$lte": arg%d}`, p.Field, p.ParamIndex) case spec.ComparatorGreaterThan: - return fmt.Sprintf(`"%s": bson.M{"$gt": arg%d}`, p.Field, argIndex) + return fmt.Sprintf(`"%s": bson.M{"$gt": arg%d}`, p.Field, p.ParamIndex) case spec.ComparatorGreaterThanEqual: - return fmt.Sprintf(`"%s": bson.M{"$gte": arg%d}`, p.Field, argIndex) + return fmt.Sprintf(`"%s": bson.M{"$gte": arg%d}`, p.Field, p.ParamIndex) case spec.ComparatorBetween: - return fmt.Sprintf(`"%s": bson.M{"$gte": arg%d, "$lte": arg%d}`, p.Field, argIndex, argIndex+1) + return fmt.Sprintf(`"%s": bson.M{"$gte": arg%d, "$lte": arg%d}`, p.Field, p.ParamIndex, p.ParamIndex+1) case spec.ComparatorIn: - return fmt.Sprintf(`"%s": bson.M{"$in": arg%d}`, p.Field, argIndex) + return fmt.Sprintf(`"%s": bson.M{"$in": arg%d}`, p.Field, p.ParamIndex) } return "" } diff --git a/internal/mongo/templates.go b/internal/mongo/templates.go index d5fb963..7a1028f 100644 --- a/internal/mongo/templates.go +++ b/internal/mongo/templates.go @@ -33,7 +33,7 @@ type mongoConstructorTemplateData struct { } const methodTemplate = ` -func (r *{{.StructName}}) {{.MethodName}}(ctx context.Context, {{.Parameters}}){{.Returns}} { +func (r *{{.StructName}}) {{.MethodName}}({{.Parameters}}){{.Returns}} { {{.Implementation}} } ` @@ -76,30 +76,59 @@ type mongoFindTemplateData struct { } const findOneTemplate = ` var entity {{.EntityType}} - if err := r.collection.FindOne(ctx, bson.M{ + if err := r.collection.FindOne(arg0, bson.M{ {{.QuerySpec.Code}} }).Decode(&entity); err != nil { return nil, err } return &entity, nil` -const findManyTemplate = ` cursor, err := r.collection.Find(ctx, bson.M{ +const findManyTemplate = ` cursor, err := r.collection.Find(arg0, bson.M{ {{.QuerySpec.Code}} }) if err != nil { return nil, err } var entities []*{{.EntityType}} - if err := cursor.All(ctx, &entities); err != nil { + if err := cursor.All(arg0, &entities); err != nil { return nil, err } return entities, nil` +type mongoUpdateTemplateData struct { + UpdateFields []updateField + QuerySpec querySpec +} + +const updateOneTemplate = ` result, err := r.collection.UpdateOne(arg0, bson.M{ +{{.QuerySpec.Code}} + }, bson.M{ + "$set": bson.M{ +{{range $index, $element := .UpdateFields}} "{{$element.BsonTag}}": arg{{$element.ParamIndex}}, +{{end}} }, + }) + if err != nil { + return false, err + } + return result.MatchedCount > 0, err` + +const updateManyTemplate = ` result, err := r.collection.UpdateMany(arg0, bson.M{ +{{.QuerySpec.Code}} + }, bson.M{ + "$set": bson.M{ +{{range $index, $element := .UpdateFields}} "{{$element.BsonTag}}": arg{{$element.ParamIndex}}, +{{end}} }, + }) + if err != nil { + return 0, err + } + return int(result.MatchedCount), err` + type mongoDeleteTemplateData struct { QuerySpec querySpec } -const deleteOneTemplate = ` result, err := r.collection.DeleteOne(ctx, bson.M{ +const deleteOneTemplate = ` result, err := r.collection.DeleteOne(arg0, bson.M{ {{.QuerySpec.Code}} }) if err != nil { @@ -107,7 +136,7 @@ const deleteOneTemplate = ` result, err := r.collection.DeleteOne(ctx, bson.M{ } return result.DeletedCount > 0, nil` -const deleteManyTemplate = ` result, err := r.collection.DeleteMany(ctx, bson.M{ +const deleteManyTemplate = ` result, err := r.collection.DeleteMany(arg0, bson.M{ {{.QuerySpec.Code}} }) if err != nil { diff --git a/internal/spec/errors.go b/internal/spec/errors.go index 22f3495..d4d0dd9 100644 --- a/internal/spec/errors.go +++ b/internal/spec/errors.go @@ -13,6 +13,8 @@ func (err ParsingError) Error() string { return "invalid query" case InvalidParamError: return "parameters do not match the query" + case InvalidUpdateFieldsError: + return "update fields is invalid" case UnsupportedReturnError: return "this type of return is not supported" case ContextParamRequiredError: @@ -29,6 +31,7 @@ const ( UnsupportedNameError ParsingError = "ERROR_UNSUPPORTED" InvalidQueryError ParsingError = "ERROR_INVALID_QUERY" InvalidParamError ParsingError = "ERROR_INVALID_PARAM" + InvalidUpdateFieldsError ParsingError = "ERROR_INVALID_UPDATE_FIELDS" UnsupportedReturnError ParsingError = "ERROR_INVALID_RETURN" ContextParamRequiredError ParsingError = "ERROR_CONTEXT_PARAM_REQUIRED" StructFieldNotFoundError ParsingError = "ERROR_STRUCT_FIELD_NOT_FOUND" diff --git a/internal/spec/models.go b/internal/spec/models.go index 66def43..b893fc1 100644 --- a/internal/spec/models.go +++ b/internal/spec/models.go @@ -33,6 +33,19 @@ type FindOperation struct { Query QuerySpec } +// UpdateOperation is a method specification for update operations +type UpdateOperation struct { + Fields []UpdateField + Mode QueryMode + Query QuerySpec +} + +// UpdateField stores mapping between field name in the model and the parameter index +type UpdateField struct { + Name string + ParamIndex int +} + // DeleteOperation is a method specification for delete operations type DeleteOperation struct { Mode QueryMode @@ -98,31 +111,32 @@ func (c Comparator) NumberOfArguments() int { type Predicate struct { Field string Comparator Comparator + ParamIndex int } type predicateToken []string -func (t predicateToken) ToPredicate() Predicate { +func (t predicateToken) ToPredicate(paramIndex int) Predicate { if len(t) > 1 && t[len(t)-1] == "Not" { - return Predicate{Field: strings.Join(t[:len(t)-1], ""), Comparator: ComparatorNot} + return Predicate{Field: strings.Join(t[:len(t)-1], ""), Comparator: ComparatorNot, ParamIndex: paramIndex} } if len(t) > 2 && t[len(t)-2] == "Less" && t[len(t)-1] == "Than" { - return Predicate{Field: strings.Join(t[:len(t)-2], ""), Comparator: ComparatorLessThan} + return Predicate{Field: strings.Join(t[:len(t)-2], ""), Comparator: ComparatorLessThan, ParamIndex: paramIndex} } if len(t) > 3 && t[len(t)-3] == "Less" && t[len(t)-2] == "Than" && t[len(t)-1] == "Equal" { - return Predicate{Field: strings.Join(t[:len(t)-3], ""), Comparator: ComparatorLessThanEqual} + return Predicate{Field: strings.Join(t[:len(t)-3], ""), Comparator: ComparatorLessThanEqual, ParamIndex: paramIndex} } if len(t) > 2 && t[len(t)-2] == "Greater" && t[len(t)-1] == "Than" { - return Predicate{Field: strings.Join(t[:len(t)-2], ""), Comparator: ComparatorGreaterThan} + return Predicate{Field: strings.Join(t[:len(t)-2], ""), Comparator: ComparatorGreaterThan, ParamIndex: paramIndex} } if len(t) > 3 && t[len(t)-3] == "Greater" && t[len(t)-2] == "Than" && t[len(t)-1] == "Equal" { - return Predicate{Field: strings.Join(t[:len(t)-3], ""), Comparator: ComparatorGreaterThanEqual} + return Predicate{Field: strings.Join(t[:len(t)-3], ""), Comparator: ComparatorGreaterThanEqual, ParamIndex: paramIndex} } if len(t) > 1 && t[len(t)-1] == "In" { - return Predicate{Field: strings.Join(t[:len(t)-1], ""), Comparator: ComparatorIn} + return Predicate{Field: strings.Join(t[:len(t)-1], ""), Comparator: ComparatorIn, ParamIndex: paramIndex} } if len(t) > 1 && t[len(t)-1] == "Between" { - return Predicate{Field: strings.Join(t[:len(t)-1], ""), Comparator: ComparatorBetween} + return Predicate{Field: strings.Join(t[:len(t)-1], ""), Comparator: ComparatorBetween, ParamIndex: paramIndex} } - return Predicate{Field: strings.Join(t, ""), Comparator: ComparatorEqual} + return Predicate{Field: strings.Join(t, ""), Comparator: ComparatorEqual, ParamIndex: paramIndex} } diff --git a/internal/spec/parser.go b/internal/spec/parser.go index a069e81..d39d7be 100644 --- a/internal/spec/parser.go +++ b/internal/spec/parser.go @@ -25,6 +25,8 @@ func (p interfaceMethodParser) Parse() (MethodSpec, error) { switch methodNameTokens[0] { case "Find": return p.parseFindMethod(methodNameTokens[1:]) + case "Update": + return p.parseUpdateMethod(methodNameTokens[1:]) case "Delete": return p.parseDeleteMethod(methodNameTokens[1:]) } @@ -41,12 +43,16 @@ func (p interfaceMethodParser) parseFindMethod(tokens []string) (MethodSpec, err return MethodSpec{}, err } - querySpec, err := p.parseQuery(tokens) + querySpec, err := p.parseQuery(tokens, 1) if err != nil { return MethodSpec{}, err } - if err := p.validateMethodSignature(querySpec); err != nil { + if err := p.validateContextParam(); err != nil { + return MethodSpec{}, err + } + + if err := p.validateQueryFromParams(p.Method.Params[1:], querySpec); err != nil { return MethodSpec{}, err } @@ -94,22 +100,94 @@ func (p interfaceMethodParser) extractFindReturns(returns []code.Type) (QueryMod return "", UnsupportedReturnError } +func (p interfaceMethodParser) parseUpdateMethod(tokens []string) (MethodSpec, error) { + if len(tokens) == 0 { + return MethodSpec{}, UnsupportedNameError + } + + mode, err := p.extractCountReturns(p.Method.Returns) + if err != nil { + return MethodSpec{}, err + } + + paramIndex := 1 + var fields []UpdateField + var aggregatedToken string + for i, token := range tokens { + if token == "By" || token == "All" { + tokens = tokens[i:] + break + } else if token != "And" { + aggregatedToken += token + } else if len(aggregatedToken) == 0 { + return MethodSpec{}, InvalidUpdateFieldsError + } else { + fields = append(fields, UpdateField{Name: aggregatedToken, ParamIndex: paramIndex}) + paramIndex++ + aggregatedToken = "" + } + } + if len(aggregatedToken) == 0 { + return MethodSpec{}, InvalidUpdateFieldsError + } + fields = append(fields, UpdateField{Name: aggregatedToken, ParamIndex: paramIndex}) + + querySpec, err := p.parseQuery(tokens, 1+len(fields)) + if err != nil { + return MethodSpec{}, err + } + + if err := p.validateContextParam(); err != nil { + return MethodSpec{}, err + } + + for _, field := range fields { + structField, ok := p.StructModel.Fields.ByName(field.Name) + if !ok { + return MethodSpec{}, StructFieldNotFoundError + } + + if structField.Type != p.Method.Params[field.ParamIndex].Type { + return MethodSpec{}, InvalidParamError + } + } + + if err := p.validateQueryFromParams(p.Method.Params[len(fields)+1:], querySpec); err != nil { + return MethodSpec{}, err + } + + return MethodSpec{ + Name: p.Method.Name, + Params: p.Method.Params, + Returns: p.Method.Returns, + Operation: UpdateOperation{ + Fields: fields, + Mode: mode, + Query: querySpec, + }, + }, nil +} + func (p interfaceMethodParser) parseDeleteMethod(tokens []string) (MethodSpec, error) { if len(tokens) == 0 { return MethodSpec{}, UnsupportedNameError } - mode, err := p.extractDeleteReturns(p.Method.Returns) + mode, err := p.extractCountReturns(p.Method.Returns) if err != nil { return MethodSpec{}, err } - querySpec, err := p.parseQuery(tokens) + querySpec, err := p.parseQuery(tokens, 1) if err != nil { return MethodSpec{}, err } - if err := p.validateMethodSignature(querySpec); err != nil { + if err := p.validateContextParam(); err != nil { + return MethodSpec{}, err + } + + if err := p.validateQueryFromParams(p.Method.Params[1:], querySpec); err != nil { return MethodSpec{}, err } @@ -124,7 +202,7 @@ func (p interfaceMethodParser) parseDeleteMethod(tokens []string) (MethodSpec, e }, nil } -func (p interfaceMethodParser) extractDeleteReturns(returns []code.Type) (QueryMode, error) { +func (p interfaceMethodParser) extractCountReturns(returns []code.Type) (QueryMode, error) { if len(returns) != 2 { return "", UnsupportedReturnError } @@ -146,7 +224,7 @@ func (p interfaceMethodParser) extractDeleteReturns(returns []code.Type) (QueryM return "", UnsupportedReturnError } -func (p interfaceMethodParser) parseQuery(tokens []string) (QuerySpec, error) { +func (p interfaceMethodParser) parseQuery(tokens []string, paramIndex int) (QuerySpec, error) { if len(tokens) == 0 { return QuerySpec{}, InvalidQueryError } @@ -176,11 +254,15 @@ func (p interfaceMethodParser) parseQuery(tokens []string) (QuerySpec, error) { return QuerySpec{}, InvalidQueryError } else if token == "And" && operator != OperatorOr { operator = OperatorAnd - predicates = append(predicates, aggregatedToken.ToPredicate()) + predicate := aggregatedToken.ToPredicate(paramIndex) + predicates = append(predicates, predicate) + paramIndex += predicate.Comparator.NumberOfArguments() aggregatedToken = predicateToken{} } else if token == "Or" && operator != OperatorAnd { operator = OperatorOr - predicates = append(predicates, aggregatedToken.ToPredicate()) + predicate := aggregatedToken.ToPredicate(paramIndex) + predicates = append(predicates, predicate) + paramIndex += predicate.Comparator.NumberOfArguments() aggregatedToken = predicateToken{} } else { return QuerySpec{}, InvalidQueryError @@ -189,22 +271,25 @@ func (p interfaceMethodParser) parseQuery(tokens []string) (QuerySpec, error) { if len(aggregatedToken) == 0 { return QuerySpec{}, InvalidQueryError } - predicates = append(predicates, aggregatedToken.ToPredicate()) + predicates = append(predicates, aggregatedToken.ToPredicate(paramIndex)) return QuerySpec{Operator: operator, Predicates: predicates}, nil } -func (p interfaceMethodParser) validateMethodSignature(querySpec QuerySpec) error { +func (p interfaceMethodParser) validateContextParam() error { contextType := code.ExternalType{PackageAlias: "context", Name: "Context"} if len(p.Method.Params) == 0 || p.Method.Params[0].Type != contextType { return ContextParamRequiredError } + return nil +} - if querySpec.NumberOfArguments()+1 != len(p.Method.Params) { +func (p interfaceMethodParser) validateQueryFromParams(params []code.Param, querySpec QuerySpec) error { + if querySpec.NumberOfArguments() != len(params) { return InvalidParamError } - currentParamIndex := 1 + var currentParamIndex int for _, predicate := range querySpec.Predicates { structField, ok := p.StructModel.Fields.ByName(predicate.Field) if !ok { @@ -212,7 +297,7 @@ func (p interfaceMethodParser) validateMethodSignature(querySpec QuerySpec) erro } for i := 0; i < predicate.Comparator.NumberOfArguments(); i++ { - if p.Method.Params[currentParamIndex].Type != predicate.Comparator.ArgumentTypeFromFieldType( + if params[currentParamIndex].Type != predicate.Comparator.ArgumentTypeFromFieldType( structField.Type) { return InvalidParamError } diff --git a/internal/spec/parser_test.go b/internal/spec/parser_test.go index 09b6230..f283492 100644 --- a/internal/spec/parser_test.go +++ b/internal/spec/parser_test.go @@ -35,9 +35,9 @@ var structModel = code.Struct{ } type ParseInterfaceMethodTestCase struct { - Name string - Method code.Method - ExpectedOutput spec.MethodSpec + Name string + Method code.Method + ExpectedOperation spec.Operation } func TestParseInterfaceMethod_Find(t *testing.T) { @@ -55,22 +55,11 @@ func TestParseInterfaceMethod_Find(t *testing.T) { code.SimpleType("error"), }, }, - ExpectedOutput: spec.MethodSpec{ - Name: "FindOneByID", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, - }, - Returns: []code.Type{ - code.PointerType{ContainedType: code.SimpleType("UserModel")}, - code.SimpleType("error"), - }, - Operation: spec.FindOperation{ - Mode: spec.QueryModeOne, - Query: spec.QuerySpec{Predicates: []spec.Predicate{ - {Field: "ID", Comparator: spec.ComparatorEqual}, - }}, - }, + ExpectedOperation: spec.FindOperation{ + Mode: spec.QueryModeOne, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + {Field: "ID", Comparator: spec.ComparatorEqual, ParamIndex: 1}, + }}, }, }, { @@ -86,22 +75,11 @@ func TestParseInterfaceMethod_Find(t *testing.T) { code.SimpleType("error"), }, }, - ExpectedOutput: spec.MethodSpec{ - Name: "FindOneByPhoneNumber", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.SimpleType("string")}, - }, - Returns: []code.Type{ - code.PointerType{ContainedType: code.SimpleType("UserModel")}, - code.SimpleType("error"), - }, - Operation: spec.FindOperation{ - Mode: spec.QueryModeOne, - Query: spec.QuerySpec{Predicates: []spec.Predicate{ - {Field: "PhoneNumber", Comparator: spec.ComparatorEqual}, - }}, - }, + ExpectedOperation: spec.FindOperation{ + Mode: spec.QueryModeOne, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + {Field: "PhoneNumber", Comparator: spec.ComparatorEqual, ParamIndex: 1}, + }}, }, }, { @@ -117,22 +95,11 @@ func TestParseInterfaceMethod_Find(t *testing.T) { code.SimpleType("error"), }, }, - ExpectedOutput: spec.MethodSpec{ - Name: "FindByCity", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.SimpleType("string")}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.SimpleType("error"), - }, - Operation: spec.FindOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{Predicates: []spec.Predicate{ - {Field: "City", Comparator: spec.ComparatorEqual}, - }}, - }, + ExpectedOperation: spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + {Field: "City", Comparator: spec.ComparatorEqual, ParamIndex: 1}, + }}, }, }, { @@ -147,18 +114,8 @@ func TestParseInterfaceMethod_Find(t *testing.T) { code.SimpleType("error"), }, }, - ExpectedOutput: spec.MethodSpec{ - Name: "FindAll", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.SimpleType("error"), - }, - Operation: spec.FindOperation{ - Mode: spec.QueryModeMany, - }, + ExpectedOperation: spec.FindOperation{ + Mode: spec.QueryModeMany, }, }, { @@ -175,25 +132,13 @@ func TestParseInterfaceMethod_Find(t *testing.T) { code.SimpleType("error"), }, }, - ExpectedOutput: spec.MethodSpec{ - Name: "FindByCityAndGender", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.SimpleType("string")}, - {Type: code.SimpleType("Gender")}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.SimpleType("error"), - }, - Operation: spec.FindOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{ - Operator: spec.OperatorAnd, - Predicates: []spec.Predicate{ - {Field: "City", Comparator: spec.ComparatorEqual}, - {Field: "Gender", Comparator: spec.ComparatorEqual}, - }, + ExpectedOperation: spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{ + Operator: spec.OperatorAnd, + Predicates: []spec.Predicate{ + {Field: "City", Comparator: spec.ComparatorEqual, ParamIndex: 1}, + {Field: "Gender", Comparator: spec.ComparatorEqual, ParamIndex: 2}, }, }, }, @@ -212,25 +157,13 @@ func TestParseInterfaceMethod_Find(t *testing.T) { code.SimpleType("error"), }, }, - ExpectedOutput: spec.MethodSpec{ - Name: "FindByCityOrGender", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.SimpleType("string")}, - {Type: code.SimpleType("Gender")}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.SimpleType("error"), - }, - Operation: spec.FindOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{ - Operator: spec.OperatorOr, - Predicates: []spec.Predicate{ - {Field: "City", Comparator: spec.ComparatorEqual}, - {Field: "Gender", Comparator: spec.ComparatorEqual}, - }, + ExpectedOperation: spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{ + Operator: spec.OperatorOr, + Predicates: []spec.Predicate{ + {Field: "City", Comparator: spec.ComparatorEqual, ParamIndex: 1}, + {Field: "Gender", Comparator: spec.ComparatorEqual, ParamIndex: 2}, }, }, }, @@ -248,22 +181,11 @@ func TestParseInterfaceMethod_Find(t *testing.T) { code.SimpleType("error"), }, }, - ExpectedOutput: spec.MethodSpec{ - Name: "FindByCityNot", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.SimpleType("string")}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.SimpleType("error"), - }, - Operation: spec.FindOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{Predicates: []spec.Predicate{ - {Field: "City", Comparator: spec.ComparatorNot}, - }}, - }, + ExpectedOperation: spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + {Field: "City", Comparator: spec.ComparatorNot, ParamIndex: 1}, + }}, }, }, { @@ -279,22 +201,11 @@ func TestParseInterfaceMethod_Find(t *testing.T) { code.SimpleType("error"), }, }, - ExpectedOutput: spec.MethodSpec{ - Name: "FindByAgeLessThan", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.SimpleType("int")}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.SimpleType("error"), - }, - Operation: spec.FindOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{Predicates: []spec.Predicate{ - {Field: "Age", Comparator: spec.ComparatorLessThan}, - }}, - }, + ExpectedOperation: spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + {Field: "Age", Comparator: spec.ComparatorLessThan, ParamIndex: 1}, + }}, }, }, { @@ -310,22 +221,11 @@ func TestParseInterfaceMethod_Find(t *testing.T) { code.SimpleType("error"), }, }, - ExpectedOutput: spec.MethodSpec{ - Name: "FindByAgeLessThanEqual", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.SimpleType("int")}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.SimpleType("error"), - }, - Operation: spec.FindOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{Predicates: []spec.Predicate{ - {Field: "Age", Comparator: spec.ComparatorLessThanEqual}, - }}, - }, + ExpectedOperation: spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + {Field: "Age", Comparator: spec.ComparatorLessThanEqual, ParamIndex: 1}, + }}, }, }, { @@ -341,22 +241,11 @@ func TestParseInterfaceMethod_Find(t *testing.T) { code.SimpleType("error"), }, }, - ExpectedOutput: spec.MethodSpec{ - Name: "FindByAgeGreaterThan", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.SimpleType("int")}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.SimpleType("error"), - }, - Operation: spec.FindOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{Predicates: []spec.Predicate{ - {Field: "Age", Comparator: spec.ComparatorGreaterThan}, - }}, - }, + ExpectedOperation: spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + {Field: "Age", Comparator: spec.ComparatorGreaterThan, ParamIndex: 1}, + }}, }, }, { @@ -372,22 +261,11 @@ func TestParseInterfaceMethod_Find(t *testing.T) { code.SimpleType("error"), }, }, - ExpectedOutput: spec.MethodSpec{ - Name: "FindByAgeGreaterThanEqual", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.SimpleType("int")}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.SimpleType("error"), - }, - Operation: spec.FindOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{Predicates: []spec.Predicate{ - {Field: "Age", Comparator: spec.ComparatorGreaterThanEqual}, - }}, - }, + ExpectedOperation: spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + {Field: "Age", Comparator: spec.ComparatorGreaterThanEqual, ParamIndex: 1}, + }}, }, }, { @@ -404,23 +282,11 @@ func TestParseInterfaceMethod_Find(t *testing.T) { code.SimpleType("error"), }, }, - ExpectedOutput: spec.MethodSpec{ - Name: "FindByAgeBetween", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.SimpleType("int")}, - {Type: code.SimpleType("int")}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.SimpleType("error"), - }, - Operation: spec.FindOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{Predicates: []spec.Predicate{ - {Field: "Age", Comparator: spec.ComparatorBetween}, - }}, - }, + ExpectedOperation: spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + {Field: "Age", Comparator: spec.ComparatorBetween, ParamIndex: 1}, + }}, }, }, { @@ -436,22 +302,11 @@ func TestParseInterfaceMethod_Find(t *testing.T) { code.SimpleType("error"), }, }, - ExpectedOutput: spec.MethodSpec{ - Name: "FindByCityIn", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.ArrayType{ContainedType: code.SimpleType("string")}}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.SimpleType("error"), - }, - Operation: spec.FindOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{Predicates: []spec.Predicate{ - {Field: "City", Comparator: spec.ComparatorIn}, - }}, - }, + ExpectedOperation: spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + {Field: "City", Comparator: spec.ComparatorIn, ParamIndex: 1}, + }}, }, }, } @@ -463,8 +318,112 @@ func TestParseInterfaceMethod_Find(t *testing.T) { if err != nil { t.Errorf("Error = %s", err) } - if !reflect.DeepEqual(actualSpec, testCase.ExpectedOutput) { - t.Errorf("Expected = %v\nReceived = %v", testCase.ExpectedOutput, actualSpec) + expectedOutput := spec.MethodSpec{ + Name: testCase.Method.Name, + Params: testCase.Method.Params, + Returns: testCase.Method.Returns, + Operation: testCase.ExpectedOperation, + } + if !reflect.DeepEqual(actualSpec, expectedOutput) { + t.Errorf("Expected = %v\nReceived = %v", expectedOutput, actualSpec) + } + }) + } +} + +func TestParseInterfaceMethod_Update(t *testing.T) { + testTable := []ParseInterfaceMethodTestCase{ + { + Name: "UpdateArgByArg one method", + Method: code.Method{ + Name: "UpdateGenderByID", + Params: []code.Param{ + {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Type: code.SimpleType("Gender")}, + {Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, + }, + Returns: []code.Type{ + code.SimpleType("bool"), + code.SimpleType("error"), + }, + }, + ExpectedOperation: spec.UpdateOperation{ + Fields: []spec.UpdateField{ + {Name: "Gender", ParamIndex: 1}, + }, + Mode: spec.QueryModeOne, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + {Field: "ID", Comparator: spec.ComparatorEqual, ParamIndex: 2}, + }}, + }, + }, + { + Name: "UpdateArgByArg many method", + Method: code.Method{ + Name: "UpdateGenderByID", + Params: []code.Param{ + {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Type: code.SimpleType("Gender")}, + {Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, + }, + Returns: []code.Type{ + code.SimpleType("int"), + code.SimpleType("error"), + }, + }, + ExpectedOperation: spec.UpdateOperation{ + Fields: []spec.UpdateField{ + {Name: "Gender", ParamIndex: 1}, + }, + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + {Field: "ID", Comparator: spec.ComparatorEqual, ParamIndex: 2}, + }}, + }, + }, + { + Name: "UpdateArgAndArgByArg method", + Method: code.Method{ + Name: "UpdateGenderAndCityByID", + Params: []code.Param{ + {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Type: code.SimpleType("Gender")}, + {Type: code.SimpleType("string")}, + {Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, + }, + Returns: []code.Type{ + code.SimpleType("int"), + code.SimpleType("error"), + }, + }, + ExpectedOperation: spec.UpdateOperation{ + Fields: []spec.UpdateField{ + {Name: "Gender", ParamIndex: 1}, + {Name: "City", ParamIndex: 2}, + }, + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + {Field: "ID", Comparator: spec.ComparatorEqual, ParamIndex: 3}, + }}, + }, + }, + } + + for _, testCase := range testTable { + t.Run(testCase.Name, func(t *testing.T) { + actualSpec, err := spec.ParseInterfaceMethod(structModel, testCase.Method) + + if err != nil { + t.Errorf("Error = %s", err) + } + expectedOutput := spec.MethodSpec{ + Name: testCase.Method.Name, + Params: testCase.Method.Params, + Returns: testCase.Method.Returns, + Operation: testCase.ExpectedOperation, + } + if !reflect.DeepEqual(actualSpec, expectedOutput) { + t.Errorf("Expected = %v\nReceived = %v", expectedOutput, actualSpec) } }) } @@ -485,22 +444,11 @@ func TestParseInterfaceMethod_Delete(t *testing.T) { code.SimpleType("error"), }, }, - ExpectedOutput: spec.MethodSpec{ - Name: "DeleteOneByID", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, - }, - Returns: []code.Type{ - code.SimpleType("bool"), - code.SimpleType("error"), - }, - Operation: spec.DeleteOperation{ - Mode: spec.QueryModeOne, - Query: spec.QuerySpec{Predicates: []spec.Predicate{ - {Field: "ID", Comparator: spec.ComparatorEqual}, - }}, - }, + ExpectedOperation: spec.DeleteOperation{ + Mode: spec.QueryModeOne, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + {Field: "ID", Comparator: spec.ComparatorEqual, ParamIndex: 1}, + }}, }, }, { @@ -516,22 +464,11 @@ func TestParseInterfaceMethod_Delete(t *testing.T) { code.SimpleType("error"), }, }, - ExpectedOutput: spec.MethodSpec{ - Name: "DeleteOneByPhoneNumber", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.SimpleType("string")}, - }, - Returns: []code.Type{ - code.SimpleType("bool"), - code.SimpleType("error"), - }, - Operation: spec.DeleteOperation{ - Mode: spec.QueryModeOne, - Query: spec.QuerySpec{Predicates: []spec.Predicate{ - {Field: "PhoneNumber", Comparator: spec.ComparatorEqual}, - }}, - }, + ExpectedOperation: spec.DeleteOperation{ + Mode: spec.QueryModeOne, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + {Field: "PhoneNumber", Comparator: spec.ComparatorEqual, ParamIndex: 1}, + }}, }, }, { @@ -547,22 +484,11 @@ func TestParseInterfaceMethod_Delete(t *testing.T) { code.SimpleType("error"), }, }, - ExpectedOutput: spec.MethodSpec{ - Name: "DeleteByCity", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.SimpleType("string")}, - }, - Returns: []code.Type{ - code.SimpleType("int"), - code.SimpleType("error"), - }, - Operation: spec.DeleteOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{Predicates: []spec.Predicate{ - {Field: "City", Comparator: spec.ComparatorEqual}, - }}, - }, + ExpectedOperation: spec.DeleteOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + {Field: "City", Comparator: spec.ComparatorEqual, ParamIndex: 1}, + }}, }, }, { @@ -577,18 +503,8 @@ func TestParseInterfaceMethod_Delete(t *testing.T) { code.SimpleType("error"), }, }, - ExpectedOutput: spec.MethodSpec{ - Name: "DeleteAll", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - }, - Returns: []code.Type{ - code.SimpleType("int"), - code.SimpleType("error"), - }, - Operation: spec.DeleteOperation{ - Mode: spec.QueryModeMany, - }, + ExpectedOperation: spec.DeleteOperation{ + Mode: spec.QueryModeMany, }, }, { @@ -605,25 +521,13 @@ func TestParseInterfaceMethod_Delete(t *testing.T) { code.SimpleType("error"), }, }, - ExpectedOutput: spec.MethodSpec{ - Name: "DeleteByCityAndGender", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.SimpleType("string")}, - {Type: code.SimpleType("Gender")}, - }, - Returns: []code.Type{ - code.SimpleType("int"), - code.SimpleType("error"), - }, - Operation: spec.DeleteOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{ - Operator: spec.OperatorAnd, - Predicates: []spec.Predicate{ - {Field: "City", Comparator: spec.ComparatorEqual}, - {Field: "Gender", Comparator: spec.ComparatorEqual}, - }, + ExpectedOperation: spec.DeleteOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{ + Operator: spec.OperatorAnd, + Predicates: []spec.Predicate{ + {Field: "City", Comparator: spec.ComparatorEqual, ParamIndex: 1}, + {Field: "Gender", Comparator: spec.ComparatorEqual, ParamIndex: 2}, }, }, }, @@ -642,25 +546,13 @@ func TestParseInterfaceMethod_Delete(t *testing.T) { code.SimpleType("error"), }, }, - ExpectedOutput: spec.MethodSpec{ - Name: "DeleteByCityOrGender", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.SimpleType("string")}, - {Type: code.SimpleType("Gender")}, - }, - Returns: []code.Type{ - code.SimpleType("int"), - code.SimpleType("error"), - }, - Operation: spec.DeleteOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{ - Operator: spec.OperatorOr, - Predicates: []spec.Predicate{ - {Field: "City", Comparator: spec.ComparatorEqual}, - {Field: "Gender", Comparator: spec.ComparatorEqual}, - }, + ExpectedOperation: spec.DeleteOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{ + Operator: spec.OperatorOr, + Predicates: []spec.Predicate{ + {Field: "City", Comparator: spec.ComparatorEqual, ParamIndex: 1}, + {Field: "Gender", Comparator: spec.ComparatorEqual, ParamIndex: 2}, }, }, }, @@ -678,22 +570,11 @@ func TestParseInterfaceMethod_Delete(t *testing.T) { code.SimpleType("error"), }, }, - ExpectedOutput: spec.MethodSpec{ - Name: "DeleteByCityNot", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.SimpleType("string")}, - }, - Returns: []code.Type{ - code.SimpleType("int"), - code.SimpleType("error"), - }, - Operation: spec.DeleteOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{Predicates: []spec.Predicate{ - {Field: "City", Comparator: spec.ComparatorNot}, - }}, - }, + ExpectedOperation: spec.DeleteOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + {Field: "City", Comparator: spec.ComparatorNot, ParamIndex: 1}, + }}, }, }, { @@ -709,22 +590,11 @@ func TestParseInterfaceMethod_Delete(t *testing.T) { code.SimpleType("error"), }, }, - ExpectedOutput: spec.MethodSpec{ - Name: "DeleteByAgeLessThan", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.SimpleType("int")}, - }, - Returns: []code.Type{ - code.SimpleType("int"), - code.SimpleType("error"), - }, - Operation: spec.DeleteOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{Predicates: []spec.Predicate{ - {Field: "Age", Comparator: spec.ComparatorLessThan}, - }}, - }, + ExpectedOperation: spec.DeleteOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + {Field: "Age", Comparator: spec.ComparatorLessThan, ParamIndex: 1}, + }}, }, }, { @@ -740,22 +610,11 @@ func TestParseInterfaceMethod_Delete(t *testing.T) { code.SimpleType("error"), }, }, - ExpectedOutput: spec.MethodSpec{ - Name: "DeleteByAgeLessThanEqual", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.SimpleType("int")}, - }, - Returns: []code.Type{ - code.SimpleType("int"), - code.SimpleType("error"), - }, - Operation: spec.DeleteOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{Predicates: []spec.Predicate{ - {Field: "Age", Comparator: spec.ComparatorLessThanEqual}, - }}, - }, + ExpectedOperation: spec.DeleteOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + {Field: "Age", Comparator: spec.ComparatorLessThanEqual, ParamIndex: 1}, + }}, }, }, { @@ -771,22 +630,11 @@ func TestParseInterfaceMethod_Delete(t *testing.T) { code.SimpleType("error"), }, }, - ExpectedOutput: spec.MethodSpec{ - Name: "DeleteByAgeGreaterThan", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.SimpleType("int")}, - }, - Returns: []code.Type{ - code.SimpleType("int"), - code.SimpleType("error"), - }, - Operation: spec.DeleteOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{Predicates: []spec.Predicate{ - {Field: "Age", Comparator: spec.ComparatorGreaterThan}, - }}, - }, + ExpectedOperation: spec.DeleteOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + {Field: "Age", Comparator: spec.ComparatorGreaterThan, ParamIndex: 1}, + }}, }, }, { @@ -802,22 +650,11 @@ func TestParseInterfaceMethod_Delete(t *testing.T) { code.SimpleType("error"), }, }, - ExpectedOutput: spec.MethodSpec{ - Name: "DeleteByAgeGreaterThanEqual", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.SimpleType("int")}, - }, - Returns: []code.Type{ - code.SimpleType("int"), - code.SimpleType("error"), - }, - Operation: spec.DeleteOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{Predicates: []spec.Predicate{ - {Field: "Age", Comparator: spec.ComparatorGreaterThanEqual}, - }}, - }, + ExpectedOperation: spec.DeleteOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + {Field: "Age", Comparator: spec.ComparatorGreaterThanEqual, ParamIndex: 1}, + }}, }, }, { @@ -834,23 +671,11 @@ func TestParseInterfaceMethod_Delete(t *testing.T) { code.SimpleType("error"), }, }, - ExpectedOutput: spec.MethodSpec{ - Name: "DeleteByAgeBetween", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.SimpleType("int")}, - {Type: code.SimpleType("int")}, - }, - Returns: []code.Type{ - code.SimpleType("int"), - code.SimpleType("error"), - }, - Operation: spec.DeleteOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{Predicates: []spec.Predicate{ - {Field: "Age", Comparator: spec.ComparatorBetween}, - }}, - }, + ExpectedOperation: spec.DeleteOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + {Field: "Age", Comparator: spec.ComparatorBetween, ParamIndex: 1}, + }}, }, }, { @@ -866,22 +691,11 @@ func TestParseInterfaceMethod_Delete(t *testing.T) { code.SimpleType("error"), }, }, - ExpectedOutput: spec.MethodSpec{ - Name: "DeleteByCityIn", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.ArrayType{ContainedType: code.SimpleType("string")}}, - }, - Returns: []code.Type{ - code.SimpleType("int"), - code.SimpleType("error"), - }, - Operation: spec.DeleteOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{Predicates: []spec.Predicate{ - {Field: "City", Comparator: spec.ComparatorIn}, - }}, - }, + ExpectedOperation: spec.DeleteOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + {Field: "City", Comparator: spec.ComparatorIn, ParamIndex: 1}, + }}, }, }, } @@ -893,8 +707,14 @@ func TestParseInterfaceMethod_Delete(t *testing.T) { if err != nil { t.Errorf("Error = %s", err) } - if !reflect.DeepEqual(actualSpec, testCase.ExpectedOutput) { - t.Errorf("Expected = %v\nReceived = %v", testCase.ExpectedOutput, actualSpec) + expectedOutput := spec.MethodSpec{ + Name: testCase.Method.Name, + Params: testCase.Method.Params, + Returns: testCase.Method.Returns, + Operation: testCase.ExpectedOperation, + } + if !reflect.DeepEqual(actualSpec, expectedOutput) { + t.Errorf("Expected = %v\nReceived = %v", expectedOutput, actualSpec) } }) } @@ -1091,6 +911,158 @@ func TestParseInterfaceMethod_Find_Invalid(t *testing.T) { } } +func TestParseInterfaceMethod_Update_Invalid(t *testing.T) { + testTable := []ParseInterfaceMethodInvalidTestCase{ + { + Name: "unsupported update method name", + Method: code.Method{ + Name: "Update", + }, + ExpectedError: spec.UnsupportedNameError, + }, + { + Name: "invalid number of returns", + Method: code.Method{ + Name: "UpdateAgeByID", + Returns: []code.Type{ + code.SimpleType("bool"), + code.SimpleType("int"), + code.SimpleType("error"), + }, + }, + ExpectedError: spec.UnsupportedReturnError, + }, + { + Name: "unsupported return values from find method", + Method: code.Method{ + Name: "UpdateAgeByID", + Returns: []code.Type{ + code.SimpleType("float64"), + code.SimpleType("error"), + }, + }, + ExpectedError: spec.UnsupportedReturnError, + }, + { + Name: "error return not provided", + Method: code.Method{ + Name: "UpdateAgeByID", + Returns: []code.Type{ + code.SimpleType("bool"), + code.SimpleType("bool"), + }, + }, + ExpectedError: spec.UnsupportedReturnError, + }, + { + Name: "update with no field provided", + Method: code.Method{ + Name: "UpdateByID", + Returns: []code.Type{ + code.SimpleType("bool"), + code.SimpleType("error"), + }, + }, + ExpectedError: spec.InvalidUpdateFieldsError, + }, + { + Name: "misplaced And token in update fields", + Method: code.Method{ + Name: "UpdateAgeAndAndGenderByID", + Returns: []code.Type{ + code.SimpleType("bool"), + code.SimpleType("error"), + }, + }, + ExpectedError: spec.InvalidUpdateFieldsError, + }, + { + Name: "ambiguous query", + Method: code.Method{ + Name: "UpdateAgeByIDAndUsernameOrGender", + Returns: []code.Type{ + code.SimpleType("int"), + code.SimpleType("error"), + }, + }, + ExpectedError: spec.InvalidQueryError, + }, + { + Name: "no context parameter", + Method: code.Method{ + Name: "UpdateAgeByGender", + Params: []code.Param{ + {Type: code.SimpleType("int")}, + {Type: code.SimpleType("Gender")}, + }, + Returns: []code.Type{ + code.SimpleType("int"), + code.SimpleType("error"), + }, + }, + ExpectedError: spec.ContextParamRequiredError, + }, + { + Name: "struct field not found in update fields", + Method: code.Method{ + Name: "UpdateCountryByGender", + Params: []code.Param{ + {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Type: code.SimpleType("string")}, + {Type: code.SimpleType("Gender")}, + }, + Returns: []code.Type{ + code.SimpleType("int"), + code.SimpleType("error"), + }, + }, + ExpectedError: spec.StructFieldNotFoundError, + }, + { + Name: "struct field does not match parameter in update fields", + Method: code.Method{ + Name: "UpdateAgeByGender", + Params: []code.Param{ + {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Type: code.SimpleType("float64")}, + {Type: code.SimpleType("Gender")}, + }, + Returns: []code.Type{ + code.SimpleType("int"), + code.SimpleType("error"), + }, + }, + ExpectedError: spec.InvalidParamError, + }, + { + Name: "struct field does not match parameter in query", + Method: code.Method{ + Name: "UpdateAgeByGender", + Params: []code.Param{ + {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Type: code.SimpleType("int")}, + {Type: code.SimpleType("string")}, + }, + Returns: []code.Type{ + code.SimpleType("int"), + code.SimpleType("error"), + }, + }, + ExpectedError: spec.InvalidParamError, + }, + } + + for _, testCase := range testTable { + t.Run(testCase.Name, func(t *testing.T) { + _, err := spec.ParseInterfaceMethod(structModel, testCase.Method) + + if err != testCase.ExpectedError { + t.Errorf("\nExpected = %v\nReceived = %v", testCase.ExpectedError, err) + } + }) + } +} + func TestParseInterfaceMethod_Delete_Invalid(t *testing.T) { testTable := []ParseInterfaceMethodInvalidTestCase{ {