Generate Mongo repository code from repository spec

This commit is contained in:
sunboyy 2021-01-17 10:29:50 +07:00
parent 3808684ed0
commit a24a9f81d7
10 changed files with 470 additions and 19 deletions

168
internal/mongo/generator.go Normal file
View file

@ -0,0 +1,168 @@
package mongo
import (
"bytes"
"errors"
"fmt"
"text/template"
"github.com/sunboyy/repogen/internal/code"
"github.com/sunboyy/repogen/internal/spec"
"golang.org/x/tools/imports"
)
// GenerateMongoRepository generates mongodb repository
func GenerateMongoRepository(packageName string, structModel code.Struct, intf code.Interface) (string, error) {
repositorySpec, err := spec.ParseRepositoryInterface(structModel, intf)
if err != nil {
return "", err
}
generator := mongoRepositoryGenerator{
PackageName: packageName,
StructModel: structModel,
RepositorySpec: repositorySpec,
}
output, err := generator.Generate()
if err != nil {
return "", err
}
return output, nil
}
type mongoRepositoryGenerator struct {
PackageName string
StructModel code.Struct
RepositorySpec spec.RepositorySpec
}
func (g mongoRepositoryGenerator) Generate() (string, error) {
buffer := new(bytes.Buffer)
if err := g.generateBaseContent(buffer); err != nil {
return "", err
}
for _, method := range g.RepositorySpec.Methods {
if err := g.generateMethod(buffer, method); err != nil {
return "", err
}
}
newOutput, err := imports.Process("", buffer.Bytes(), nil)
if err != nil {
return "", err
}
return string(newOutput), nil
}
func (g mongoRepositoryGenerator) generateBaseContent(buffer *bytes.Buffer) error {
tmpl, err := template.New("mongo_repository_base").Parse(baseTemplate)
if err != nil {
return err
}
tmplData := mongoBaseTemplateData{
PackageName: g.PackageName,
InterfaceName: g.RepositorySpec.InterfaceName,
StructName: g.structName(),
}
if err := tmpl.Execute(buffer, tmplData); err != nil {
return err
}
return nil
}
func (g mongoRepositoryGenerator) generateMethod(buffer *bytes.Buffer, method spec.MethodSpec) error {
tmpl, err := template.New("mongo_repository_method").Parse(methodTemplate)
if err != nil {
return err
}
implementation, err := g.generateMethodImplementation(method)
if err != nil {
return err
}
var paramTypes []code.Type
for _, param := range method.Params[1:] {
paramTypes = append(paramTypes, param.Type)
}
tmplData := mongoMethodTemplateData{
StructName: g.structName(),
MethodName: method.Name,
ParamTypes: paramTypes,
ReturnTypes: method.Returns,
Implementation: implementation,
}
if err := tmpl.Execute(buffer, tmplData); err != nil {
return err
}
return nil
}
func (g mongoRepositoryGenerator) generateMethodImplementation(methodSpec spec.MethodSpec) (string, error) {
switch operation := methodSpec.Operation.(type) {
case spec.FindOperation:
return g.generateFindImplementation(operation)
}
return "", errors.New("method spec not supported")
}
func (g mongoRepositoryGenerator) generateFindImplementation(operation spec.FindOperation) (string, error) {
buffer := new(bytes.Buffer)
var queryFields []string
for _, fieldName := range operation.Query.Fields {
structField, ok := g.StructModel.Fields.ByName(fieldName)
if !ok {
return "", fmt.Errorf("struct field %s not found", fieldName)
}
bsonTag, ok := structField.Tags["bson"]
if !ok {
return "", fmt.Errorf("struct field %s does not have bson tag", fieldName)
}
queryFields = append(queryFields, bsonTag[0])
}
tmplData := mongoFindTemplateData{
EntityType: g.StructModel.Name,
QueryFields: queryFields,
}
if operation.Mode == spec.QueryModeOne {
tmpl, err := template.New("mongo_repository_findone").Parse(findOneTemplate)
if err != nil {
return "", err
}
if err := tmpl.Execute(buffer, tmplData); err != nil {
return "", err
}
} else {
tmpl, err := template.New("mongo_repository_findmany").Parse(findManyTemplate)
if err != nil {
return "", err
}
if err := tmpl.Execute(buffer, tmplData); err != nil {
return "", err
}
}
return buffer.String(), nil
}
func (g mongoRepositoryGenerator) structName() string {
return g.RepositorySpec.InterfaceName + "Mongo"
}

View file

@ -0,0 +1,154 @@
package mongo_test
import (
"strings"
"testing"
"github.com/sunboyy/repogen/internal/code"
"github.com/sunboyy/repogen/internal/mongo"
)
func TestGenerateMongoRepository(t *testing.T) {
userModel := code.Struct{
Name: "UserModel",
Fields: code.StructFields{
{
Name: "ID",
Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"},
Tags: map[string][]string{"bson": {"_id", "omitempty"}},
},
{
Name: "Username",
Type: code.SimpleType("string"),
Tags: map[string][]string{"bson": {"username"}},
},
},
}
intf := code.Interface{
Name: "UserRepository",
Methods: []code.Method{
{
Name: "FindByID",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "id", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}},
},
Returns: []code.Type{code.PointerType{ContainedType: code.SimpleType("UserModel")}, code.SimpleType("error")},
},
{
Name: "FindOneByUsername",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "username", Type: code.SimpleType("string")},
},
Returns: []code.Type{
code.PointerType{ContainedType: code.SimpleType("UserModel")},
code.SimpleType("error"),
},
},
{
Name: "FindByUsername",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "username", Type: code.SimpleType("string")},
},
Returns: []code.Type{
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
code.SimpleType("error"),
},
},
{
Name: "FindByIDAndUsername",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "id", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}},
{Name: "username", Type: code.SimpleType("string")},
},
Returns: []code.Type{
code.PointerType{ContainedType: code.SimpleType("UserModel")},
code.SimpleType("error"),
},
},
},
}
code, err := mongo.GenerateMongoRepository("user", userModel, intf)
if err != nil {
t.Error(err)
}
expectedCode := `// Code generated by repogen. DO NOT EDIT.
package user
import (
"context"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo"
)
func NewUserRepository(collection *mongo.Collection) UserRepository {
return &UserRepositoryMongo{
collection: collection,
}
}
type UserRepositoryMongo struct {
collection *mongo.Collection
}
func (r *UserRepositoryMongo) FindByID(ctx context.Context, arg0 primitive.ObjectID) (*UserModel, error) {
var entity UserModel
if err := r.collection.FindOne(ctx, bson.M{
"_id": arg0,
}).Decode(&entity); err != nil {
return nil, err
}
return &entity, nil
}
func (r *UserRepositoryMongo) FindOneByUsername(ctx context.Context, arg0 string) (*UserModel, error) {
var entity UserModel
if err := r.collection.FindOne(ctx, bson.M{
"username": arg0,
}).Decode(&entity); err != nil {
return nil, err
}
return &entity, nil
}
func (r *UserRepositoryMongo) FindByUsername(ctx context.Context, arg0 string) ([]*UserModel, error) {
cursor, err := r.collection.Find(ctx, bson.M{
"username": arg0,
})
if err != nil {
return nil, err
}
var entities []*UserModel
if err := cursor.All(ctx, &entities); err != nil {
return nil, err
}
return entities, nil
}
func (r *UserRepositoryMongo) FindByIDAndUsername(ctx context.Context, arg0 primitive.ObjectID, arg1 string) (*UserModel, error) {
var entity UserModel
if err := r.collection.FindOne(ctx, bson.M{
"_id": arg0,
"username": arg1,
}).Decode(&entity); err != nil {
return nil, err
}
return &entity, nil
}
`
expectedCodeLines := strings.Split(expectedCode, "\n")
actualCodeLines := strings.Split(code, "\n")
for i, line := range expectedCodeLines {
if line != actualCodeLines[i] {
t.Errorf("On line %d\nExpected = %v\nActual = %v", i, line, actualCodeLines[i])
}
}
}

View file

@ -0,0 +1,99 @@
package mongo
import (
"fmt"
"strings"
"github.com/sunboyy/repogen/internal/code"
)
const baseTemplate = `// Code generated by repogen. DO NOT EDIT.
package {{.PackageName}}
import (
"context"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo"
)
func New{{.InterfaceName}}(collection *mongo.Collection) {{.InterfaceName}} {
return &{{.StructName}}{
collection: collection,
}
}
type {{.StructName}} struct {
collection *mongo.Collection
}
`
type mongoBaseTemplateData struct {
PackageName string
InterfaceName string
StructName string
}
const methodTemplate = `
func (r *{{.StructName}}) {{.MethodName}}(ctx context.Context, {{.Parameters}}){{.Returns}} {
{{.Implementation}}
}
`
type mongoMethodTemplateData struct {
StructName string
MethodName string
ParamTypes []code.Type
ReturnTypes []code.Type
Implementation string
}
func (data mongoMethodTemplateData) Parameters() string {
var params []string
for i, paramType := range data.ParamTypes {
params = append(params, fmt.Sprintf("arg%d %s", i, paramType.Code()))
}
return strings.Join(params, ", ")
}
func (data mongoMethodTemplateData) Returns() string {
if len(data.ReturnTypes) == 0 {
return ""
}
if len(data.ReturnTypes) == 1 {
return fmt.Sprintf(" %s", data.ReturnTypes[0].Code())
}
var returns []string
for _, returnType := range data.ReturnTypes {
returns = append(returns, returnType.Code())
}
return fmt.Sprintf(" (%s)", strings.Join(returns, ", "))
}
const findOneTemplate = ` var entity {{.EntityType}}
if err := r.collection.FindOne(ctx, bson.M{
{{range $index, $field := .QueryFields}} "{{$field}}": arg{{$index}},
{{end}} }).Decode(&entity); err != nil {
return nil, err
}
return &entity, nil`
type mongoFindTemplateData struct {
EntityType string
QueryFields []string
}
const findManyTemplate = ` cursor, err := r.collection.Find(ctx, bson.M{
{{range $index, $field := .QueryFields}} "{{$field}}": arg{{$index}},
{{end}} })
if err != nil {
return nil, err
}
var entities []*{{.EntityType}}
if err := cursor.All(ctx, &entities); err != nil {
return nil, err
}
return entities, nil`

View file

@ -32,15 +32,15 @@ type Operation interface {
// FindOperation is a method specification for find operations
type FindOperation struct {
Mode QueryMode
Query Query
Query QuerySpec
}
// Query is a condition of querying the database
type Query struct {
// QuerySpec is a condition of querying the database
type QuerySpec struct {
Fields []string
}
// NumberOfArguments returns number of arguments required to perform the query
func (q Query) NumberOfArguments() int {
func (q QuerySpec) NumberOfArguments() int {
return len(q.Fields)
}

View file

@ -58,12 +58,12 @@ func (p repositoryInterfaceParser) parseFindMethod(method code.Method, tokens []
return MethodSpec{}, err
}
query, err := p.parseQuery(tokens)
querySpec, err := p.parseQuery(tokens)
if err != nil {
return MethodSpec{}, err
}
if query.NumberOfArguments()+1 != len(method.Params) {
if querySpec.NumberOfArguments()+1 != len(method.Params) {
return MethodSpec{}, errors.New("method parameter not supported")
}
@ -73,7 +73,7 @@ func (p repositoryInterfaceParser) parseFindMethod(method code.Method, tokens []
Returns: method.Returns,
Operation: FindOperation{
Mode: mode,
Query: query,
Query: querySpec,
},
}, nil
}
@ -111,13 +111,13 @@ func (p repositoryInterfaceParser) extractFindReturns(returns []code.Type) (Quer
return "", errors.New("method return not supported")
}
func (p repositoryInterfaceParser) parseQuery(tokens []string) (Query, error) {
func (p repositoryInterfaceParser) parseQuery(tokens []string) (QuerySpec, error) {
if len(tokens) == 0 {
return Query{}, errors.New("method name not supported")
return QuerySpec{}, errors.New("method name not supported")
}
if len(tokens) == 1 && tokens[0] == "All" {
return Query{}, nil
return QuerySpec{}, nil
}
if tokens[0] == "One" {
@ -128,7 +128,7 @@ func (p repositoryInterfaceParser) parseQuery(tokens []string) (Query, error) {
}
if tokens[0] == "And" {
return Query{}, errors.New("method name not supported")
return QuerySpec{}, errors.New("method name not supported")
}
var queryFields []string
var aggregatedToken string
@ -141,9 +141,9 @@ func (p repositoryInterfaceParser) parseQuery(tokens []string) (Query, error) {
}
}
if aggregatedToken == "" {
return Query{}, errors.New("method name not supported")
return QuerySpec{}, errors.New("method name not supported")
}
queryFields = append(queryFields, aggregatedToken)
return Query{Fields: queryFields}, nil
return QuerySpec{Fields: queryFields}, nil
}

View file

@ -58,7 +58,7 @@ func TestParseRepositoryInterface(t *testing.T) {
},
Operation: spec.FindOperation{
Mode: spec.QueryModeOne,
Query: spec.Query{Fields: []string{"ID"}},
Query: spec.QuerySpec{Fields: []string{"ID"}},
},
},
},
@ -98,7 +98,7 @@ func TestParseRepositoryInterface(t *testing.T) {
Operation: spec.FindOperation{
Mode: spec.QueryModeOne,
Query: spec.Query{Fields: []string{"PhoneNumber"}},
Query: spec.QuerySpec{Fields: []string{"PhoneNumber"}},
},
},
},
@ -137,7 +137,7 @@ func TestParseRepositoryInterface(t *testing.T) {
},
Operation: spec.FindOperation{
Mode: spec.QueryModeMany,
Query: spec.Query{Fields: []string{"City"}},
Query: spec.QuerySpec{Fields: []string{"City"}},
},
},
},
@ -214,7 +214,7 @@ func TestParseRepositoryInterface(t *testing.T) {
},
Operation: spec.FindOperation{
Mode: spec.QueryModeMany,
Query: spec.Query{Fields: []string{"City", "Gender"}},
Query: spec.QuerySpec{Fields: []string{"City", "Gender"}},
},
},
},