Segregate base code generation logic to a new package (#31)

This commit is contained in:
sunboyy 2022-10-18 19:37:50 +07:00 committed by GitHub
parent 737c1a4044
commit ec08a5a918
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
22 changed files with 1761 additions and 937 deletions

View file

@ -78,7 +78,7 @@ type UserModel struct {
},
code.StructField{
Name: "Username",
Type: code.SimpleType("string"),
Type: code.TypeString,
Tags: map[string][]string{
"bson": {"username"},
"json": {"username"},
@ -120,7 +120,7 @@ type UserRepository interface {
},
Returns: []code.Type{
code.PointerType{ContainedType: code.SimpleType("UserModel")},
code.SimpleType("error"),
code.TypeError,
},
},
{
@ -130,19 +130,19 @@ type UserRepository interface {
},
Returns: []code.Type{
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
code.SimpleType("error"),
code.TypeError,
},
},
{
Name: "FindByAgeBetween",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "fromAge", Type: code.SimpleType("int")},
{Name: "toAge", Type: code.SimpleType("int")},
{Name: "fromAge", Type: code.TypeInt},
{Name: "toAge", Type: code.TypeInt},
},
Returns: []code.Type{
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
code.SimpleType("error"),
code.TypeError,
},
},
{
@ -153,19 +153,19 @@ type UserRepository interface {
},
Returns: []code.Type{
code.InterfaceType{},
code.SimpleType("error"),
code.TypeError,
},
},
{
Name: "UpdateAgreementByID",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "agreement", Type: code.MapType{KeyType: code.SimpleType("string"), ValueType: code.SimpleType("bool")}},
{Name: "agreement", Type: code.MapType{KeyType: code.TypeString, ValueType: code.TypeBool}},
{Name: "id", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}},
},
Returns: []code.Type{
code.SimpleType("bool"),
code.SimpleType("error"),
code.TypeBool,
code.TypeError,
},
},
{
@ -178,7 +178,7 @@ type UserRepository interface {
{
Name: "Run",
Params: []code.Param{
{Name: "arg1", Type: code.SimpleType("int")},
{Name: "arg1", Type: code.TypeInt},
},
},
},
@ -191,7 +191,7 @@ type UserRepository interface {
{
Name: "Do",
Params: []code.Param{
{Name: "arg2", Type: code.SimpleType("string")},
{Name: "arg2", Type: code.TypeString},
},
},
},
@ -243,7 +243,7 @@ type UserRepository interface {
},
code.StructField{
Name: "Username",
Type: code.SimpleType("string"),
Type: code.TypeString,
Tags: map[string][]string{
"bson": {"username"},
"json": {"username"},
@ -264,7 +264,7 @@ type UserRepository interface {
},
Returns: []code.Type{
code.PointerType{ContainedType: code.SimpleType("UserModel")},
code.SimpleType("error"),
code.TypeError,
},
},
{
@ -274,7 +274,7 @@ type UserRepository interface {
},
Returns: []code.Type{
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
code.SimpleType("error"),
code.TypeError,
},
},
},

View file

@ -100,6 +100,15 @@ func (t SimpleType) IsNumber() bool {
t == "float32" || t == "float64"
}
// commonly-used types
const (
TypeBool = SimpleType("bool")
TypeInt = SimpleType("int")
TypeFloat64 = SimpleType("float64")
TypeString = SimpleType("string")
TypeError = SimpleType("error")
)
// ExternalType is a type that is called to another package
type ExternalType struct {
PackageAlias string

View file

@ -9,7 +9,7 @@ import (
func TestStructFieldsByName(t *testing.T) {
idField := code.StructField{Name: "ID", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}
usernameField := code.StructField{Name: "Username", Type: code.SimpleType("string")}
usernameField := code.StructField{Name: "Username", Type: code.TypeString}
fields := code.StructFields{idField, usernameField}
t.Run("struct field found", func(t *testing.T) {
@ -91,7 +91,7 @@ func TestTypeIsNumber(t *testing.T) {
testTable := []TypeIsNumberTestCase{
{
Name: "simple type: int",
Type: code.SimpleType("int"),
Type: code.TypeInt,
IsNumber: true,
},
{
@ -116,12 +116,12 @@ func TestTypeIsNumber(t *testing.T) {
},
{
Name: "simple type: other float variants",
Type: code.SimpleType("float64"),
Type: code.TypeFloat64,
IsNumber: true,
},
{
Name: "simple type: non-number primitive type",
Type: code.SimpleType("string"),
Type: code.TypeString,
IsNumber: false,
},
{
@ -136,22 +136,22 @@ func TestTypeIsNumber(t *testing.T) {
},
{
Name: "pointer type: number",
Type: code.PointerType{ContainedType: code.SimpleType("int")},
Type: code.PointerType{ContainedType: code.TypeInt},
IsNumber: true,
},
{
Name: "pointer type: non-number",
Type: code.PointerType{ContainedType: code.SimpleType("string")},
Type: code.PointerType{ContainedType: code.TypeString},
IsNumber: false,
},
{
Name: "array type",
Type: code.ArrayType{ContainedType: code.SimpleType("int")},
Type: code.ArrayType{ContainedType: code.TypeInt},
IsNumber: false,
},
{
Name: "map type",
Type: code.MapType{KeyType: code.SimpleType("int"), ValueType: code.SimpleType("float64")},
Type: code.MapType{KeyType: code.TypeInt, ValueType: code.TypeFloat64},
IsNumber: false,
},
{

41
internal/codegen/base.go Normal file
View file

@ -0,0 +1,41 @@
package codegen
import (
"fmt"
"strings"
"github.com/sunboyy/repogen/internal/code"
)
const baseTemplate = `// Code generated by {{.Program}}. DO NOT EDIT.
package {{.PackageName}}
import (
{{.GenImports}}
)
`
type baseTemplateData struct {
Program string
PackageName string
Imports [][]code.Import
}
func (data baseTemplateData) GenImports() string {
var sections []string
for _, importSection := range data.Imports {
var section []string
for _, imp := range importSection {
section = append(section, data.generateImportLine(imp))
}
sections = append(sections, strings.Join(section, "\n"))
}
return strings.Join(sections, "\n\n")
}
func (data baseTemplateData) generateImportLine(imp code.Import) string {
if imp.Name == "" {
return fmt.Sprintf("\t\"%s\"", imp.Path)
}
return fmt.Sprintf("\t%s \"%s\"", imp.Name, imp.Path)
}

View file

@ -0,0 +1,83 @@
package codegen
import (
"bytes"
"text/template"
"github.com/sunboyy/repogen/internal/code"
"golang.org/x/tools/imports"
)
type Builder struct {
// Program defines generator program name in the generated file.
Program string
// PackageName defines the package name of the generated file.
PackageName string
// Imports defines necessary imports to reduce ambiguity when generating
// formatting the raw-generated code.
Imports [][]code.Import
implementers []Implementer
}
// Implementer is an interface that wraps the basic Impl method for code
// generation.
type Implementer interface {
Impl(buffer *bytes.Buffer) error
}
// NewBuilder is a constructor of Builder struct.
func NewBuilder(program string, packageName string, imports [][]code.Import) *Builder {
return &Builder{
Program: program,
PackageName: packageName,
Imports: imports,
}
}
// AddImplementer appends a new implemeneter to the implementer list.
func (b *Builder) AddImplementer(implementer Implementer) {
b.implementers = append(b.implementers, implementer)
}
// Build generates code from the previously provided specifications.
func (b Builder) Build() (string, error) {
buffer := new(bytes.Buffer)
if err := b.buildBase(buffer); err != nil {
return "", err
}
for _, impl := range b.implementers {
if err := impl.Impl(buffer); err != nil {
return "", err
}
}
formattedCode, err := imports.Process("", buffer.Bytes(), nil)
if err != nil {
return "", err
}
return string(formattedCode), nil
}
func (b Builder) buildBase(buffer *bytes.Buffer) error {
tmpl, err := template.New("file_base").Parse(baseTemplate)
if err != nil {
return err
}
tmplData := baseTemplateData{
Program: b.Program,
PackageName: b.PackageName,
Imports: b.Imports,
}
// writing to a buffer should not cause errors.
_ = tmpl.Execute(buffer, tmplData)
return nil
}

View file

@ -0,0 +1,102 @@
package codegen_test
import (
"testing"
"github.com/sunboyy/repogen/internal/code"
"github.com/sunboyy/repogen/internal/codegen"
"github.com/sunboyy/repogen/internal/testutils"
)
const expectedBuildCode = `// Code generated by repogen. DO NOT EDIT.
package user
import (
_ "context"
"go.mongodb.org/mongo-driver/bson/primitive"
_ "go.mongodb.org/mongo-driver/mongo/options"
)
type User struct {
ID primitive.ObjectID ` + "`bson:\"id\" json:\"id,omitempty\"`" + `
Username string
}
func NewUser(username string) User {
return User{
ID: primitive.NewObjectID(),
Username: username,
}
}
func (u User) IDHex() string {
return u.ID.Hex()
}
`
func TestBuilderBuild(t *testing.T) {
builder := codegen.NewBuilder("repogen", "user", [][]code.Import{
{
{
Name: "_",
Path: "context",
},
},
{
{
Path: "go.mongodb.org/mongo-driver/bson/primitive",
},
{
Name: "_",
Path: "go.mongodb.org/mongo-driver/mongo/options",
},
},
})
builder.AddImplementer(codegen.StructBuilder{
Name: "User",
Fields: code.StructFields{
{
Name: "ID",
Type: code.ExternalType{
PackageAlias: "primitive",
Name: "ObjectID",
},
Tags: map[string][]string{
"bson": {"id"},
"json": {"id", "omitempty"},
},
},
{
Name: "Username",
Type: code.TypeString,
},
},
})
builder.AddImplementer(codegen.FunctionBuilder{
Name: "NewUser",
Params: []code.Param{
{Name: "username", Type: code.TypeString},
},
Returns: []code.Type{code.SimpleType("User")},
Body: ` return User{
ID: primitive.NewObjectID(),
Username: username,
}`,
})
builder.AddImplementer(codegen.MethodBuilder{
Receiver: codegen.MethodReceiver{Name: "u", Type: code.SimpleType("User")},
Name: "IDHex",
Params: nil,
Returns: []code.Type{code.TypeString},
Body: " return u.ID.Hex()",
})
generatedCode, err := builder.Build()
if err != nil {
t.Fatal(err)
}
if err := testutils.ExpectMultiLineString(expectedBuildCode, generatedCode); err != nil {
t.Error(err)
}
}

View file

@ -0,0 +1,73 @@
package codegen
import (
"bytes"
"fmt"
"strings"
"text/template"
"github.com/sunboyy/repogen/internal/code"
)
const functionTemplate = `
func {{.Name}}({{.GenParams}}){{.GenReturns}} {
{{.Body}}
}
`
// FunctionBuilder is an implementer of a function.
type FunctionBuilder struct {
Name string
Params []code.Param
Returns []code.Type
Body string
}
// Impl writes function declatation code to the buffer.
func (fb FunctionBuilder) Impl(buffer *bytes.Buffer) error {
tmpl, err := template.New("function").Parse(functionTemplate)
if err != nil {
return err
}
// writing to a buffer should not cause errors.
_ = tmpl.Execute(buffer, fb)
return nil
}
func (fb FunctionBuilder) GenParams() string {
return generateParams(fb.Params)
}
func (fb FunctionBuilder) GenReturns() string {
return generateReturns(fb.Returns)
}
func generateParams(params []code.Param) string {
var paramList []string
for _, param := range params {
paramList = append(
paramList,
fmt.Sprintf("%s %s", param.Name, param.Type.Code()),
)
}
return strings.Join(paramList, ", ")
}
func generateReturns(returns []code.Type) string {
if len(returns) == 0 {
return ""
}
if len(returns) == 1 {
return " " + returns[0].Code()
}
var returnList []string
for _, ret := range returns {
returnList = append(returnList, ret.Code())
}
return fmt.Sprintf(" (%s)", strings.Join(returnList, ", "))
}

View file

@ -0,0 +1,125 @@
package codegen_test
import (
"bytes"
"testing"
"github.com/sunboyy/repogen/internal/code"
"github.com/sunboyy/repogen/internal/codegen"
"github.com/sunboyy/repogen/internal/testutils"
)
func TestFunctionBuilderBuild_NoReturn(t *testing.T) {
fb := codegen.FunctionBuilder{
Name: "init",
Params: nil,
Returns: nil,
Body: ` logrus.SetLevel(logrus.DebugLevel)`,
}
expectedCode := `
func init() {
logrus.SetLevel(logrus.DebugLevel)
}
`
buffer := new(bytes.Buffer)
err := fb.Impl(buffer)
if err != nil {
t.Fatal(err)
}
actual := buffer.String()
if err := testutils.ExpectMultiLineString(
expectedCode,
actual,
); err != nil {
t.Error(err)
}
}
func TestFunctionBuilderBuild_OneReturn(t *testing.T) {
fb := codegen.FunctionBuilder{
Name: "NewUser",
Params: []code.Param{
{
Name: "username",
Type: code.TypeString,
},
{
Name: "age",
Type: code.TypeInt,
},
{
Name: "parent",
Type: code.PointerType{ContainedType: code.SimpleType("User")},
},
},
Returns: []code.Type{
code.SimpleType("User"),
},
Body: ` return User{
Username: username,
Age: age,
Parent: parent
}`,
}
expectedCode := `
func NewUser(username string, age int, parent *User) User {
return User{
Username: username,
Age: age,
Parent: parent
}
}
`
buffer := new(bytes.Buffer)
err := fb.Impl(buffer)
if err != nil {
t.Fatal(err)
}
actual := buffer.String()
if err := testutils.ExpectMultiLineString(
expectedCode,
actual,
); err != nil {
t.Error(err)
}
}
func TestFunctionBuilderBuild_MultiReturn(t *testing.T) {
fb := codegen.FunctionBuilder{
Name: "Save",
Params: []code.Param{
{
Name: "user",
Type: code.SimpleType("User"),
},
},
Returns: []code.Type{
code.SimpleType("User"),
code.TypeError,
},
Body: ` return collection.Save(user)`,
}
expectedCode := `
func Save(user User) (User, error) {
return collection.Save(user)
}
`
buffer := new(bytes.Buffer)
err := fb.Impl(buffer)
if err != nil {
t.Fatal(err)
}
actual := buffer.String()
if err := testutils.ExpectMultiLineString(
expectedCode,
actual,
); err != nil {
t.Error(err)
}
}

View file

@ -0,0 +1,66 @@
package codegen
import (
"bytes"
"fmt"
"text/template"
"github.com/sunboyy/repogen/internal/code"
)
const methodTemplate = `
func ({{.GenReceiver}}) {{.Name}}({{.GenParams}}){{.GenReturns}} {
{{.Body}}
}
`
// MethodBuilder is an implementer of a method.
type MethodBuilder struct {
Receiver MethodReceiver
Name string
Params []code.Param
Returns []code.Type
Body string
}
// MethodReceiver describes a specification of a method receiver.
type MethodReceiver struct {
Name string
Type code.SimpleType
Pointer bool
}
// Impl writes method declatation code to the buffer.
func (mb MethodBuilder) Impl(buffer *bytes.Buffer) error {
tmpl, err := template.New("function").Parse(methodTemplate)
if err != nil {
return err
}
// writing to a buffer should not cause errors.
_ = tmpl.Execute(buffer, mb)
return nil
}
func (mb MethodBuilder) GenReceiver() string {
if mb.Receiver.Name == "" {
return mb.generateReceiverType()
}
return fmt.Sprintf("%s %s", mb.Receiver.Name, mb.generateReceiverType())
}
func (mb MethodBuilder) generateReceiverType() string {
if !mb.Receiver.Pointer {
return mb.Receiver.Type.Code()
}
return code.PointerType{ContainedType: mb.Receiver.Type}.Code()
}
func (mb MethodBuilder) GenParams() string {
return generateParams(mb.Params)
}
func (mb MethodBuilder) GenReturns() string {
return generateReturns(mb.Returns)
}

View file

@ -0,0 +1,142 @@
package codegen_test
import (
"bytes"
"testing"
"github.com/sunboyy/repogen/internal/code"
"github.com/sunboyy/repogen/internal/codegen"
"github.com/sunboyy/repogen/internal/testutils"
)
func TestMethodBuilderBuild_IgnoreReceiverNoReturn(t *testing.T) {
fb := codegen.MethodBuilder{
Receiver: codegen.MethodReceiver{Type: "User"},
Name: "Init",
Params: nil,
Returns: nil,
Body: ` db.Init(&User{})`,
}
expectedCode := `
func (User) Init() {
db.Init(&User{})
}
`
buffer := new(bytes.Buffer)
err := fb.Impl(buffer)
if err != nil {
t.Fatal(err)
}
actual := buffer.String()
if err := testutils.ExpectMultiLineString(
expectedCode,
actual,
); err != nil {
t.Error(err)
}
}
func TestMethodBuilderBuild_IgnorePoinerReceiverOneReturn(t *testing.T) {
fb := codegen.MethodBuilder{
Receiver: codegen.MethodReceiver{
Type: "User",
Pointer: true,
},
Name: "Init",
Params: nil,
Returns: []code.Type{code.TypeError},
Body: ` return db.Init(&User{})`,
}
expectedCode := `
func (*User) Init() error {
return db.Init(&User{})
}
`
buffer := new(bytes.Buffer)
err := fb.Impl(buffer)
if err != nil {
t.Fatal(err)
}
actual := buffer.String()
if err := testutils.ExpectMultiLineString(
expectedCode,
actual,
); err != nil {
t.Error(err)
}
}
func TestMethodBuilderBuild_UseReceiverMultiReturn(t *testing.T) {
fb := codegen.MethodBuilder{
Receiver: codegen.MethodReceiver{
Name: "u",
Type: "User",
},
Name: "WithAge",
Params: []code.Param{
{Name: "age", Type: code.TypeInt},
},
Returns: []code.Type{code.SimpleType("User"), code.TypeError},
Body: ` u.Age = age
return u`,
}
expectedCode := `
func (u User) WithAge(age int) (User, error) {
u.Age = age
return u
}
`
buffer := new(bytes.Buffer)
err := fb.Impl(buffer)
if err != nil {
t.Fatal(err)
}
actual := buffer.String()
if err := testutils.ExpectMultiLineString(
expectedCode,
actual,
); err != nil {
t.Error(err)
}
}
func TestMethodBuilderBuild_UsePointerReceiverNoReturn(t *testing.T) {
fb := codegen.MethodBuilder{
Receiver: codegen.MethodReceiver{
Name: "u",
Type: "User",
Pointer: true,
},
Name: "SetAge",
Params: []code.Param{
{Name: "age", Type: code.TypeInt},
},
Returns: nil,
Body: ` u.Age = age`,
}
expectedCode := `
func (u *User) SetAge(age int) {
u.Age = age
}
`
buffer := new(bytes.Buffer)
err := fb.Impl(buffer)
if err != nil {
t.Fatal(err)
}
actual := buffer.String()
if err := testutils.ExpectMultiLineString(
expectedCode,
actual,
); err != nil {
t.Error(err)
}
}

View file

@ -0,0 +1,63 @@
package codegen
import (
"bytes"
"fmt"
"sort"
"strings"
"text/template"
"github.com/sunboyy/repogen/internal/code"
)
const structTemplate = `
type {{.Name}} struct {
{{.GenFields}}
}
`
// StructBuilder is an implementer of a struct.
type StructBuilder struct {
Name string
Fields code.StructFields
}
// Impl writes struct declatation code to the buffer.
func (sb StructBuilder) Impl(buffer *bytes.Buffer) error {
tmpl, err := template.New("struct").Parse(structTemplate)
if err != nil {
return err
}
// writing to a buffer should not cause errors.
_ = tmpl.Execute(buffer, sb)
return nil
}
func (sb StructBuilder) GenFields() string {
var fieldLines []string
for _, field := range sb.Fields {
fieldLine := fmt.Sprintf("\t%s %s", field.Name, field.Type.Code())
if len(field.Tags) > 0 {
fieldLine += fmt.Sprintf(" `%s`", sb.generateStructTag(field.Tags))
}
fieldLines = append(fieldLines, fieldLine)
}
return strings.Join(fieldLines, "\n")
}
func (sb StructBuilder) generateStructTag(tags map[string][]string) string {
var tagKeys []string
for key := range tags {
tagKeys = append(tagKeys, key)
}
sort.Strings(tagKeys)
var tagGroups []string
for _, key := range tagKeys {
tagValue := strings.Join(tags[key], ",")
tagGroups = append(tagGroups, fmt.Sprintf("%s:\"%s\"", key, tagValue))
}
return strings.Join(tagGroups, " ")
}

View file

@ -0,0 +1,73 @@
package codegen_test
import (
"bytes"
"testing"
"github.com/sunboyy/repogen/internal/code"
"github.com/sunboyy/repogen/internal/codegen"
"github.com/sunboyy/repogen/internal/testutils"
)
const expectedStructBuilderCode = `
type User struct {
ID primitive.ObjectID ` + "`bson:\"id,omitempty\" json:\"id,omitempty\"`" + `
Username string ` + "`bson:\"username\" json:\"username\"`" + `
Age int ` + "`bson:\"age\"`" + `
orderCount *int
}
`
func TestStructBuilderBuild(t *testing.T) {
sb := codegen.StructBuilder{
Name: "User",
Fields: []code.StructField{
{
Name: "ID",
Type: code.ExternalType{
PackageAlias: "primitive",
Name: "ObjectID",
},
Tags: map[string][]string{
"json": {"id", "omitempty"},
"bson": {"id", "omitempty"},
},
},
{
Name: "Username",
Type: code.TypeString,
Tags: map[string][]string{
"json": {"username"},
"bson": {"username"},
},
},
{
Name: "Age",
Type: code.TypeInt,
Tags: map[string][]string{
"bson": {"age"},
},
},
{
Name: "orderCount",
Type: code.PointerType{
ContainedType: code.TypeInt,
},
},
},
}
buffer := new(bytes.Buffer)
err := sb.Impl(buffer)
if err != nil {
t.Fatal(err)
}
actual := buffer.String()
if err := testutils.ExpectMultiLineString(
expectedStructBuilderCode,
actual,
); err != nil {
t.Error(err)
}
}

View file

@ -1,75 +1,40 @@
package generator
import (
"bytes"
"html/template"
"github.com/sunboyy/repogen/internal/code"
"github.com/sunboyy/repogen/internal/codegen"
"github.com/sunboyy/repogen/internal/mongo"
"github.com/sunboyy/repogen/internal/spec"
"golang.org/x/tools/imports"
)
// GenerateRepository generates repository implementation from repository interface specification
func GenerateRepository(packageName string, structModel code.Struct, interfaceName string,
methodSpecs []spec.MethodSpec) (string, error) {
// GenerateRepository generates repository implementation code from repository
// interface specification.
func GenerateRepository(packageName string, structModel code.Struct,
interfaceName string, methodSpecs []spec.MethodSpec) (string, error) {
repositoryGenerator := repositoryGenerator{
PackageName: packageName,
StructModel: structModel,
InterfaceName: interfaceName,
MethodSpecs: methodSpecs,
Generator: mongo.NewGenerator(structModel, interfaceName),
}
generator := mongo.NewGenerator(structModel, interfaceName)
return repositoryGenerator.Generate()
}
codeBuilder := codegen.NewBuilder(
"repogen",
packageName,
generator.Imports(),
)
type repositoryGenerator struct {
PackageName string
StructModel code.Struct
InterfaceName string
MethodSpecs []spec.MethodSpec
Generator mongo.RepositoryGenerator
}
func (g repositoryGenerator) Generate() (string, error) {
buffer := new(bytes.Buffer)
if err := g.generateBase(buffer); err != nil {
return "", err
}
if err := g.Generator.GenerateConstructor(buffer); err != nil {
return "", err
}
for _, method := range g.MethodSpecs {
if err := g.Generator.GenerateMethod(method, buffer); err != nil {
return "", err
}
}
formattedCode, err := imports.Process("", buffer.Bytes(), nil)
constructorBuilder, err := generator.GenerateConstructor()
if err != nil {
return "", err
}
return string(formattedCode), nil
}
codeBuilder.AddImplementer(constructorBuilder)
codeBuilder.AddImplementer(generator.GenerateStruct())
func (g repositoryGenerator) generateBase(buffer *bytes.Buffer) error {
tmpl, err := template.New("file_base").Parse(baseTemplate)
for _, method := range methodSpecs {
methodBuilder, err := generator.GenerateMethod(method)
if err != nil {
return err
return "", err
}
codeBuilder.AddImplementer(methodBuilder)
}
tmplData := baseTemplateData{
PackageName: g.PackageName,
}
if err := tmpl.Execute(buffer, tmplData); err != nil {
return err
}
return nil
return codeBuilder.Build()
}

View file

@ -22,7 +22,7 @@ var (
}
ageField = code.StructField{
Name: "Age",
Type: code.SimpleType("int"),
Type: code.TypeInt,
Tags: map[string][]string{"bson": {"age"}},
}
)
@ -34,7 +34,7 @@ func TestGenerateMongoRepository(t *testing.T) {
idField,
code.StructField{
Name: "Username",
Type: code.SimpleType("string"),
Type: code.TypeString,
Tags: map[string][]string{"bson": {"username"}},
},
genderField,
@ -49,7 +49,7 @@ func TestGenerateMongoRepository(t *testing.T) {
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "id", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}},
},
Returns: []code.Type{code.PointerType{ContainedType: code.SimpleType("UserModel")}, code.SimpleType("error")},
Returns: []code.Type{code.PointerType{ContainedType: code.SimpleType("UserModel")}, code.TypeError},
Operation: spec.FindOperation{
Mode: spec.QueryModeOne,
Query: spec.QuerySpec{
@ -65,11 +65,11 @@ func TestGenerateMongoRepository(t *testing.T) {
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "gender", Type: code.SimpleType("Gender")},
{Name: "age", Type: code.SimpleType("int")},
{Name: "age", Type: code.TypeInt},
},
Returns: []code.Type{
code.PointerType{ContainedType: code.SimpleType("UserModel")},
code.SimpleType("error"),
code.TypeError,
},
Operation: spec.FindOperation{
Mode: spec.QueryModeMany,
@ -86,11 +86,11 @@ func TestGenerateMongoRepository(t *testing.T) {
Name: "FindByAgeLessThanEqualOrderByAge",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "age", Type: code.SimpleType("int")},
{Name: "age", Type: code.TypeInt},
},
Returns: []code.Type{
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
code.SimpleType("error"),
code.TypeError,
},
Operation: spec.FindOperation{
Mode: spec.QueryModeMany,
@ -108,11 +108,11 @@ func TestGenerateMongoRepository(t *testing.T) {
Name: "FindByAgeGreaterThanOrderByAgeAsc",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "age", Type: code.SimpleType("int")},
{Name: "age", Type: code.TypeInt},
},
Returns: []code.Type{
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
code.SimpleType("error"),
code.TypeError,
},
Operation: spec.FindOperation{
Mode: spec.QueryModeMany,
@ -130,11 +130,11 @@ func TestGenerateMongoRepository(t *testing.T) {
Name: "FindByAgeGreaterThanEqualOrderByAgeDesc",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "age", Type: code.SimpleType("int")},
{Name: "age", Type: code.TypeInt},
},
Returns: []code.Type{
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
code.SimpleType("error"),
code.TypeError,
},
Operation: spec.FindOperation{
Mode: spec.QueryModeMany,
@ -152,12 +152,12 @@ func TestGenerateMongoRepository(t *testing.T) {
Name: "FindByAgeBetween",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "fromAge", Type: code.SimpleType("int")},
{Name: "toAge", Type: code.SimpleType("int")},
{Name: "fromAge", Type: code.TypeInt},
{Name: "toAge", Type: code.TypeInt},
},
Returns: []code.Type{
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
code.SimpleType("error"),
code.TypeError,
},
Operation: spec.FindOperation{
Mode: spec.QueryModeMany,
@ -173,11 +173,11 @@ func TestGenerateMongoRepository(t *testing.T) {
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "gender", Type: code.SimpleType("Gender")},
{Name: "age", Type: code.SimpleType("int")},
{Name: "age", Type: code.TypeInt},
},
Returns: []code.Type{
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
code.SimpleType("error"),
code.TypeError,
},
Operation: spec.FindOperation{
Mode: spec.QueryModeMany,

View file

@ -1,9 +0,0 @@
package generator
const baseTemplate = `// Code generated by repogen. DO NOT EDIT.
package {{.PackageName}}
`
type baseTemplateData struct {
PackageName string
}

View file

@ -2,11 +2,12 @@ package mongo
import (
"bytes"
"io"
"fmt"
"strings"
"text/template"
"github.com/sunboyy/repogen/internal/code"
"github.com/sunboyy/repogen/internal/codegen"
"github.com/sunboyy/repogen/internal/spec"
)
@ -24,55 +25,104 @@ type RepositoryGenerator struct {
InterfaceName string
}
// GenerateConstructor generates mongo repository struct implementation and constructor for the struct
func (g RepositoryGenerator) GenerateConstructor(buffer io.Writer) error {
tmpl, err := template.New("mongo_repository_base").Parse(constructorTemplate)
if err != nil {
return err
// Imports returns necessary imports for the mongo repository implementation.
func (g RepositoryGenerator) Imports() [][]code.Import {
return [][]code.Import{
{
{Path: "context"},
},
{
{Path: "go.mongodb.org/mongo-driver/bson"},
{Path: "go.mongodb.org/mongo-driver/bson/primitive"},
{Path: "go.mongodb.org/mongo-driver/mongo"},
{Path: "go.mongodb.org/mongo-driver/mongo/options"},
},
}
tmplData := mongoConstructorTemplateData{
InterfaceName: g.InterfaceName,
ImplStructName: g.structName(),
}
if err := tmpl.Execute(buffer, tmplData); err != nil {
return err
}
return nil
}
// GenerateMethod generates implementation of from provided method specification
func (g RepositoryGenerator) GenerateMethod(methodSpec spec.MethodSpec, buffer io.Writer) error {
tmpl, err := template.New("mongo_repository_method").Parse(methodTemplate)
// GenerateStruct creates codegen.StructBuilder of mongo repository
// implementation struct.
func (g RepositoryGenerator) GenerateStruct() codegen.StructBuilder {
return codegen.StructBuilder{
Name: g.repoImplStructName(),
Fields: code.StructFields{
{
Name: "collection",
Type: code.PointerType{
ContainedType: code.ExternalType{
PackageAlias: "mongo",
Name: "Collection",
},
},
},
},
}
}
// GenerateConstructor creates codegen.FunctionBuilder of a constructor for
// mongo repository implementation struct.
func (g RepositoryGenerator) GenerateConstructor() (codegen.FunctionBuilder, error) {
tmpl, err := template.New("mongo_constructor_body").Parse(constructorBody)
if err != nil {
return err
return codegen.FunctionBuilder{}, err
}
tmplData := constructorBodyData{
ImplStructName: g.repoImplStructName(),
}
buffer := new(bytes.Buffer)
if err := tmpl.Execute(buffer, tmplData); err != nil {
return codegen.FunctionBuilder{}, err
}
return codegen.FunctionBuilder{
Name: "New" + g.InterfaceName,
Params: []code.Param{
{
Name: "collection",
Type: code.PointerType{
ContainedType: code.ExternalType{
PackageAlias: "mongo",
Name: "Collection",
},
},
},
},
Returns: []code.Type{
code.SimpleType(g.InterfaceName),
},
Body: buffer.String(),
}, nil
}
// GenerateMethod creates codegen.MethodBuilder of repository method from the
// provided method specification.
func (g RepositoryGenerator) GenerateMethod(methodSpec spec.MethodSpec) (codegen.MethodBuilder, error) {
var params []code.Param
for i, param := range methodSpec.Params {
params = append(params, code.Param{
Name: fmt.Sprintf("arg%d", i),
Type: param.Type,
})
}
implementation, err := g.generateMethodImplementation(methodSpec)
if err != nil {
return err
return codegen.MethodBuilder{}, err
}
var paramTypes []code.Type
for _, param := range methodSpec.Params {
paramTypes = append(paramTypes, param.Type)
}
tmplData := mongoMethodTemplateData{
StructName: g.structName(),
MethodName: methodSpec.Name,
ParamTypes: paramTypes,
ReturnTypes: methodSpec.Returns,
Implementation: implementation,
}
if err := tmpl.Execute(buffer, tmplData); err != nil {
return err
}
return nil
return codegen.MethodBuilder{
Receiver: codegen.MethodReceiver{
Name: "r",
Type: code.SimpleType(g.repoImplStructName()),
Pointer: true,
},
Name: methodSpec.Name,
Params: params,
Returns: methodSpec.Returns,
Body: implementation,
}, nil
}
func (g RepositoryGenerator) generateMethodImplementation(methodSpec spec.MethodSpec) (string, error) {
@ -275,7 +325,7 @@ func (g RepositoryGenerator) bsonTagFromField(field code.StructField) (string, e
return bsonTag[0], nil
}
func (g RepositoryGenerator) structName() string {
func (g RepositoryGenerator) repoImplStructName() string {
return g.InterfaceName + "Mongo"
}

File diff suppressed because it is too large Load diff

View file

@ -1,77 +1,17 @@
package mongo
import (
"fmt"
"strings"
"github.com/sunboyy/repogen/internal/code"
"github.com/sunboyy/repogen/internal/spec"
)
const constructorTemplate = `
import (
"context"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
func New{{.InterfaceName}}(collection *mongo.Collection) {{.InterfaceName}} {
return &{{.ImplStructName}}{
const constructorBody = ` return &{{.ImplStructName}}{
collection: collection,
}
}
}`
type {{.ImplStructName}} struct {
collection *mongo.Collection
}
`
type mongoConstructorTemplateData struct {
InterfaceName string
type constructorBodyData struct {
ImplStructName string
}
const methodTemplate = `
func (r *{{.StructName}}) {{.MethodName}}({{.Parameters}}){{.Returns}} {
{{.Implementation}}
}
`
type mongoMethodTemplateData struct {
StructName string
MethodName string
ParamTypes []code.Type
ReturnTypes []code.Type
Implementation string
}
func (data mongoMethodTemplateData) Parameters() string {
var params []string
for i, paramType := range data.ParamTypes {
params = append(params, fmt.Sprintf("arg%d %s", i, paramType.Code()))
}
return strings.Join(params, ", ")
}
func (data mongoMethodTemplateData) Returns() string {
if len(data.ReturnTypes) == 0 {
return ""
}
if len(data.ReturnTypes) == 1 {
return fmt.Sprintf(" %s", data.ReturnTypes[0].Code())
}
var returns []string
for _, returnType := range data.ReturnTypes {
returns = append(returns, returnType.Code())
}
return fmt.Sprintf(" (%s)", strings.Join(returns, ", "))
}
const insertOneTemplate = ` result, err := r.collection.InsertOne(arg0, arg1)
if err != nil {
return nil, err

View file

@ -44,7 +44,7 @@ func TestError(t *testing.T) {
Name: "IncompatibleComparatorError",
Error: spec.NewIncompatibleComparatorError(spec.ComparatorTrue, code.StructField{
Name: "Age",
Type: code.SimpleType("int"),
Type: code.TypeInt,
}),
ExpectedString: "cannot use comparator EQUAL_TRUE with struct field 'Age' of type 'int'",
},
@ -55,7 +55,7 @@ func TestError(t *testing.T) {
},
{
Name: "ArgumentTypeNotMatchedError",
Error: spec.NewArgumentTypeNotMatchedError("Age", code.SimpleType("int"), code.SimpleType("float64")),
Error: spec.NewArgumentTypeNotMatchedError("Age", code.TypeInt, code.TypeFloat64),
ExpectedString: "field 'Age' requires an argument of type 'int' (got 'float64')",
},
{
@ -63,7 +63,7 @@ func TestError(t *testing.T) {
Error: spec.NewIncompatibleUpdateOperatorError(spec.UpdateOperatorInc, spec.FieldReference{
code.StructField{
Name: "City",
Type: code.SimpleType("string"),
Type: code.TypeString,
},
}),
ExpectedString: "cannot use update operator INC with struct field 'City' of type 'string'",

View file

@ -85,7 +85,7 @@ func (p interfaceMethodParser) extractInsertReturns(returns []code.Type) (QueryM
return "", NewOperationReturnCountUnmatchedError(2)
}
if returns[1] != code.SimpleType("error") {
if returns[1] != code.TypeError {
return "", NewUnsupportedReturnError(returns[1], 1)
}
@ -203,7 +203,7 @@ func (p interfaceMethodParser) extractModelOrSliceReturns(returns []code.Type) (
return "", NewOperationReturnCountUnmatchedError(2)
}
if returns[1] != code.SimpleType("error") {
if returns[1] != code.TypeError {
return "", NewUnsupportedReturnError(returns[1], 1)
}
@ -318,11 +318,11 @@ func (p interfaceMethodParser) validateCountReturns(returns []code.Type) error {
return NewOperationReturnCountUnmatchedError(2)
}
if returns[0] != code.SimpleType("int") {
if returns[0] != code.TypeInt {
return NewUnsupportedReturnError(returns[0], 0)
}
if returns[1] != code.SimpleType("error") {
if returns[1] != code.TypeError {
return NewUnsupportedReturnError(returns[1], 1)
}
@ -334,16 +334,16 @@ func (p interfaceMethodParser) extractIntOrBoolReturns(returns []code.Type) (Que
return "", NewOperationReturnCountUnmatchedError(2)
}
if returns[1] != code.SimpleType("error") {
if returns[1] != code.TypeError {
return "", NewUnsupportedReturnError(returns[1], 1)
}
simpleType, ok := returns[0].(code.SimpleType)
if ok {
if simpleType == code.SimpleType("bool") {
if simpleType == code.TypeBool {
return QueryModeOne, nil
}
if simpleType == code.SimpleType("int") {
if simpleType == code.TypeInt {
return QueryModeMany, nil
}
}
@ -367,7 +367,7 @@ func (p interfaceMethodParser) validateQueryFromParams(params []code.Param, quer
var currentParamIndex int
for _, predicate := range querySpec.Predicates {
if (predicate.Comparator == ComparatorTrue || predicate.Comparator == ComparatorFalse) &&
predicate.FieldReference.ReferencedField().Type != code.SimpleType("bool") {
predicate.FieldReference.ReferencedField().Type != code.TypeBool {
return NewIncompatibleComparatorError(predicate.Comparator,
predicate.FieldReference.ReferencedField())
}

File diff suppressed because it is too large Load diff

View file

@ -21,7 +21,8 @@ const usageText = `repogen generates MongoDB repository implementation from repo
Supported options:`
const version = "v0.2.1"
// version indicates the version of repogen.
const version = "v0.3-next"
func main() {
flag.Usage = printUsage