diff --git a/codecov.yml b/codecov.yml index 83c628e..b93815e 100644 --- a/codecov.yml +++ b/codecov.yml @@ -2,7 +2,7 @@ coverage: status: project: default: - target: 70% + target: 75% threshold: 5% patch: default: diff --git a/internal/generator/generator.go b/internal/generator/generator.go new file mode 100644 index 0000000..42cc54d --- /dev/null +++ b/internal/generator/generator.go @@ -0,0 +1,75 @@ +package generator + +import ( + "bytes" + "html/template" + + "github.com/sunboyy/repogen/internal/code" + "github.com/sunboyy/repogen/internal/mongo" + "github.com/sunboyy/repogen/internal/spec" + "golang.org/x/tools/imports" +) + +// GenerateRepository generates repository implementation from repository interface specification +func GenerateRepository(packageName string, structModel code.Struct, interfaceName string, + methodSpecs []spec.MethodSpec) (string, error) { + + repositoryGenerator := repositoryGenerator{ + PackageName: packageName, + StructModel: structModel, + InterfaceName: interfaceName, + MethodSpecs: methodSpecs, + Generator: mongo.NewGenerator(structModel, interfaceName), + } + + return repositoryGenerator.Generate() +} + +type repositoryGenerator struct { + PackageName string + StructModel code.Struct + InterfaceName string + MethodSpecs []spec.MethodSpec + Generator mongo.RepositoryGenerator +} + +func (g repositoryGenerator) Generate() (string, error) { + buffer := new(bytes.Buffer) + if err := g.generateBase(buffer); err != nil { + return "", err + } + + if err := g.Generator.GenerateConstructor(buffer); err != nil { + return "", err + } + + for _, method := range g.MethodSpecs { + if err := g.Generator.GenerateMethod(method, buffer); err != nil { + return "", err + } + } + + formattedCode, err := imports.Process("", buffer.Bytes(), nil) + if err != nil { + return "", err + } + + return string(formattedCode), nil +} + +func (g repositoryGenerator) generateBase(buffer *bytes.Buffer) error { + tmpl, err := template.New("file_base").Parse(baseTemplate) + if err != nil { + return err + } + + tmplData := baseTemplateData{ + PackageName: g.PackageName, + } + + if err := tmpl.Execute(buffer, tmplData); err != nil { + return err + } + + return nil +} diff --git a/internal/generator/generator_test.go b/internal/generator/generator_test.go new file mode 100644 index 0000000..00be3d7 --- /dev/null +++ b/internal/generator/generator_test.go @@ -0,0 +1,278 @@ +package generator_test + +import ( + "strings" + "testing" + + "github.com/sunboyy/repogen/internal/code" + "github.com/sunboyy/repogen/internal/generator" + "github.com/sunboyy/repogen/internal/spec" +) + +func TestGenerateMongoRepository(t *testing.T) { + userModel := code.Struct{ + Name: "UserModel", + Fields: code.StructFields{ + { + Name: "ID", + Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}, + Tags: map[string][]string{"bson": {"_id", "omitempty"}}, + }, + { + Name: "Username", + Type: code.SimpleType("string"), + Tags: map[string][]string{"bson": {"username"}}, + }, + { + Name: "Gender", + Type: code.SimpleType("Gender"), + Tags: map[string][]string{"bson": {"gender"}}, + }, + { + Name: "Age", + Type: code.SimpleType("int"), + Tags: map[string][]string{"bson": {"age"}}, + }, + }, + } + methods := []spec.MethodSpec{ + // test find: One mode + { + Name: "FindByID", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Name: "id", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, + }, + Returns: []code.Type{code.PointerType{ContainedType: code.SimpleType("UserModel")}, code.SimpleType("error")}, + Operation: spec.FindOperation{ + Mode: spec.QueryModeOne, + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + {Field: "ID", Comparator: spec.ComparatorEqual}, + }, + }, + }, + }, + // test find: Many mode, And operator, NOT and LessThan comparator + { + Name: "FindByGenderNotAndAgeLessThan", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Name: "gender", Type: code.SimpleType("Gender")}, + {Name: "age", Type: code.SimpleType("int")}, + }, + Returns: []code.Type{ + code.PointerType{ContainedType: code.SimpleType("UserModel")}, + code.SimpleType("error"), + }, + Operation: spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{ + Operator: spec.OperatorAnd, + Predicates: []spec.Predicate{ + {Field: "Gender", Comparator: spec.ComparatorNot}, + {Field: "Age", Comparator: spec.ComparatorLessThan}, + }, + }, + }, + }, + { + Name: "FindByAgeLessThanEqual", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Name: "age", Type: code.SimpleType("int")}, + }, + Returns: []code.Type{ + code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + code.SimpleType("error"), + }, + Operation: spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + {Field: "Age", Comparator: spec.ComparatorLessThanEqual}, + }, + }, + }, + }, + { + Name: "FindByAgeGreaterThan", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Name: "age", Type: code.SimpleType("int")}, + }, + Returns: []code.Type{ + code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + code.SimpleType("error"), + }, + Operation: spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + {Field: "Age", Comparator: spec.ComparatorGreaterThan}, + }, + }, + }, + }, + { + Name: "FindByAgeGreaterThanEqual", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Name: "age", Type: code.SimpleType("int")}, + }, + Returns: []code.Type{ + code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + code.SimpleType("error"), + }, + Operation: spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + {Field: "Age", Comparator: spec.ComparatorGreaterThanEqual}, + }, + }, + }, + }, + { + Name: "FindByGenderOrAge", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Name: "gender", Type: code.SimpleType("Gender")}, + {Name: "age", Type: code.SimpleType("int")}, + }, + Returns: []code.Type{ + code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + code.SimpleType("error"), + }, + Operation: spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{ + Operator: spec.OperatorOr, + Predicates: []spec.Predicate{ + {Field: "Gender", Comparator: spec.ComparatorEqual}, + {Field: "Age", Comparator: spec.ComparatorEqual}, + }, + }, + }, + }, + } + + code, err := generator.GenerateRepository("user", userModel, "UserRepository", methods) + + if err != nil { + t.Error(err) + } + expectedCode := `// Code generated by repogen. DO NOT EDIT. +package user + +import ( + "context" + + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/bson/primitive" + "go.mongodb.org/mongo-driver/mongo" +) + +func NewUserRepository(collection *mongo.Collection) UserRepository { + return &UserRepositoryMongo{ + collection: collection, + } +} + +type UserRepositoryMongo struct { + collection *mongo.Collection +} + +func (r *UserRepositoryMongo) FindByID(ctx context.Context, arg0 primitive.ObjectID) (*UserModel, error) { + var entity UserModel + if err := r.collection.FindOne(ctx, bson.M{ + "_id": arg0, + }).Decode(&entity); err != nil { + return nil, err + } + return &entity, nil +} + +func (r *UserRepositoryMongo) FindByGenderNotAndAgeLessThan(ctx context.Context, arg0 Gender, arg1 int) (*UserModel, error) { + cursor, err := r.collection.Find(ctx, bson.M{ + "gender": bson.M{"$ne": arg0}, + "age": bson.M{"$lt": arg1}, + }) + if err != nil { + return nil, err + } + var entities []*UserModel + if err := cursor.All(ctx, &entities); err != nil { + return nil, err + } + return entities, nil +} + +func (r *UserRepositoryMongo) FindByAgeLessThanEqual(ctx context.Context, arg0 int) ([]*UserModel, error) { + cursor, err := r.collection.Find(ctx, bson.M{ + "age": bson.M{"$lte": arg0}, + }) + if err != nil { + return nil, err + } + var entities []*UserModel + if err := cursor.All(ctx, &entities); err != nil { + return nil, err + } + return entities, nil +} + +func (r *UserRepositoryMongo) FindByAgeGreaterThan(ctx context.Context, arg0 int) ([]*UserModel, error) { + cursor, err := r.collection.Find(ctx, bson.M{ + "age": bson.M{"$gt": arg0}, + }) + if err != nil { + return nil, err + } + var entities []*UserModel + if err := cursor.All(ctx, &entities); err != nil { + return nil, err + } + return entities, nil +} + +func (r *UserRepositoryMongo) FindByAgeGreaterThanEqual(ctx context.Context, arg0 int) ([]*UserModel, error) { + cursor, err := r.collection.Find(ctx, bson.M{ + "age": bson.M{"$gte": arg0}, + }) + if err != nil { + return nil, err + } + var entities []*UserModel + if err := cursor.All(ctx, &entities); err != nil { + return nil, err + } + return entities, nil +} + +func (r *UserRepositoryMongo) FindByGenderOrAge(ctx context.Context, arg0 Gender, arg1 int) ([]*UserModel, error) { + cursor, err := r.collection.Find(ctx, bson.M{ + "$or": []bson.M{ + {"gender": arg0}, + {"age": arg1}, + }, + }) + if err != nil { + return nil, err + } + var entities []*UserModel + if err := cursor.All(ctx, &entities); err != nil { + return nil, err + } + return entities, nil +} +` + expectedCodeLines := strings.Split(expectedCode, "\n") + actualCodeLines := strings.Split(code, "\n") + + for i, line := range expectedCodeLines { + if line != actualCodeLines[i] { + t.Errorf("On line %d\nExpected = %v\nActual = %v", i, line, actualCodeLines[i]) + } + } +} diff --git a/internal/generator/templates.go b/internal/generator/templates.go new file mode 100644 index 0000000..3fb9b3d --- /dev/null +++ b/internal/generator/templates.go @@ -0,0 +1,9 @@ +package generator + +const baseTemplate = `// Code generated by repogen. DO NOT EDIT. +package {{.PackageName}} +` + +type baseTemplateData struct { + PackageName string +} diff --git a/internal/mongo/generator.go b/internal/mongo/generator.go index c127871..31ea396 100644 --- a/internal/mongo/generator.go +++ b/internal/mongo/generator.go @@ -4,76 +4,37 @@ import ( "bytes" "errors" "fmt" + "io" "text/template" "github.com/sunboyy/repogen/internal/code" "github.com/sunboyy/repogen/internal/spec" - "golang.org/x/tools/imports" ) -// GenerateMongoRepository generates mongodb repository -func GenerateMongoRepository(packageName string, structModel code.Struct, intf code.Interface) (string, error) { - var methodSpecs []spec.MethodSpec - for _, method := range intf.Methods { - methodSpec, err := spec.ParseInterfaceMethod(structModel, method) - if err != nil { - return "", err - } - methodSpecs = append(methodSpecs, methodSpec) - } - - generator := mongoRepositoryGenerator{ - PackageName: packageName, +// NewGenerator creates a new instance of MongoDB repository generator +func NewGenerator(structModel code.Struct, interfaceName string) RepositoryGenerator { + return RepositoryGenerator{ StructModel: structModel, - InterfaceName: intf.Name, - MethodSpecs: methodSpecs, + InterfaceName: interfaceName, } - - output, err := generator.Generate() - if err != nil { - return "", err - } - - return output, nil } -type mongoRepositoryGenerator struct { - PackageName string +// RepositoryGenerator provides repository constructor and method generation from provided specification +type RepositoryGenerator struct { StructModel code.Struct InterfaceName string - MethodSpecs []spec.MethodSpec } -func (g mongoRepositoryGenerator) Generate() (string, error) { - buffer := new(bytes.Buffer) - if err := g.generateBaseContent(buffer); err != nil { - return "", err - } - - for _, method := range g.MethodSpecs { - if err := g.generateMethod(buffer, method); err != nil { - return "", err - } - } - - newOutput, err := imports.Process("", buffer.Bytes(), nil) - if err != nil { - return "", err - } - - return string(newOutput), nil -} - -func (g mongoRepositoryGenerator) generateBaseContent(buffer *bytes.Buffer) error { - tmpl, err := template.New("mongo_repository_base").Parse(baseTemplate) +// GenerateConstructor generates mongo repository struct implementation and constructor for the struct +func (g RepositoryGenerator) GenerateConstructor(buffer io.Writer) error { + tmpl, err := template.New("mongo_repository_base").Parse(constructorTemplate) if err != nil { return err } - tmplData := mongoBaseTemplateData{ - PackageName: g.PackageName, - InterfaceName: g.InterfaceName, - StructName: g.structName(), + tmplData := mongoConstructorTemplateData{ + InterfaceName: g.InterfaceName, + ImplStructName: g.structName(), } if err := tmpl.Execute(buffer, tmplData); err != nil { @@ -83,27 +44,28 @@ func (g mongoRepositoryGenerator) generateBaseContent(buffer *bytes.Buffer) erro return nil } -func (g mongoRepositoryGenerator) generateMethod(buffer *bytes.Buffer, method spec.MethodSpec) error { +// GenerateMethod generates implementation of from provided method specification +func (g RepositoryGenerator) GenerateMethod(methodSpec spec.MethodSpec, buffer io.Writer) error { tmpl, err := template.New("mongo_repository_method").Parse(methodTemplate) if err != nil { return err } - implementation, err := g.generateMethodImplementation(method) + implementation, err := g.generateMethodImplementation(methodSpec) if err != nil { return err } var paramTypes []code.Type - for _, param := range method.Params[1:] { + for _, param := range methodSpec.Params[1:] { paramTypes = append(paramTypes, param.Type) } tmplData := mongoMethodTemplateData{ StructName: g.structName(), - MethodName: method.Name, + MethodName: methodSpec.Name, ParamTypes: paramTypes, - ReturnTypes: method.Returns, + ReturnTypes: methodSpec.Returns, Implementation: implementation, } @@ -114,7 +76,7 @@ func (g mongoRepositoryGenerator) generateMethod(buffer *bytes.Buffer, method sp return nil } -func (g mongoRepositoryGenerator) generateMethodImplementation(methodSpec spec.MethodSpec) (string, error) { +func (g RepositoryGenerator) generateMethodImplementation(methodSpec spec.MethodSpec) (string, error) { switch operation := methodSpec.Operation.(type) { case spec.FindOperation: return g.generateFindImplementation(operation) @@ -123,7 +85,7 @@ func (g mongoRepositoryGenerator) generateMethodImplementation(methodSpec spec.M return "", errors.New("method spec not supported") } -func (g mongoRepositoryGenerator) generateFindImplementation(operation spec.FindOperation) (string, error) { +func (g RepositoryGenerator) generateFindImplementation(operation spec.FindOperation) (string, error) { buffer := new(bytes.Buffer) var predicates []predicate @@ -172,6 +134,6 @@ func (g mongoRepositoryGenerator) generateFindImplementation(operation spec.Find return buffer.String(), nil } -func (g mongoRepositoryGenerator) structName() string { +func (g RepositoryGenerator) structName() string { return g.InterfaceName + "Mongo" } diff --git a/internal/mongo/generator_test.go b/internal/mongo/generator_test.go index 51ac3a9..a581d7e 100644 --- a/internal/mongo/generator_test.go +++ b/internal/mongo/generator_test.go @@ -1,14 +1,36 @@ package mongo_test import ( - "strings" + "bytes" "testing" "github.com/sunboyy/repogen/internal/code" "github.com/sunboyy/repogen/internal/mongo" + "github.com/sunboyy/repogen/internal/spec" + "github.com/sunboyy/repogen/internal/testutils" ) -func TestGenerateMongoRepository(t *testing.T) { +const expectedConstructorResult = ` +import ( + "context" + + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/bson/primitive" + "go.mongodb.org/mongo-driver/mongo" +) + +func NewUserRepository(collection *mongo.Collection) UserRepository { + return &UserRepositoryMongo{ + collection: collection, + } +} + +type UserRepositoryMongo struct { + collection *mongo.Collection +} +` + +func TestGenerateConstructor(t *testing.T) { userModel := code.Struct{ Name: "UserModel", Fields: code.StructFields{ @@ -34,158 +56,46 @@ func TestGenerateMongoRepository(t *testing.T) { }, }, } - intf := code.Interface{ - Name: "UserRepository", - Methods: []code.Method{ - { + generator := mongo.NewGenerator(userModel, "UserRepository") + buffer := new(bytes.Buffer) + + err := generator.GenerateConstructor(buffer) + + if err != nil { + t.Error(err) + } + if err := testutils.ExpectMultiLineString(expectedConstructorResult, buffer.String()); err != nil { + t.Error(err) + } +} + +type GenerateMethodTestCase struct { + Name string + MethodSpec spec.MethodSpec + ExpectedCode string +} + +func TestGenerateMethod(t *testing.T) { + testTable := []GenerateMethodTestCase{ + { + Name: "simple find one method", + MethodSpec: spec.MethodSpec{ Name: "FindByID", Params: []code.Param{ {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, {Name: "id", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, }, Returns: []code.Type{code.PointerType{ContainedType: code.SimpleType("UserModel")}, code.SimpleType("error")}, - }, - { - Name: "FindOneByUsername", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "username", Type: code.SimpleType("string")}, - }, - Returns: []code.Type{ - code.PointerType{ContainedType: code.SimpleType("UserModel")}, - code.SimpleType("error"), + Operation: spec.FindOperation{ + Mode: spec.QueryModeOne, + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + {Comparator: spec.ComparatorEqual, Field: "ID"}, + }, + }, }, }, - { - Name: "FindByUsername", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "username", Type: code.SimpleType("string")}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.SimpleType("error"), - }, - }, - { - Name: "FindByIDAndUsername", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "id", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, - {Name: "username", Type: code.SimpleType("string")}, - }, - Returns: []code.Type{ - code.PointerType{ContainedType: code.SimpleType("UserModel")}, - code.SimpleType("error"), - }, - }, - { - Name: "FindByGenderNot", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "gender", Type: code.SimpleType("Gender")}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.SimpleType("error"), - }, - }, - { - Name: "FindByAgeLessThan", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "age", Type: code.SimpleType("int")}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.SimpleType("error"), - }, - }, - { - Name: "FindByAgeLessThanEqual", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "age", Type: code.SimpleType("int")}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.SimpleType("error"), - }, - }, - { - Name: "FindByAgeGreaterThan", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "age", Type: code.SimpleType("int")}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.SimpleType("error"), - }, - }, - { - Name: "FindByAgeGreaterThanEqual", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "age", Type: code.SimpleType("int")}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.SimpleType("error"), - }, - }, - { - Name: "FindByGenderOrAgeLessThan", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "gender", Type: code.SimpleType("Gender")}, - {Name: "age", Type: code.SimpleType("int")}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.SimpleType("error"), - }, - }, - { - Name: "FindByGenderIn", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "gender", Type: code.ArrayType{ContainedType: code.SimpleType("Gender")}}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.SimpleType("error"), - }, - }, - }, - } - - code, err := mongo.GenerateMongoRepository("user", userModel, intf) - - if err != nil { - t.Error(err) - } - expectedCode := `// Code generated by repogen. DO NOT EDIT. -package user - -import ( - "context" - - "go.mongodb.org/mongo-driver/bson" - "go.mongodb.org/mongo-driver/bson/primitive" - "go.mongodb.org/mongo-driver/mongo" -) - -func NewUserRepository(collection *mongo.Collection) UserRepository { - return &UserRepositoryMongo{ - collection: collection, - } -} - -type UserRepositoryMongo struct { - collection *mongo.Collection -} - + ExpectedCode: ` func (r *UserRepositoryMongo) FindByID(ctx context.Context, arg0 primitive.ObjectID) (*UserModel, error) { var entity UserModel if err := r.collection.FindOne(ctx, bson.M{ @@ -195,20 +105,33 @@ func (r *UserRepositoryMongo) FindByID(ctx context.Context, arg0 primitive.Objec } return &entity, nil } - -func (r *UserRepositoryMongo) FindOneByUsername(ctx context.Context, arg0 string) (*UserModel, error) { - var entity UserModel - if err := r.collection.FindOne(ctx, bson.M{ - "username": arg0, - }).Decode(&entity); err != nil { - return nil, err - } - return &entity, nil -} - -func (r *UserRepositoryMongo) FindByUsername(ctx context.Context, arg0 string) ([]*UserModel, error) { +`, + }, + { + Name: "simple find many method", + MethodSpec: spec.MethodSpec{ + Name: "FindByGender", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Name: "gender", Type: code.SimpleType("Gender")}, + }, + Returns: []code.Type{ + code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + code.SimpleType("error"), + }, + Operation: spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + {Comparator: spec.ComparatorEqual, Field: "Gender"}, + }, + }, + }, + }, + ExpectedCode: ` +func (r *UserRepositoryMongo) FindByGender(ctx context.Context, arg0 Gender) ([]*UserModel, error) { cursor, err := r.collection.Find(ctx, bson.M{ - "username": arg0, + "gender": arg0, }) if err != nil { return nil, err @@ -219,18 +142,114 @@ func (r *UserRepositoryMongo) FindByUsername(ctx context.Context, arg0 string) ( } return entities, nil } - -func (r *UserRepositoryMongo) FindByIDAndUsername(ctx context.Context, arg0 primitive.ObjectID, arg1 string) (*UserModel, error) { - var entity UserModel - if err := r.collection.FindOne(ctx, bson.M{ - "_id": arg0, - "username": arg1, - }).Decode(&entity); err != nil { +`, + }, + { + Name: "find with And operator", + MethodSpec: spec.MethodSpec{ + Name: "FindByGenderAndAge", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Name: "gender", Type: code.SimpleType("Gender")}, + {Name: "age", Type: code.SimpleType("int")}, + }, + Returns: []code.Type{ + code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + code.SimpleType("error"), + }, + Operation: spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{ + Operator: spec.OperatorAnd, + Predicates: []spec.Predicate{ + {Comparator: spec.ComparatorEqual, Field: "Gender"}, + {Comparator: spec.ComparatorEqual, Field: "Age"}, + }, + }, + }, + }, + ExpectedCode: ` +func (r *UserRepositoryMongo) FindByGenderAndAge(ctx context.Context, arg0 Gender, arg1 int) ([]*UserModel, error) { + cursor, err := r.collection.Find(ctx, bson.M{ + "gender": arg0, + "age": arg1, + }) + if err != nil { return nil, err } - return &entity, nil + var entities []*UserModel + if err := cursor.All(ctx, &entities); err != nil { + return nil, err + } + return entities, nil } - +`, + }, + { + Name: "find with Or operator", + MethodSpec: spec.MethodSpec{ + Name: "FindByGenderOrAge", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Name: "gender", Type: code.SimpleType("Gender")}, + {Name: "age", Type: code.SimpleType("int")}, + }, + Returns: []code.Type{ + code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + code.SimpleType("error"), + }, + Operation: spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{ + Operator: spec.OperatorOr, + Predicates: []spec.Predicate{ + {Comparator: spec.ComparatorEqual, Field: "Gender"}, + {Comparator: spec.ComparatorEqual, Field: "Age"}, + }, + }, + }, + }, + ExpectedCode: ` +func (r *UserRepositoryMongo) FindByGenderOrAge(ctx context.Context, arg0 Gender, arg1 int) ([]*UserModel, error) { + cursor, err := r.collection.Find(ctx, bson.M{ + "$or": []bson.M{ + {"gender": arg0}, + {"age": arg1}, + }, + }) + if err != nil { + return nil, err + } + var entities []*UserModel + if err := cursor.All(ctx, &entities); err != nil { + return nil, err + } + return entities, nil +} +`, + }, + { + Name: "find with Not comparator", + MethodSpec: spec.MethodSpec{ + Name: "FindByGenderNot", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Name: "gender", Type: code.SimpleType("Gender")}, + }, + Returns: []code.Type{ + code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + code.SimpleType("error"), + }, + Operation: spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + {Comparator: spec.ComparatorNot, Field: "Gender"}, + }, + }, + }, + }, + ExpectedCode: ` func (r *UserRepositoryMongo) FindByGenderNot(ctx context.Context, arg0 Gender) ([]*UserModel, error) { cursor, err := r.collection.Find(ctx, bson.M{ "gender": bson.M{"$ne": arg0}, @@ -244,7 +263,30 @@ func (r *UserRepositoryMongo) FindByGenderNot(ctx context.Context, arg0 Gender) } return entities, nil } - +`, + }, + { + Name: "find with LessThan comparator", + MethodSpec: spec.MethodSpec{ + Name: "FindByAgeLessThan", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Name: "age", Type: code.SimpleType("int")}, + }, + Returns: []code.Type{ + code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + code.SimpleType("error"), + }, + Operation: spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + {Comparator: spec.ComparatorLessThan, Field: "Age"}, + }, + }, + }, + }, + ExpectedCode: ` func (r *UserRepositoryMongo) FindByAgeLessThan(ctx context.Context, arg0 int) ([]*UserModel, error) { cursor, err := r.collection.Find(ctx, bson.M{ "age": bson.M{"$lt": arg0}, @@ -258,7 +300,30 @@ func (r *UserRepositoryMongo) FindByAgeLessThan(ctx context.Context, arg0 int) ( } return entities, nil } - +`, + }, + { + Name: "find with LessThanEqual comparator", + MethodSpec: spec.MethodSpec{ + Name: "FindByAgeLessThanEqual", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Name: "age", Type: code.SimpleType("int")}, + }, + Returns: []code.Type{ + code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + code.SimpleType("error"), + }, + Operation: spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + {Comparator: spec.ComparatorLessThanEqual, Field: "Age"}, + }, + }, + }, + }, + ExpectedCode: ` func (r *UserRepositoryMongo) FindByAgeLessThanEqual(ctx context.Context, arg0 int) ([]*UserModel, error) { cursor, err := r.collection.Find(ctx, bson.M{ "age": bson.M{"$lte": arg0}, @@ -272,7 +337,30 @@ func (r *UserRepositoryMongo) FindByAgeLessThanEqual(ctx context.Context, arg0 i } return entities, nil } - +`, + }, + { + Name: "find with GreaterThan comparator", + MethodSpec: spec.MethodSpec{ + Name: "FindByAgeGreaterThan", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Name: "age", Type: code.SimpleType("int")}, + }, + Returns: []code.Type{ + code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + code.SimpleType("error"), + }, + Operation: spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + {Comparator: spec.ComparatorGreaterThan, Field: "Age"}, + }, + }, + }, + }, + ExpectedCode: ` func (r *UserRepositoryMongo) FindByAgeGreaterThan(ctx context.Context, arg0 int) ([]*UserModel, error) { cursor, err := r.collection.Find(ctx, bson.M{ "age": bson.M{"$gt": arg0}, @@ -286,7 +374,30 @@ func (r *UserRepositoryMongo) FindByAgeGreaterThan(ctx context.Context, arg0 int } return entities, nil } - +`, + }, + { + Name: "find with GreaterThanEqual comparator", + MethodSpec: spec.MethodSpec{ + Name: "FindByAgeGreaterThanEqual", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Name: "age", Type: code.SimpleType("int")}, + }, + Returns: []code.Type{ + code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + code.SimpleType("error"), + }, + Operation: spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + {Comparator: spec.ComparatorGreaterThanEqual, Field: "Age"}, + }, + }, + }, + }, + ExpectedCode: ` func (r *UserRepositoryMongo) FindByAgeGreaterThanEqual(ctx context.Context, arg0 int) ([]*UserModel, error) { cursor, err := r.collection.Find(ctx, bson.M{ "age": bson.M{"$gte": arg0}, @@ -300,24 +411,30 @@ func (r *UserRepositoryMongo) FindByAgeGreaterThanEqual(ctx context.Context, arg } return entities, nil } - -func (r *UserRepositoryMongo) FindByGenderOrAgeLessThan(ctx context.Context, arg0 Gender, arg1 int) ([]*UserModel, error) { - cursor, err := r.collection.Find(ctx, bson.M{ - "$or": []bson.M{ - {"gender": arg0}, - {"age": bson.M{"$lt": arg1}}, +`, }, - }) - if err != nil { - return nil, err - } - var entities []*UserModel - if err := cursor.All(ctx, &entities); err != nil { - return nil, err - } - return entities, nil -} - + { + Name: "find with In comparator", + MethodSpec: spec.MethodSpec{ + Name: "FindByGenderIn", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Name: "gender", Type: code.ArrayType{ContainedType: code.SimpleType("Gender")}}, + }, + Returns: []code.Type{ + code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + code.SimpleType("error"), + }, + Operation: spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + {Comparator: spec.ComparatorIn, Field: "Gender"}, + }, + }, + }, + }, + ExpectedCode: ` func (r *UserRepositoryMongo) FindByGenderIn(ctx context.Context, arg0 []Gender) ([]*UserModel, error) { cursor, err := r.collection.Find(ctx, bson.M{ "gender": bson.M{"$in": arg0}, @@ -331,13 +448,47 @@ func (r *UserRepositoryMongo) FindByGenderIn(ctx context.Context, arg0 []Gender) } return entities, nil } -` - expectedCodeLines := strings.Split(expectedCode, "\n") - actualCodeLines := strings.Split(code, "\n") +`, + }, + } + for _, testCase := range testTable { + t.Run(testCase.Name, func(t *testing.T) { + userModel := code.Struct{ + Name: "UserModel", + Fields: code.StructFields{ + { + Name: "ID", + Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}, + Tags: map[string][]string{"bson": {"_id", "omitempty"}}, + }, + { + Name: "Username", + Type: code.SimpleType("string"), + Tags: map[string][]string{"bson": {"username"}}, + }, + { + Name: "Gender", + Type: code.SimpleType("Gender"), + Tags: map[string][]string{"bson": {"gender"}}, + }, + { + Name: "Age", + Type: code.SimpleType("int"), + Tags: map[string][]string{"bson": {"age"}}, + }, + }, + } + generator := mongo.NewGenerator(userModel, "UserRepository") + buffer := new(bytes.Buffer) - for i, line := range expectedCodeLines { - if line != actualCodeLines[i] { - t.Errorf("On line %d\nExpected = %v\nActual = %v", i, line, actualCodeLines[i]) - } + err := generator.GenerateMethod(testCase.MethodSpec, buffer) + + if err != nil { + t.Error(err) + } + if err := testutils.ExpectMultiLineString(testCase.ExpectedCode, buffer.String()); err != nil { + t.Error(err) + } + }) } } diff --git a/internal/mongo/templates.go b/internal/mongo/templates.go index 4f73e02..3e87a12 100644 --- a/internal/mongo/templates.go +++ b/internal/mongo/templates.go @@ -7,9 +7,7 @@ import ( "github.com/sunboyy/repogen/internal/code" ) -const baseTemplate = `// Code generated by repogen. DO NOT EDIT. -package {{.PackageName}} - +const constructorTemplate = ` import ( "context" @@ -19,20 +17,19 @@ import ( ) func New{{.InterfaceName}}(collection *mongo.Collection) {{.InterfaceName}} { - return &{{.StructName}}{ + return &{{.ImplStructName}}{ collection: collection, } } -type {{.StructName}} struct { +type {{.ImplStructName}} struct { collection *mongo.Collection } ` -type mongoBaseTemplateData struct { - PackageName string - InterfaceName string - StructName string +type mongoConstructorTemplateData struct { + InterfaceName string + ImplStructName string } const methodTemplate = ` diff --git a/internal/testutils/multilines.go b/internal/testutils/multilines.go new file mode 100644 index 0000000..d435a2b --- /dev/null +++ b/internal/testutils/multilines.go @@ -0,0 +1,31 @@ +package testutils + +import ( + "fmt" + "strings" +) + +// ExpectMultiLineString compares two multi-line strings and report the difference +func ExpectMultiLineString(expected, actual string) error { + expectedLines := strings.Split(expected, "\n") + actualLines := strings.Split(actual, "\n") + + numberOfComparableLines := len(expectedLines) + if len(actualLines) < numberOfComparableLines { + numberOfComparableLines = len(actualLines) + } + + for i := 0; i < numberOfComparableLines; i++ { + if expectedLines[i] != actualLines[i] { + return fmt.Errorf("On line %d\nExpected: %v\nReceived: %v", i+1, expectedLines[i], actualLines[i]) + } + } + + if len(expectedLines) < len(actualLines) { + return fmt.Errorf("Unexpected lines:\n%s", strings.Join(actualLines[len(expectedLines):], "\n")) + } else if len(expectedLines) > len(actualLines) { + return fmt.Errorf("Missing lines:\n%s", strings.Join(expectedLines[len(actualLines):], "\n")) + } + + return nil +} diff --git a/internal/testutils/multilines_test.go b/internal/testutils/multilines_test.go new file mode 100644 index 0000000..5a6a26c --- /dev/null +++ b/internal/testutils/multilines_test.go @@ -0,0 +1,72 @@ +package testutils_test + +import ( + "testing" + + "github.com/sunboyy/repogen/internal/testutils" +) + +func TestExpectMultiLineString(t *testing.T) { + t.Run("same string should return nil", func(t *testing.T) { + text := ` Hello world + this is a test text ` + + err := testutils.ExpectMultiLineString(text, text) + + if err != nil { + t.Errorf("Expected = \nReceived = %s", err.Error()) + } + }) + + t.Run("different string with same number of lines", func(t *testing.T) { + expectedText := ` Hello world +this is an expected text +how are you?` + actualText := ` Hello world +this is a real text +How are you?` + + err := testutils.ExpectMultiLineString(expectedText, actualText) + + expectedError := "On line 2\nExpected: this is an expected text\nReceived: this is a real text" + if err == nil || err.Error() != expectedError { + t.Errorf("Expected = %s\nReceived = %s", expectedError, err.Error()) + } + }) + + t.Run("expected text longer than actual text", func(t *testing.T) { + expectedText := ` Hello world +this is an expected text +how are you? +I'm fine... +Thank you...` + actualText := ` Hello world +this is an expected text +how are you?` + + err := testutils.ExpectMultiLineString(expectedText, actualText) + + expectedError := "Missing lines:\nI'm fine...\nThank you..." + if err == nil || err.Error() != expectedError { + t.Errorf("Expected = %s\nReceived = %s", expectedError, err.Error()) + } + }) + + t.Run("actual text longer than expected text", func(t *testing.T) { + expectedText := ` Hello world +this is an expected text +how are you?` + actualText := ` Hello world +this is an expected text +how are you? +I'm fine... +Thank you...` + + err := testutils.ExpectMultiLineString(expectedText, actualText) + + expectedError := "Unexpected lines:\nI'm fine...\nThank you..." + if err == nil || err.Error() != expectedError { + t.Errorf("Expected = %s\nReceived = %s", expectedError, err.Error()) + } + }) +} diff --git a/main.go b/main.go new file mode 100644 index 0000000..8c187cf --- /dev/null +++ b/main.go @@ -0,0 +1,86 @@ +package main + +import ( + "errors" + "flag" + "go/parser" + "go/token" + "os" + "path/filepath" + + "github.com/sunboyy/repogen/internal/code" + "github.com/sunboyy/repogen/internal/generator" + "github.com/sunboyy/repogen/internal/spec" +) + +func main() { + sourcePtr := flag.String("src", "", "source file") + destPtr := flag.String("dest", "", "destination file") + modelPtr := flag.String("model", "", "model struct name") + repoPtr := flag.String("repo", "", "repository interface name") + + flag.Parse() + + if *sourcePtr == "" { + panic("-source flag required") + } + if *modelPtr == "" { + panic("-model flag required") + } + if *repoPtr == "" { + panic("-repo flag required") + } + + dest := os.Stdout + if *destPtr != "" { + if err := os.MkdirAll(filepath.Dir(*destPtr), os.ModePerm); err != nil { + panic(err) + } + file, err := os.Create(*destPtr) + if err != nil { + panic(err) + } + defer file.Close() + dest = file + } + + code, err := generateFromRequest(*sourcePtr, *modelPtr, *repoPtr) + if err != nil { + panic(err) + } + + if _, err := dest.WriteString(code); err != nil { + panic(err) + } +} + +func generateFromRequest(fileName, structModelName, repositoryInterfaceName string) (string, error) { + fset := token.NewFileSet() + f, err := parser.ParseFile(fset, fileName, nil, parser.ParseComments) + if err != nil { + panic(err) + } + + file := code.ExtractComponents(f) + + structModel, ok := file.Structs.ByName(structModelName) + if !ok { + return "", errors.New("struct model not found") + } + + intf, ok := file.Interfaces.ByName(repositoryInterfaceName) + if !ok { + return "", errors.New("interface model not found") + } + + var methodSpecs []spec.MethodSpec + for _, method := range intf.Methods { + methodSpec, err := spec.ParseInterfaceMethod(structModel, method) + if err != nil { + return "", err + } + methodSpecs = append(methodSpecs, methodSpec) + } + + return generator.GenerateRepository(file.PackageName, structModel, intf.Name, methodSpecs) +}