Generate Mongo repository code from repository spec
This commit is contained in:
parent
3808684ed0
commit
a24a9f81d7
10 changed files with 470 additions and 19 deletions
internal/mongo
168
internal/mongo/generator.go
Normal file
168
internal/mongo/generator.go
Normal file
|
@ -0,0 +1,168 @@
|
|||
package mongo
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"text/template"
|
||||
|
||||
"github.com/sunboyy/repogen/internal/code"
|
||||
"github.com/sunboyy/repogen/internal/spec"
|
||||
"golang.org/x/tools/imports"
|
||||
)
|
||||
|
||||
// GenerateMongoRepository generates mongodb repository
|
||||
func GenerateMongoRepository(packageName string, structModel code.Struct, intf code.Interface) (string, error) {
|
||||
repositorySpec, err := spec.ParseRepositoryInterface(structModel, intf)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
generator := mongoRepositoryGenerator{
|
||||
PackageName: packageName,
|
||||
StructModel: structModel,
|
||||
RepositorySpec: repositorySpec,
|
||||
}
|
||||
|
||||
output, err := generator.Generate()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return output, nil
|
||||
}
|
||||
|
||||
type mongoRepositoryGenerator struct {
|
||||
PackageName string
|
||||
StructModel code.Struct
|
||||
RepositorySpec spec.RepositorySpec
|
||||
}
|
||||
|
||||
func (g mongoRepositoryGenerator) Generate() (string, error) {
|
||||
buffer := new(bytes.Buffer)
|
||||
if err := g.generateBaseContent(buffer); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
for _, method := range g.RepositorySpec.Methods {
|
||||
if err := g.generateMethod(buffer, method); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
newOutput, err := imports.Process("", buffer.Bytes(), nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return string(newOutput), nil
|
||||
}
|
||||
|
||||
func (g mongoRepositoryGenerator) generateBaseContent(buffer *bytes.Buffer) error {
|
||||
tmpl, err := template.New("mongo_repository_base").Parse(baseTemplate)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tmplData := mongoBaseTemplateData{
|
||||
PackageName: g.PackageName,
|
||||
InterfaceName: g.RepositorySpec.InterfaceName,
|
||||
StructName: g.structName(),
|
||||
}
|
||||
|
||||
if err := tmpl.Execute(buffer, tmplData); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g mongoRepositoryGenerator) generateMethod(buffer *bytes.Buffer, method spec.MethodSpec) error {
|
||||
tmpl, err := template.New("mongo_repository_method").Parse(methodTemplate)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
implementation, err := g.generateMethodImplementation(method)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var paramTypes []code.Type
|
||||
for _, param := range method.Params[1:] {
|
||||
paramTypes = append(paramTypes, param.Type)
|
||||
}
|
||||
|
||||
tmplData := mongoMethodTemplateData{
|
||||
StructName: g.structName(),
|
||||
MethodName: method.Name,
|
||||
ParamTypes: paramTypes,
|
||||
ReturnTypes: method.Returns,
|
||||
Implementation: implementation,
|
||||
}
|
||||
|
||||
if err := tmpl.Execute(buffer, tmplData); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g mongoRepositoryGenerator) generateMethodImplementation(methodSpec spec.MethodSpec) (string, error) {
|
||||
switch operation := methodSpec.Operation.(type) {
|
||||
case spec.FindOperation:
|
||||
return g.generateFindImplementation(operation)
|
||||
}
|
||||
|
||||
return "", errors.New("method spec not supported")
|
||||
}
|
||||
|
||||
func (g mongoRepositoryGenerator) generateFindImplementation(operation spec.FindOperation) (string, error) {
|
||||
buffer := new(bytes.Buffer)
|
||||
|
||||
var queryFields []string
|
||||
for _, fieldName := range operation.Query.Fields {
|
||||
structField, ok := g.StructModel.Fields.ByName(fieldName)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("struct field %s not found", fieldName)
|
||||
}
|
||||
|
||||
bsonTag, ok := structField.Tags["bson"]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("struct field %s does not have bson tag", fieldName)
|
||||
}
|
||||
|
||||
queryFields = append(queryFields, bsonTag[0])
|
||||
}
|
||||
|
||||
tmplData := mongoFindTemplateData{
|
||||
EntityType: g.StructModel.Name,
|
||||
QueryFields: queryFields,
|
||||
}
|
||||
|
||||
if operation.Mode == spec.QueryModeOne {
|
||||
tmpl, err := template.New("mongo_repository_findone").Parse(findOneTemplate)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if err := tmpl.Execute(buffer, tmplData); err != nil {
|
||||
return "", err
|
||||
}
|
||||
} else {
|
||||
tmpl, err := template.New("mongo_repository_findmany").Parse(findManyTemplate)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if err := tmpl.Execute(buffer, tmplData); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
return buffer.String(), nil
|
||||
}
|
||||
|
||||
func (g mongoRepositoryGenerator) structName() string {
|
||||
return g.RepositorySpec.InterfaceName + "Mongo"
|
||||
}
|
154
internal/mongo/generator_test.go
Normal file
154
internal/mongo/generator_test.go
Normal file
|
@ -0,0 +1,154 @@
|
|||
package mongo_test
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/sunboyy/repogen/internal/code"
|
||||
"github.com/sunboyy/repogen/internal/mongo"
|
||||
)
|
||||
|
||||
func TestGenerateMongoRepository(t *testing.T) {
|
||||
userModel := code.Struct{
|
||||
Name: "UserModel",
|
||||
Fields: code.StructFields{
|
||||
{
|
||||
Name: "ID",
|
||||
Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"},
|
||||
Tags: map[string][]string{"bson": {"_id", "omitempty"}},
|
||||
},
|
||||
{
|
||||
Name: "Username",
|
||||
Type: code.SimpleType("string"),
|
||||
Tags: map[string][]string{"bson": {"username"}},
|
||||
},
|
||||
},
|
||||
}
|
||||
intf := code.Interface{
|
||||
Name: "UserRepository",
|
||||
Methods: []code.Method{
|
||||
{
|
||||
Name: "FindByID",
|
||||
Params: []code.Param{
|
||||
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
|
||||
{Name: "id", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}},
|
||||
},
|
||||
Returns: []code.Type{code.PointerType{ContainedType: code.SimpleType("UserModel")}, code.SimpleType("error")},
|
||||
},
|
||||
{
|
||||
Name: "FindOneByUsername",
|
||||
Params: []code.Param{
|
||||
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
|
||||
{Name: "username", Type: code.SimpleType("string")},
|
||||
},
|
||||
Returns: []code.Type{
|
||||
code.PointerType{ContainedType: code.SimpleType("UserModel")},
|
||||
code.SimpleType("error"),
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "FindByUsername",
|
||||
Params: []code.Param{
|
||||
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
|
||||
{Name: "username", Type: code.SimpleType("string")},
|
||||
},
|
||||
Returns: []code.Type{
|
||||
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
|
||||
code.SimpleType("error"),
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "FindByIDAndUsername",
|
||||
Params: []code.Param{
|
||||
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
|
||||
{Name: "id", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}},
|
||||
{Name: "username", Type: code.SimpleType("string")},
|
||||
},
|
||||
Returns: []code.Type{
|
||||
code.PointerType{ContainedType: code.SimpleType("UserModel")},
|
||||
code.SimpleType("error"),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
code, err := mongo.GenerateMongoRepository("user", userModel, intf)
|
||||
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
expectedCode := `// Code generated by repogen. DO NOT EDIT.
|
||||
package user
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
)
|
||||
|
||||
func NewUserRepository(collection *mongo.Collection) UserRepository {
|
||||
return &UserRepositoryMongo{
|
||||
collection: collection,
|
||||
}
|
||||
}
|
||||
|
||||
type UserRepositoryMongo struct {
|
||||
collection *mongo.Collection
|
||||
}
|
||||
|
||||
func (r *UserRepositoryMongo) FindByID(ctx context.Context, arg0 primitive.ObjectID) (*UserModel, error) {
|
||||
var entity UserModel
|
||||
if err := r.collection.FindOne(ctx, bson.M{
|
||||
"_id": arg0,
|
||||
}).Decode(&entity); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &entity, nil
|
||||
}
|
||||
|
||||
func (r *UserRepositoryMongo) FindOneByUsername(ctx context.Context, arg0 string) (*UserModel, error) {
|
||||
var entity UserModel
|
||||
if err := r.collection.FindOne(ctx, bson.M{
|
||||
"username": arg0,
|
||||
}).Decode(&entity); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &entity, nil
|
||||
}
|
||||
|
||||
func (r *UserRepositoryMongo) FindByUsername(ctx context.Context, arg0 string) ([]*UserModel, error) {
|
||||
cursor, err := r.collection.Find(ctx, bson.M{
|
||||
"username": arg0,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var entities []*UserModel
|
||||
if err := cursor.All(ctx, &entities); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return entities, nil
|
||||
}
|
||||
|
||||
func (r *UserRepositoryMongo) FindByIDAndUsername(ctx context.Context, arg0 primitive.ObjectID, arg1 string) (*UserModel, error) {
|
||||
var entity UserModel
|
||||
if err := r.collection.FindOne(ctx, bson.M{
|
||||
"_id": arg0,
|
||||
"username": arg1,
|
||||
}).Decode(&entity); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &entity, nil
|
||||
}
|
||||
`
|
||||
expectedCodeLines := strings.Split(expectedCode, "\n")
|
||||
actualCodeLines := strings.Split(code, "\n")
|
||||
|
||||
for i, line := range expectedCodeLines {
|
||||
if line != actualCodeLines[i] {
|
||||
t.Errorf("On line %d\nExpected = %v\nActual = %v", i, line, actualCodeLines[i])
|
||||
}
|
||||
}
|
||||
}
|
99
internal/mongo/templates.go
Normal file
99
internal/mongo/templates.go
Normal file
|
@ -0,0 +1,99 @@
|
|||
package mongo
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/sunboyy/repogen/internal/code"
|
||||
)
|
||||
|
||||
const baseTemplate = `// Code generated by repogen. DO NOT EDIT.
|
||||
package {{.PackageName}}
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
)
|
||||
|
||||
func New{{.InterfaceName}}(collection *mongo.Collection) {{.InterfaceName}} {
|
||||
return &{{.StructName}}{
|
||||
collection: collection,
|
||||
}
|
||||
}
|
||||
|
||||
type {{.StructName}} struct {
|
||||
collection *mongo.Collection
|
||||
}
|
||||
`
|
||||
|
||||
type mongoBaseTemplateData struct {
|
||||
PackageName string
|
||||
InterfaceName string
|
||||
StructName string
|
||||
}
|
||||
|
||||
const methodTemplate = `
|
||||
func (r *{{.StructName}}) {{.MethodName}}(ctx context.Context, {{.Parameters}}){{.Returns}} {
|
||||
{{.Implementation}}
|
||||
}
|
||||
`
|
||||
|
||||
type mongoMethodTemplateData struct {
|
||||
StructName string
|
||||
MethodName string
|
||||
ParamTypes []code.Type
|
||||
ReturnTypes []code.Type
|
||||
Implementation string
|
||||
}
|
||||
|
||||
func (data mongoMethodTemplateData) Parameters() string {
|
||||
var params []string
|
||||
for i, paramType := range data.ParamTypes {
|
||||
params = append(params, fmt.Sprintf("arg%d %s", i, paramType.Code()))
|
||||
}
|
||||
return strings.Join(params, ", ")
|
||||
}
|
||||
|
||||
func (data mongoMethodTemplateData) Returns() string {
|
||||
if len(data.ReturnTypes) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
if len(data.ReturnTypes) == 1 {
|
||||
return fmt.Sprintf(" %s", data.ReturnTypes[0].Code())
|
||||
}
|
||||
|
||||
var returns []string
|
||||
for _, returnType := range data.ReturnTypes {
|
||||
returns = append(returns, returnType.Code())
|
||||
}
|
||||
return fmt.Sprintf(" (%s)", strings.Join(returns, ", "))
|
||||
}
|
||||
|
||||
const findOneTemplate = ` var entity {{.EntityType}}
|
||||
if err := r.collection.FindOne(ctx, bson.M{
|
||||
{{range $index, $field := .QueryFields}} "{{$field}}": arg{{$index}},
|
||||
{{end}} }).Decode(&entity); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &entity, nil`
|
||||
|
||||
type mongoFindTemplateData struct {
|
||||
EntityType string
|
||||
QueryFields []string
|
||||
}
|
||||
|
||||
const findManyTemplate = ` cursor, err := r.collection.Find(ctx, bson.M{
|
||||
{{range $index, $field := .QueryFields}} "{{$field}}": arg{{$index}},
|
||||
{{end}} })
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var entities []*{{.EntityType}}
|
||||
if err := cursor.All(ctx, &entities); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return entities, nil`
|
Loading…
Add table
Add a link
Reference in a new issue