diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..f47cb20 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +*.out diff --git a/README.md b/README.md index 6b17c55..71b2d43 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,7 @@ # repogen -![build status badge](https://github.com/sunboyy/repogen/workflows/build/badge.svg) + + build status badge + Repogen is a code generator for database repository in Golang. (WIP) diff --git a/go.mod b/go.mod index 2548e97..261b72d 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,7 @@ module github.com/sunboyy/repogen go 1.15 -require github.com/fatih/camelcase v1.0.0 +require ( + github.com/fatih/camelcase v1.0.0 + golang.org/x/tools v0.0.0-20210115202250-e0d201561e39 +) diff --git a/go.sum b/go.sum index 315a92c..88b1cc0 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,26 @@ github.com/fatih/camelcase v1.0.0 h1:hxNvNX/xYBp0ovncs8WyWZrOrpBNub/JfaMvbURyft8= github.com/fatih/camelcase v1.0.0/go.mod h1:yN2Sb0lFhZJUdVvtELVWefmrXpuZESvPmqwoZc+/fpc= +github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/mod v0.3.0 h1:RM4zey1++hCTbCVQfnWeKs9/IEsaBLA8vTkd0WVtmH4= +golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20210115202250-e0d201561e39 h1:BTs2GMGSMWpgtCpv1CE7vkJTv7XcHdcLLnAMu7UbgTY= +golang.org/x/tools v0.0.0-20210115202250-e0d201561e39/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/internal/mongo/generator.go b/internal/mongo/generator.go new file mode 100644 index 0000000..b3e43ea --- /dev/null +++ b/internal/mongo/generator.go @@ -0,0 +1,168 @@ +package mongo + +import ( + "bytes" + "errors" + "fmt" + "text/template" + + "github.com/sunboyy/repogen/internal/code" + "github.com/sunboyy/repogen/internal/spec" + "golang.org/x/tools/imports" +) + +// GenerateMongoRepository generates mongodb repository +func GenerateMongoRepository(packageName string, structModel code.Struct, intf code.Interface) (string, error) { + repositorySpec, err := spec.ParseRepositoryInterface(structModel, intf) + if err != nil { + return "", err + } + + generator := mongoRepositoryGenerator{ + PackageName: packageName, + StructModel: structModel, + RepositorySpec: repositorySpec, + } + + output, err := generator.Generate() + if err != nil { + return "", err + } + + return output, nil +} + +type mongoRepositoryGenerator struct { + PackageName string + StructModel code.Struct + RepositorySpec spec.RepositorySpec +} + +func (g mongoRepositoryGenerator) Generate() (string, error) { + buffer := new(bytes.Buffer) + if err := g.generateBaseContent(buffer); err != nil { + return "", err + } + + for _, method := range g.RepositorySpec.Methods { + if err := g.generateMethod(buffer, method); err != nil { + return "", err + } + } + + newOutput, err := imports.Process("", buffer.Bytes(), nil) + if err != nil { + return "", err + } + + return string(newOutput), nil +} + +func (g mongoRepositoryGenerator) generateBaseContent(buffer *bytes.Buffer) error { + tmpl, err := template.New("mongo_repository_base").Parse(baseTemplate) + if err != nil { + return err + } + + tmplData := mongoBaseTemplateData{ + PackageName: g.PackageName, + InterfaceName: g.RepositorySpec.InterfaceName, + StructName: g.structName(), + } + + if err := tmpl.Execute(buffer, tmplData); err != nil { + return err + } + + return nil +} + +func (g mongoRepositoryGenerator) generateMethod(buffer *bytes.Buffer, method spec.MethodSpec) error { + tmpl, err := template.New("mongo_repository_method").Parse(methodTemplate) + if err != nil { + return err + } + + implementation, err := g.generateMethodImplementation(method) + if err != nil { + return err + } + + var paramTypes []code.Type + for _, param := range method.Params[1:] { + paramTypes = append(paramTypes, param.Type) + } + + tmplData := mongoMethodTemplateData{ + StructName: g.structName(), + MethodName: method.Name, + ParamTypes: paramTypes, + ReturnTypes: method.Returns, + Implementation: implementation, + } + + if err := tmpl.Execute(buffer, tmplData); err != nil { + return err + } + + return nil +} + +func (g mongoRepositoryGenerator) generateMethodImplementation(methodSpec spec.MethodSpec) (string, error) { + switch operation := methodSpec.Operation.(type) { + case spec.FindOperation: + return g.generateFindImplementation(operation) + } + + return "", errors.New("method spec not supported") +} + +func (g mongoRepositoryGenerator) generateFindImplementation(operation spec.FindOperation) (string, error) { + buffer := new(bytes.Buffer) + + var queryFields []string + for _, fieldName := range operation.Query.Fields { + structField, ok := g.StructModel.Fields.ByName(fieldName) + if !ok { + return "", fmt.Errorf("struct field %s not found", fieldName) + } + + bsonTag, ok := structField.Tags["bson"] + if !ok { + return "", fmt.Errorf("struct field %s does not have bson tag", fieldName) + } + + queryFields = append(queryFields, bsonTag[0]) + } + + tmplData := mongoFindTemplateData{ + EntityType: g.StructModel.Name, + QueryFields: queryFields, + } + + if operation.Mode == spec.QueryModeOne { + tmpl, err := template.New("mongo_repository_findone").Parse(findOneTemplate) + if err != nil { + return "", err + } + + if err := tmpl.Execute(buffer, tmplData); err != nil { + return "", err + } + } else { + tmpl, err := template.New("mongo_repository_findmany").Parse(findManyTemplate) + if err != nil { + return "", err + } + + if err := tmpl.Execute(buffer, tmplData); err != nil { + return "", err + } + } + + return buffer.String(), nil +} + +func (g mongoRepositoryGenerator) structName() string { + return g.RepositorySpec.InterfaceName + "Mongo" +} diff --git a/internal/mongo/generator_test.go b/internal/mongo/generator_test.go new file mode 100644 index 0000000..0a77702 --- /dev/null +++ b/internal/mongo/generator_test.go @@ -0,0 +1,154 @@ +package mongo_test + +import ( + "strings" + "testing" + + "github.com/sunboyy/repogen/internal/code" + "github.com/sunboyy/repogen/internal/mongo" +) + +func TestGenerateMongoRepository(t *testing.T) { + userModel := code.Struct{ + Name: "UserModel", + Fields: code.StructFields{ + { + Name: "ID", + Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}, + Tags: map[string][]string{"bson": {"_id", "omitempty"}}, + }, + { + Name: "Username", + Type: code.SimpleType("string"), + Tags: map[string][]string{"bson": {"username"}}, + }, + }, + } + intf := code.Interface{ + Name: "UserRepository", + Methods: []code.Method{ + { + Name: "FindByID", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Name: "id", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, + }, + Returns: []code.Type{code.PointerType{ContainedType: code.SimpleType("UserModel")}, code.SimpleType("error")}, + }, + { + Name: "FindOneByUsername", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Name: "username", Type: code.SimpleType("string")}, + }, + Returns: []code.Type{ + code.PointerType{ContainedType: code.SimpleType("UserModel")}, + code.SimpleType("error"), + }, + }, + { + Name: "FindByUsername", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Name: "username", Type: code.SimpleType("string")}, + }, + Returns: []code.Type{ + code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + code.SimpleType("error"), + }, + }, + { + Name: "FindByIDAndUsername", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Name: "id", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, + {Name: "username", Type: code.SimpleType("string")}, + }, + Returns: []code.Type{ + code.PointerType{ContainedType: code.SimpleType("UserModel")}, + code.SimpleType("error"), + }, + }, + }, + } + + code, err := mongo.GenerateMongoRepository("user", userModel, intf) + + if err != nil { + t.Error(err) + } + expectedCode := `// Code generated by repogen. DO NOT EDIT. +package user + +import ( + "context" + + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/bson/primitive" + "go.mongodb.org/mongo-driver/mongo" +) + +func NewUserRepository(collection *mongo.Collection) UserRepository { + return &UserRepositoryMongo{ + collection: collection, + } +} + +type UserRepositoryMongo struct { + collection *mongo.Collection +} + +func (r *UserRepositoryMongo) FindByID(ctx context.Context, arg0 primitive.ObjectID) (*UserModel, error) { + var entity UserModel + if err := r.collection.FindOne(ctx, bson.M{ + "_id": arg0, + }).Decode(&entity); err != nil { + return nil, err + } + return &entity, nil +} + +func (r *UserRepositoryMongo) FindOneByUsername(ctx context.Context, arg0 string) (*UserModel, error) { + var entity UserModel + if err := r.collection.FindOne(ctx, bson.M{ + "username": arg0, + }).Decode(&entity); err != nil { + return nil, err + } + return &entity, nil +} + +func (r *UserRepositoryMongo) FindByUsername(ctx context.Context, arg0 string) ([]*UserModel, error) { + cursor, err := r.collection.Find(ctx, bson.M{ + "username": arg0, + }) + if err != nil { + return nil, err + } + var entities []*UserModel + if err := cursor.All(ctx, &entities); err != nil { + return nil, err + } + return entities, nil +} + +func (r *UserRepositoryMongo) FindByIDAndUsername(ctx context.Context, arg0 primitive.ObjectID, arg1 string) (*UserModel, error) { + var entity UserModel + if err := r.collection.FindOne(ctx, bson.M{ + "_id": arg0, + "username": arg1, + }).Decode(&entity); err != nil { + return nil, err + } + return &entity, nil +} +` + expectedCodeLines := strings.Split(expectedCode, "\n") + actualCodeLines := strings.Split(code, "\n") + + for i, line := range expectedCodeLines { + if line != actualCodeLines[i] { + t.Errorf("On line %d\nExpected = %v\nActual = %v", i, line, actualCodeLines[i]) + } + } +} diff --git a/internal/mongo/templates.go b/internal/mongo/templates.go new file mode 100644 index 0000000..41b10cf --- /dev/null +++ b/internal/mongo/templates.go @@ -0,0 +1,99 @@ +package mongo + +import ( + "fmt" + "strings" + + "github.com/sunboyy/repogen/internal/code" +) + +const baseTemplate = `// Code generated by repogen. DO NOT EDIT. +package {{.PackageName}} + +import ( + "context" + + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/bson/primitive" + "go.mongodb.org/mongo-driver/mongo" +) + +func New{{.InterfaceName}}(collection *mongo.Collection) {{.InterfaceName}} { + return &{{.StructName}}{ + collection: collection, + } +} + +type {{.StructName}} struct { + collection *mongo.Collection +} +` + +type mongoBaseTemplateData struct { + PackageName string + InterfaceName string + StructName string +} + +const methodTemplate = ` +func (r *{{.StructName}}) {{.MethodName}}(ctx context.Context, {{.Parameters}}){{.Returns}} { +{{.Implementation}} +} +` + +type mongoMethodTemplateData struct { + StructName string + MethodName string + ParamTypes []code.Type + ReturnTypes []code.Type + Implementation string +} + +func (data mongoMethodTemplateData) Parameters() string { + var params []string + for i, paramType := range data.ParamTypes { + params = append(params, fmt.Sprintf("arg%d %s", i, paramType.Code())) + } + return strings.Join(params, ", ") +} + +func (data mongoMethodTemplateData) Returns() string { + if len(data.ReturnTypes) == 0 { + return "" + } + + if len(data.ReturnTypes) == 1 { + return fmt.Sprintf(" %s", data.ReturnTypes[0].Code()) + } + + var returns []string + for _, returnType := range data.ReturnTypes { + returns = append(returns, returnType.Code()) + } + return fmt.Sprintf(" (%s)", strings.Join(returns, ", ")) +} + +const findOneTemplate = ` var entity {{.EntityType}} + if err := r.collection.FindOne(ctx, bson.M{ +{{range $index, $field := .QueryFields}} "{{$field}}": arg{{$index}}, +{{end}} }).Decode(&entity); err != nil { + return nil, err + } + return &entity, nil` + +type mongoFindTemplateData struct { + EntityType string + QueryFields []string +} + +const findManyTemplate = ` cursor, err := r.collection.Find(ctx, bson.M{ +{{range $index, $field := .QueryFields}} "{{$field}}": arg{{$index}}, +{{end}} }) + if err != nil { + return nil, err + } + var entities []*{{.EntityType}} + if err := cursor.All(ctx, &entities); err != nil { + return nil, err + } + return entities, nil` diff --git a/internal/spec/models.go b/internal/spec/models.go index 11e9cb9..171e3bd 100644 --- a/internal/spec/models.go +++ b/internal/spec/models.go @@ -32,15 +32,15 @@ type Operation interface { // FindOperation is a method specification for find operations type FindOperation struct { Mode QueryMode - Query Query + Query QuerySpec } -// Query is a condition of querying the database -type Query struct { +// QuerySpec is a condition of querying the database +type QuerySpec struct { Fields []string } // NumberOfArguments returns number of arguments required to perform the query -func (q Query) NumberOfArguments() int { +func (q QuerySpec) NumberOfArguments() int { return len(q.Fields) } diff --git a/internal/spec/parser.go b/internal/spec/parser.go index 3369195..d08642c 100644 --- a/internal/spec/parser.go +++ b/internal/spec/parser.go @@ -58,12 +58,12 @@ func (p repositoryInterfaceParser) parseFindMethod(method code.Method, tokens [] return MethodSpec{}, err } - query, err := p.parseQuery(tokens) + querySpec, err := p.parseQuery(tokens) if err != nil { return MethodSpec{}, err } - if query.NumberOfArguments()+1 != len(method.Params) { + if querySpec.NumberOfArguments()+1 != len(method.Params) { return MethodSpec{}, errors.New("method parameter not supported") } @@ -73,7 +73,7 @@ func (p repositoryInterfaceParser) parseFindMethod(method code.Method, tokens [] Returns: method.Returns, Operation: FindOperation{ Mode: mode, - Query: query, + Query: querySpec, }, }, nil } @@ -111,13 +111,13 @@ func (p repositoryInterfaceParser) extractFindReturns(returns []code.Type) (Quer return "", errors.New("method return not supported") } -func (p repositoryInterfaceParser) parseQuery(tokens []string) (Query, error) { +func (p repositoryInterfaceParser) parseQuery(tokens []string) (QuerySpec, error) { if len(tokens) == 0 { - return Query{}, errors.New("method name not supported") + return QuerySpec{}, errors.New("method name not supported") } if len(tokens) == 1 && tokens[0] == "All" { - return Query{}, nil + return QuerySpec{}, nil } if tokens[0] == "One" { @@ -128,7 +128,7 @@ func (p repositoryInterfaceParser) parseQuery(tokens []string) (Query, error) { } if tokens[0] == "And" { - return Query{}, errors.New("method name not supported") + return QuerySpec{}, errors.New("method name not supported") } var queryFields []string var aggregatedToken string @@ -141,9 +141,9 @@ func (p repositoryInterfaceParser) parseQuery(tokens []string) (Query, error) { } } if aggregatedToken == "" { - return Query{}, errors.New("method name not supported") + return QuerySpec{}, errors.New("method name not supported") } queryFields = append(queryFields, aggregatedToken) - return Query{Fields: queryFields}, nil + return QuerySpec{Fields: queryFields}, nil } diff --git a/internal/spec/parser_test.go b/internal/spec/parser_test.go index 56160c3..d5a5f9a 100644 --- a/internal/spec/parser_test.go +++ b/internal/spec/parser_test.go @@ -58,7 +58,7 @@ func TestParseRepositoryInterface(t *testing.T) { }, Operation: spec.FindOperation{ Mode: spec.QueryModeOne, - Query: spec.Query{Fields: []string{"ID"}}, + Query: spec.QuerySpec{Fields: []string{"ID"}}, }, }, }, @@ -98,7 +98,7 @@ func TestParseRepositoryInterface(t *testing.T) { Operation: spec.FindOperation{ Mode: spec.QueryModeOne, - Query: spec.Query{Fields: []string{"PhoneNumber"}}, + Query: spec.QuerySpec{Fields: []string{"PhoneNumber"}}, }, }, }, @@ -137,7 +137,7 @@ func TestParseRepositoryInterface(t *testing.T) { }, Operation: spec.FindOperation{ Mode: spec.QueryModeMany, - Query: spec.Query{Fields: []string{"City"}}, + Query: spec.QuerySpec{Fields: []string{"City"}}, }, }, }, @@ -214,7 +214,7 @@ func TestParseRepositoryInterface(t *testing.T) { }, Operation: spec.FindOperation{ Mode: spec.QueryModeMany, - Query: spec.Query{Fields: []string{"City", "Gender"}}, + Query: spec.QuerySpec{Fields: []string{"City", "Gender"}}, }, }, },