Segregate base code generation logic to a new package (#31)

This commit is contained in:
sunboyy 2022-10-18 19:37:50 +07:00 committed by GitHub
parent 737c1a4044
commit ec08a5a918
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
22 changed files with 1761 additions and 937 deletions

View file

@ -78,7 +78,7 @@ type UserModel struct {
}, },
code.StructField{ code.StructField{
Name: "Username", Name: "Username",
Type: code.SimpleType("string"), Type: code.TypeString,
Tags: map[string][]string{ Tags: map[string][]string{
"bson": {"username"}, "bson": {"username"},
"json": {"username"}, "json": {"username"},
@ -120,7 +120,7 @@ type UserRepository interface {
}, },
Returns: []code.Type{ Returns: []code.Type{
code.PointerType{ContainedType: code.SimpleType("UserModel")}, code.PointerType{ContainedType: code.SimpleType("UserModel")},
code.SimpleType("error"), code.TypeError,
}, },
}, },
{ {
@ -130,19 +130,19 @@ type UserRepository interface {
}, },
Returns: []code.Type{ Returns: []code.Type{
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
code.SimpleType("error"), code.TypeError,
}, },
}, },
{ {
Name: "FindByAgeBetween", Name: "FindByAgeBetween",
Params: []code.Param{ Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "fromAge", Type: code.SimpleType("int")}, {Name: "fromAge", Type: code.TypeInt},
{Name: "toAge", Type: code.SimpleType("int")}, {Name: "toAge", Type: code.TypeInt},
}, },
Returns: []code.Type{ Returns: []code.Type{
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
code.SimpleType("error"), code.TypeError,
}, },
}, },
{ {
@ -153,19 +153,19 @@ type UserRepository interface {
}, },
Returns: []code.Type{ Returns: []code.Type{
code.InterfaceType{}, code.InterfaceType{},
code.SimpleType("error"), code.TypeError,
}, },
}, },
{ {
Name: "UpdateAgreementByID", Name: "UpdateAgreementByID",
Params: []code.Param{ Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "agreement", Type: code.MapType{KeyType: code.SimpleType("string"), ValueType: code.SimpleType("bool")}}, {Name: "agreement", Type: code.MapType{KeyType: code.TypeString, ValueType: code.TypeBool}},
{Name: "id", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, {Name: "id", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}},
}, },
Returns: []code.Type{ Returns: []code.Type{
code.SimpleType("bool"), code.TypeBool,
code.SimpleType("error"), code.TypeError,
}, },
}, },
{ {
@ -178,7 +178,7 @@ type UserRepository interface {
{ {
Name: "Run", Name: "Run",
Params: []code.Param{ Params: []code.Param{
{Name: "arg1", Type: code.SimpleType("int")}, {Name: "arg1", Type: code.TypeInt},
}, },
}, },
}, },
@ -191,7 +191,7 @@ type UserRepository interface {
{ {
Name: "Do", Name: "Do",
Params: []code.Param{ Params: []code.Param{
{Name: "arg2", Type: code.SimpleType("string")}, {Name: "arg2", Type: code.TypeString},
}, },
}, },
}, },
@ -243,7 +243,7 @@ type UserRepository interface {
}, },
code.StructField{ code.StructField{
Name: "Username", Name: "Username",
Type: code.SimpleType("string"), Type: code.TypeString,
Tags: map[string][]string{ Tags: map[string][]string{
"bson": {"username"}, "bson": {"username"},
"json": {"username"}, "json": {"username"},
@ -264,7 +264,7 @@ type UserRepository interface {
}, },
Returns: []code.Type{ Returns: []code.Type{
code.PointerType{ContainedType: code.SimpleType("UserModel")}, code.PointerType{ContainedType: code.SimpleType("UserModel")},
code.SimpleType("error"), code.TypeError,
}, },
}, },
{ {
@ -274,7 +274,7 @@ type UserRepository interface {
}, },
Returns: []code.Type{ Returns: []code.Type{
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
code.SimpleType("error"), code.TypeError,
}, },
}, },
}, },

View file

@ -100,6 +100,15 @@ func (t SimpleType) IsNumber() bool {
t == "float32" || t == "float64" t == "float32" || t == "float64"
} }
// commonly-used types
const (
TypeBool = SimpleType("bool")
TypeInt = SimpleType("int")
TypeFloat64 = SimpleType("float64")
TypeString = SimpleType("string")
TypeError = SimpleType("error")
)
// ExternalType is a type that is called to another package // ExternalType is a type that is called to another package
type ExternalType struct { type ExternalType struct {
PackageAlias string PackageAlias string

View file

@ -9,7 +9,7 @@ import (
func TestStructFieldsByName(t *testing.T) { func TestStructFieldsByName(t *testing.T) {
idField := code.StructField{Name: "ID", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}} idField := code.StructField{Name: "ID", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}
usernameField := code.StructField{Name: "Username", Type: code.SimpleType("string")} usernameField := code.StructField{Name: "Username", Type: code.TypeString}
fields := code.StructFields{idField, usernameField} fields := code.StructFields{idField, usernameField}
t.Run("struct field found", func(t *testing.T) { t.Run("struct field found", func(t *testing.T) {
@ -91,7 +91,7 @@ func TestTypeIsNumber(t *testing.T) {
testTable := []TypeIsNumberTestCase{ testTable := []TypeIsNumberTestCase{
{ {
Name: "simple type: int", Name: "simple type: int",
Type: code.SimpleType("int"), Type: code.TypeInt,
IsNumber: true, IsNumber: true,
}, },
{ {
@ -116,12 +116,12 @@ func TestTypeIsNumber(t *testing.T) {
}, },
{ {
Name: "simple type: other float variants", Name: "simple type: other float variants",
Type: code.SimpleType("float64"), Type: code.TypeFloat64,
IsNumber: true, IsNumber: true,
}, },
{ {
Name: "simple type: non-number primitive type", Name: "simple type: non-number primitive type",
Type: code.SimpleType("string"), Type: code.TypeString,
IsNumber: false, IsNumber: false,
}, },
{ {
@ -136,22 +136,22 @@ func TestTypeIsNumber(t *testing.T) {
}, },
{ {
Name: "pointer type: number", Name: "pointer type: number",
Type: code.PointerType{ContainedType: code.SimpleType("int")}, Type: code.PointerType{ContainedType: code.TypeInt},
IsNumber: true, IsNumber: true,
}, },
{ {
Name: "pointer type: non-number", Name: "pointer type: non-number",
Type: code.PointerType{ContainedType: code.SimpleType("string")}, Type: code.PointerType{ContainedType: code.TypeString},
IsNumber: false, IsNumber: false,
}, },
{ {
Name: "array type", Name: "array type",
Type: code.ArrayType{ContainedType: code.SimpleType("int")}, Type: code.ArrayType{ContainedType: code.TypeInt},
IsNumber: false, IsNumber: false,
}, },
{ {
Name: "map type", Name: "map type",
Type: code.MapType{KeyType: code.SimpleType("int"), ValueType: code.SimpleType("float64")}, Type: code.MapType{KeyType: code.TypeInt, ValueType: code.TypeFloat64},
IsNumber: false, IsNumber: false,
}, },
{ {

41
internal/codegen/base.go Normal file
View file

@ -0,0 +1,41 @@
package codegen
import (
"fmt"
"strings"
"github.com/sunboyy/repogen/internal/code"
)
const baseTemplate = `// Code generated by {{.Program}}. DO NOT EDIT.
package {{.PackageName}}
import (
{{.GenImports}}
)
`
type baseTemplateData struct {
Program string
PackageName string
Imports [][]code.Import
}
func (data baseTemplateData) GenImports() string {
var sections []string
for _, importSection := range data.Imports {
var section []string
for _, imp := range importSection {
section = append(section, data.generateImportLine(imp))
}
sections = append(sections, strings.Join(section, "\n"))
}
return strings.Join(sections, "\n\n")
}
func (data baseTemplateData) generateImportLine(imp code.Import) string {
if imp.Name == "" {
return fmt.Sprintf("\t\"%s\"", imp.Path)
}
return fmt.Sprintf("\t%s \"%s\"", imp.Name, imp.Path)
}

View file

@ -0,0 +1,83 @@
package codegen
import (
"bytes"
"text/template"
"github.com/sunboyy/repogen/internal/code"
"golang.org/x/tools/imports"
)
type Builder struct {
// Program defines generator program name in the generated file.
Program string
// PackageName defines the package name of the generated file.
PackageName string
// Imports defines necessary imports to reduce ambiguity when generating
// formatting the raw-generated code.
Imports [][]code.Import
implementers []Implementer
}
// Implementer is an interface that wraps the basic Impl method for code
// generation.
type Implementer interface {
Impl(buffer *bytes.Buffer) error
}
// NewBuilder is a constructor of Builder struct.
func NewBuilder(program string, packageName string, imports [][]code.Import) *Builder {
return &Builder{
Program: program,
PackageName: packageName,
Imports: imports,
}
}
// AddImplementer appends a new implemeneter to the implementer list.
func (b *Builder) AddImplementer(implementer Implementer) {
b.implementers = append(b.implementers, implementer)
}
// Build generates code from the previously provided specifications.
func (b Builder) Build() (string, error) {
buffer := new(bytes.Buffer)
if err := b.buildBase(buffer); err != nil {
return "", err
}
for _, impl := range b.implementers {
if err := impl.Impl(buffer); err != nil {
return "", err
}
}
formattedCode, err := imports.Process("", buffer.Bytes(), nil)
if err != nil {
return "", err
}
return string(formattedCode), nil
}
func (b Builder) buildBase(buffer *bytes.Buffer) error {
tmpl, err := template.New("file_base").Parse(baseTemplate)
if err != nil {
return err
}
tmplData := baseTemplateData{
Program: b.Program,
PackageName: b.PackageName,
Imports: b.Imports,
}
// writing to a buffer should not cause errors.
_ = tmpl.Execute(buffer, tmplData)
return nil
}

View file

@ -0,0 +1,102 @@
package codegen_test
import (
"testing"
"github.com/sunboyy/repogen/internal/code"
"github.com/sunboyy/repogen/internal/codegen"
"github.com/sunboyy/repogen/internal/testutils"
)
const expectedBuildCode = `// Code generated by repogen. DO NOT EDIT.
package user
import (
_ "context"
"go.mongodb.org/mongo-driver/bson/primitive"
_ "go.mongodb.org/mongo-driver/mongo/options"
)
type User struct {
ID primitive.ObjectID ` + "`bson:\"id\" json:\"id,omitempty\"`" + `
Username string
}
func NewUser(username string) User {
return User{
ID: primitive.NewObjectID(),
Username: username,
}
}
func (u User) IDHex() string {
return u.ID.Hex()
}
`
func TestBuilderBuild(t *testing.T) {
builder := codegen.NewBuilder("repogen", "user", [][]code.Import{
{
{
Name: "_",
Path: "context",
},
},
{
{
Path: "go.mongodb.org/mongo-driver/bson/primitive",
},
{
Name: "_",
Path: "go.mongodb.org/mongo-driver/mongo/options",
},
},
})
builder.AddImplementer(codegen.StructBuilder{
Name: "User",
Fields: code.StructFields{
{
Name: "ID",
Type: code.ExternalType{
PackageAlias: "primitive",
Name: "ObjectID",
},
Tags: map[string][]string{
"bson": {"id"},
"json": {"id", "omitempty"},
},
},
{
Name: "Username",
Type: code.TypeString,
},
},
})
builder.AddImplementer(codegen.FunctionBuilder{
Name: "NewUser",
Params: []code.Param{
{Name: "username", Type: code.TypeString},
},
Returns: []code.Type{code.SimpleType("User")},
Body: ` return User{
ID: primitive.NewObjectID(),
Username: username,
}`,
})
builder.AddImplementer(codegen.MethodBuilder{
Receiver: codegen.MethodReceiver{Name: "u", Type: code.SimpleType("User")},
Name: "IDHex",
Params: nil,
Returns: []code.Type{code.TypeString},
Body: " return u.ID.Hex()",
})
generatedCode, err := builder.Build()
if err != nil {
t.Fatal(err)
}
if err := testutils.ExpectMultiLineString(expectedBuildCode, generatedCode); err != nil {
t.Error(err)
}
}

View file

@ -0,0 +1,73 @@
package codegen
import (
"bytes"
"fmt"
"strings"
"text/template"
"github.com/sunboyy/repogen/internal/code"
)
const functionTemplate = `
func {{.Name}}({{.GenParams}}){{.GenReturns}} {
{{.Body}}
}
`
// FunctionBuilder is an implementer of a function.
type FunctionBuilder struct {
Name string
Params []code.Param
Returns []code.Type
Body string
}
// Impl writes function declatation code to the buffer.
func (fb FunctionBuilder) Impl(buffer *bytes.Buffer) error {
tmpl, err := template.New("function").Parse(functionTemplate)
if err != nil {
return err
}
// writing to a buffer should not cause errors.
_ = tmpl.Execute(buffer, fb)
return nil
}
func (fb FunctionBuilder) GenParams() string {
return generateParams(fb.Params)
}
func (fb FunctionBuilder) GenReturns() string {
return generateReturns(fb.Returns)
}
func generateParams(params []code.Param) string {
var paramList []string
for _, param := range params {
paramList = append(
paramList,
fmt.Sprintf("%s %s", param.Name, param.Type.Code()),
)
}
return strings.Join(paramList, ", ")
}
func generateReturns(returns []code.Type) string {
if len(returns) == 0 {
return ""
}
if len(returns) == 1 {
return " " + returns[0].Code()
}
var returnList []string
for _, ret := range returns {
returnList = append(returnList, ret.Code())
}
return fmt.Sprintf(" (%s)", strings.Join(returnList, ", "))
}

View file

@ -0,0 +1,125 @@
package codegen_test
import (
"bytes"
"testing"
"github.com/sunboyy/repogen/internal/code"
"github.com/sunboyy/repogen/internal/codegen"
"github.com/sunboyy/repogen/internal/testutils"
)
func TestFunctionBuilderBuild_NoReturn(t *testing.T) {
fb := codegen.FunctionBuilder{
Name: "init",
Params: nil,
Returns: nil,
Body: ` logrus.SetLevel(logrus.DebugLevel)`,
}
expectedCode := `
func init() {
logrus.SetLevel(logrus.DebugLevel)
}
`
buffer := new(bytes.Buffer)
err := fb.Impl(buffer)
if err != nil {
t.Fatal(err)
}
actual := buffer.String()
if err := testutils.ExpectMultiLineString(
expectedCode,
actual,
); err != nil {
t.Error(err)
}
}
func TestFunctionBuilderBuild_OneReturn(t *testing.T) {
fb := codegen.FunctionBuilder{
Name: "NewUser",
Params: []code.Param{
{
Name: "username",
Type: code.TypeString,
},
{
Name: "age",
Type: code.TypeInt,
},
{
Name: "parent",
Type: code.PointerType{ContainedType: code.SimpleType("User")},
},
},
Returns: []code.Type{
code.SimpleType("User"),
},
Body: ` return User{
Username: username,
Age: age,
Parent: parent
}`,
}
expectedCode := `
func NewUser(username string, age int, parent *User) User {
return User{
Username: username,
Age: age,
Parent: parent
}
}
`
buffer := new(bytes.Buffer)
err := fb.Impl(buffer)
if err != nil {
t.Fatal(err)
}
actual := buffer.String()
if err := testutils.ExpectMultiLineString(
expectedCode,
actual,
); err != nil {
t.Error(err)
}
}
func TestFunctionBuilderBuild_MultiReturn(t *testing.T) {
fb := codegen.FunctionBuilder{
Name: "Save",
Params: []code.Param{
{
Name: "user",
Type: code.SimpleType("User"),
},
},
Returns: []code.Type{
code.SimpleType("User"),
code.TypeError,
},
Body: ` return collection.Save(user)`,
}
expectedCode := `
func Save(user User) (User, error) {
return collection.Save(user)
}
`
buffer := new(bytes.Buffer)
err := fb.Impl(buffer)
if err != nil {
t.Fatal(err)
}
actual := buffer.String()
if err := testutils.ExpectMultiLineString(
expectedCode,
actual,
); err != nil {
t.Error(err)
}
}

View file

@ -0,0 +1,66 @@
package codegen
import (
"bytes"
"fmt"
"text/template"
"github.com/sunboyy/repogen/internal/code"
)
const methodTemplate = `
func ({{.GenReceiver}}) {{.Name}}({{.GenParams}}){{.GenReturns}} {
{{.Body}}
}
`
// MethodBuilder is an implementer of a method.
type MethodBuilder struct {
Receiver MethodReceiver
Name string
Params []code.Param
Returns []code.Type
Body string
}
// MethodReceiver describes a specification of a method receiver.
type MethodReceiver struct {
Name string
Type code.SimpleType
Pointer bool
}
// Impl writes method declatation code to the buffer.
func (mb MethodBuilder) Impl(buffer *bytes.Buffer) error {
tmpl, err := template.New("function").Parse(methodTemplate)
if err != nil {
return err
}
// writing to a buffer should not cause errors.
_ = tmpl.Execute(buffer, mb)
return nil
}
func (mb MethodBuilder) GenReceiver() string {
if mb.Receiver.Name == "" {
return mb.generateReceiverType()
}
return fmt.Sprintf("%s %s", mb.Receiver.Name, mb.generateReceiverType())
}
func (mb MethodBuilder) generateReceiverType() string {
if !mb.Receiver.Pointer {
return mb.Receiver.Type.Code()
}
return code.PointerType{ContainedType: mb.Receiver.Type}.Code()
}
func (mb MethodBuilder) GenParams() string {
return generateParams(mb.Params)
}
func (mb MethodBuilder) GenReturns() string {
return generateReturns(mb.Returns)
}

View file

@ -0,0 +1,142 @@
package codegen_test
import (
"bytes"
"testing"
"github.com/sunboyy/repogen/internal/code"
"github.com/sunboyy/repogen/internal/codegen"
"github.com/sunboyy/repogen/internal/testutils"
)
func TestMethodBuilderBuild_IgnoreReceiverNoReturn(t *testing.T) {
fb := codegen.MethodBuilder{
Receiver: codegen.MethodReceiver{Type: "User"},
Name: "Init",
Params: nil,
Returns: nil,
Body: ` db.Init(&User{})`,
}
expectedCode := `
func (User) Init() {
db.Init(&User{})
}
`
buffer := new(bytes.Buffer)
err := fb.Impl(buffer)
if err != nil {
t.Fatal(err)
}
actual := buffer.String()
if err := testutils.ExpectMultiLineString(
expectedCode,
actual,
); err != nil {
t.Error(err)
}
}
func TestMethodBuilderBuild_IgnorePoinerReceiverOneReturn(t *testing.T) {
fb := codegen.MethodBuilder{
Receiver: codegen.MethodReceiver{
Type: "User",
Pointer: true,
},
Name: "Init",
Params: nil,
Returns: []code.Type{code.TypeError},
Body: ` return db.Init(&User{})`,
}
expectedCode := `
func (*User) Init() error {
return db.Init(&User{})
}
`
buffer := new(bytes.Buffer)
err := fb.Impl(buffer)
if err != nil {
t.Fatal(err)
}
actual := buffer.String()
if err := testutils.ExpectMultiLineString(
expectedCode,
actual,
); err != nil {
t.Error(err)
}
}
func TestMethodBuilderBuild_UseReceiverMultiReturn(t *testing.T) {
fb := codegen.MethodBuilder{
Receiver: codegen.MethodReceiver{
Name: "u",
Type: "User",
},
Name: "WithAge",
Params: []code.Param{
{Name: "age", Type: code.TypeInt},
},
Returns: []code.Type{code.SimpleType("User"), code.TypeError},
Body: ` u.Age = age
return u`,
}
expectedCode := `
func (u User) WithAge(age int) (User, error) {
u.Age = age
return u
}
`
buffer := new(bytes.Buffer)
err := fb.Impl(buffer)
if err != nil {
t.Fatal(err)
}
actual := buffer.String()
if err := testutils.ExpectMultiLineString(
expectedCode,
actual,
); err != nil {
t.Error(err)
}
}
func TestMethodBuilderBuild_UsePointerReceiverNoReturn(t *testing.T) {
fb := codegen.MethodBuilder{
Receiver: codegen.MethodReceiver{
Name: "u",
Type: "User",
Pointer: true,
},
Name: "SetAge",
Params: []code.Param{
{Name: "age", Type: code.TypeInt},
},
Returns: nil,
Body: ` u.Age = age`,
}
expectedCode := `
func (u *User) SetAge(age int) {
u.Age = age
}
`
buffer := new(bytes.Buffer)
err := fb.Impl(buffer)
if err != nil {
t.Fatal(err)
}
actual := buffer.String()
if err := testutils.ExpectMultiLineString(
expectedCode,
actual,
); err != nil {
t.Error(err)
}
}

View file

@ -0,0 +1,63 @@
package codegen
import (
"bytes"
"fmt"
"sort"
"strings"
"text/template"
"github.com/sunboyy/repogen/internal/code"
)
const structTemplate = `
type {{.Name}} struct {
{{.GenFields}}
}
`
// StructBuilder is an implementer of a struct.
type StructBuilder struct {
Name string
Fields code.StructFields
}
// Impl writes struct declatation code to the buffer.
func (sb StructBuilder) Impl(buffer *bytes.Buffer) error {
tmpl, err := template.New("struct").Parse(structTemplate)
if err != nil {
return err
}
// writing to a buffer should not cause errors.
_ = tmpl.Execute(buffer, sb)
return nil
}
func (sb StructBuilder) GenFields() string {
var fieldLines []string
for _, field := range sb.Fields {
fieldLine := fmt.Sprintf("\t%s %s", field.Name, field.Type.Code())
if len(field.Tags) > 0 {
fieldLine += fmt.Sprintf(" `%s`", sb.generateStructTag(field.Tags))
}
fieldLines = append(fieldLines, fieldLine)
}
return strings.Join(fieldLines, "\n")
}
func (sb StructBuilder) generateStructTag(tags map[string][]string) string {
var tagKeys []string
for key := range tags {
tagKeys = append(tagKeys, key)
}
sort.Strings(tagKeys)
var tagGroups []string
for _, key := range tagKeys {
tagValue := strings.Join(tags[key], ",")
tagGroups = append(tagGroups, fmt.Sprintf("%s:\"%s\"", key, tagValue))
}
return strings.Join(tagGroups, " ")
}

View file

@ -0,0 +1,73 @@
package codegen_test
import (
"bytes"
"testing"
"github.com/sunboyy/repogen/internal/code"
"github.com/sunboyy/repogen/internal/codegen"
"github.com/sunboyy/repogen/internal/testutils"
)
const expectedStructBuilderCode = `
type User struct {
ID primitive.ObjectID ` + "`bson:\"id,omitempty\" json:\"id,omitempty\"`" + `
Username string ` + "`bson:\"username\" json:\"username\"`" + `
Age int ` + "`bson:\"age\"`" + `
orderCount *int
}
`
func TestStructBuilderBuild(t *testing.T) {
sb := codegen.StructBuilder{
Name: "User",
Fields: []code.StructField{
{
Name: "ID",
Type: code.ExternalType{
PackageAlias: "primitive",
Name: "ObjectID",
},
Tags: map[string][]string{
"json": {"id", "omitempty"},
"bson": {"id", "omitempty"},
},
},
{
Name: "Username",
Type: code.TypeString,
Tags: map[string][]string{
"json": {"username"},
"bson": {"username"},
},
},
{
Name: "Age",
Type: code.TypeInt,
Tags: map[string][]string{
"bson": {"age"},
},
},
{
Name: "orderCount",
Type: code.PointerType{
ContainedType: code.TypeInt,
},
},
},
}
buffer := new(bytes.Buffer)
err := sb.Impl(buffer)
if err != nil {
t.Fatal(err)
}
actual := buffer.String()
if err := testutils.ExpectMultiLineString(
expectedStructBuilderCode,
actual,
); err != nil {
t.Error(err)
}
}

View file

@ -1,75 +1,40 @@
package generator package generator
import ( import (
"bytes"
"html/template"
"github.com/sunboyy/repogen/internal/code" "github.com/sunboyy/repogen/internal/code"
"github.com/sunboyy/repogen/internal/codegen"
"github.com/sunboyy/repogen/internal/mongo" "github.com/sunboyy/repogen/internal/mongo"
"github.com/sunboyy/repogen/internal/spec" "github.com/sunboyy/repogen/internal/spec"
"golang.org/x/tools/imports"
) )
// GenerateRepository generates repository implementation from repository interface specification // GenerateRepository generates repository implementation code from repository
func GenerateRepository(packageName string, structModel code.Struct, interfaceName string, // interface specification.
methodSpecs []spec.MethodSpec) (string, error) { func GenerateRepository(packageName string, structModel code.Struct,
interfaceName string, methodSpecs []spec.MethodSpec) (string, error) {
repositoryGenerator := repositoryGenerator{ generator := mongo.NewGenerator(structModel, interfaceName)
PackageName: packageName,
StructModel: structModel,
InterfaceName: interfaceName,
MethodSpecs: methodSpecs,
Generator: mongo.NewGenerator(structModel, interfaceName),
}
return repositoryGenerator.Generate() codeBuilder := codegen.NewBuilder(
} "repogen",
packageName,
generator.Imports(),
)
type repositoryGenerator struct { constructorBuilder, err := generator.GenerateConstructor()
PackageName string
StructModel code.Struct
InterfaceName string
MethodSpecs []spec.MethodSpec
Generator mongo.RepositoryGenerator
}
func (g repositoryGenerator) Generate() (string, error) {
buffer := new(bytes.Buffer)
if err := g.generateBase(buffer); err != nil {
return "", err
}
if err := g.Generator.GenerateConstructor(buffer); err != nil {
return "", err
}
for _, method := range g.MethodSpecs {
if err := g.Generator.GenerateMethod(method, buffer); err != nil {
return "", err
}
}
formattedCode, err := imports.Process("", buffer.Bytes(), nil)
if err != nil { if err != nil {
return "", err return "", err
} }
return string(formattedCode), nil codeBuilder.AddImplementer(constructorBuilder)
} codeBuilder.AddImplementer(generator.GenerateStruct())
func (g repositoryGenerator) generateBase(buffer *bytes.Buffer) error { for _, method := range methodSpecs {
tmpl, err := template.New("file_base").Parse(baseTemplate) methodBuilder, err := generator.GenerateMethod(method)
if err != nil { if err != nil {
return err return "", err
}
codeBuilder.AddImplementer(methodBuilder)
} }
tmplData := baseTemplateData{ return codeBuilder.Build()
PackageName: g.PackageName,
}
if err := tmpl.Execute(buffer, tmplData); err != nil {
return err
}
return nil
} }

View file

@ -22,7 +22,7 @@ var (
} }
ageField = code.StructField{ ageField = code.StructField{
Name: "Age", Name: "Age",
Type: code.SimpleType("int"), Type: code.TypeInt,
Tags: map[string][]string{"bson": {"age"}}, Tags: map[string][]string{"bson": {"age"}},
} }
) )
@ -34,7 +34,7 @@ func TestGenerateMongoRepository(t *testing.T) {
idField, idField,
code.StructField{ code.StructField{
Name: "Username", Name: "Username",
Type: code.SimpleType("string"), Type: code.TypeString,
Tags: map[string][]string{"bson": {"username"}}, Tags: map[string][]string{"bson": {"username"}},
}, },
genderField, genderField,
@ -49,7 +49,7 @@ func TestGenerateMongoRepository(t *testing.T) {
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "id", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, {Name: "id", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}},
}, },
Returns: []code.Type{code.PointerType{ContainedType: code.SimpleType("UserModel")}, code.SimpleType("error")}, Returns: []code.Type{code.PointerType{ContainedType: code.SimpleType("UserModel")}, code.TypeError},
Operation: spec.FindOperation{ Operation: spec.FindOperation{
Mode: spec.QueryModeOne, Mode: spec.QueryModeOne,
Query: spec.QuerySpec{ Query: spec.QuerySpec{
@ -65,11 +65,11 @@ func TestGenerateMongoRepository(t *testing.T) {
Params: []code.Param{ Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "gender", Type: code.SimpleType("Gender")}, {Name: "gender", Type: code.SimpleType("Gender")},
{Name: "age", Type: code.SimpleType("int")}, {Name: "age", Type: code.TypeInt},
}, },
Returns: []code.Type{ Returns: []code.Type{
code.PointerType{ContainedType: code.SimpleType("UserModel")}, code.PointerType{ContainedType: code.SimpleType("UserModel")},
code.SimpleType("error"), code.TypeError,
}, },
Operation: spec.FindOperation{ Operation: spec.FindOperation{
Mode: spec.QueryModeMany, Mode: spec.QueryModeMany,
@ -86,11 +86,11 @@ func TestGenerateMongoRepository(t *testing.T) {
Name: "FindByAgeLessThanEqualOrderByAge", Name: "FindByAgeLessThanEqualOrderByAge",
Params: []code.Param{ Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "age", Type: code.SimpleType("int")}, {Name: "age", Type: code.TypeInt},
}, },
Returns: []code.Type{ Returns: []code.Type{
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
code.SimpleType("error"), code.TypeError,
}, },
Operation: spec.FindOperation{ Operation: spec.FindOperation{
Mode: spec.QueryModeMany, Mode: spec.QueryModeMany,
@ -108,11 +108,11 @@ func TestGenerateMongoRepository(t *testing.T) {
Name: "FindByAgeGreaterThanOrderByAgeAsc", Name: "FindByAgeGreaterThanOrderByAgeAsc",
Params: []code.Param{ Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "age", Type: code.SimpleType("int")}, {Name: "age", Type: code.TypeInt},
}, },
Returns: []code.Type{ Returns: []code.Type{
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
code.SimpleType("error"), code.TypeError,
}, },
Operation: spec.FindOperation{ Operation: spec.FindOperation{
Mode: spec.QueryModeMany, Mode: spec.QueryModeMany,
@ -130,11 +130,11 @@ func TestGenerateMongoRepository(t *testing.T) {
Name: "FindByAgeGreaterThanEqualOrderByAgeDesc", Name: "FindByAgeGreaterThanEqualOrderByAgeDesc",
Params: []code.Param{ Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "age", Type: code.SimpleType("int")}, {Name: "age", Type: code.TypeInt},
}, },
Returns: []code.Type{ Returns: []code.Type{
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
code.SimpleType("error"), code.TypeError,
}, },
Operation: spec.FindOperation{ Operation: spec.FindOperation{
Mode: spec.QueryModeMany, Mode: spec.QueryModeMany,
@ -152,12 +152,12 @@ func TestGenerateMongoRepository(t *testing.T) {
Name: "FindByAgeBetween", Name: "FindByAgeBetween",
Params: []code.Param{ Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "fromAge", Type: code.SimpleType("int")}, {Name: "fromAge", Type: code.TypeInt},
{Name: "toAge", Type: code.SimpleType("int")}, {Name: "toAge", Type: code.TypeInt},
}, },
Returns: []code.Type{ Returns: []code.Type{
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
code.SimpleType("error"), code.TypeError,
}, },
Operation: spec.FindOperation{ Operation: spec.FindOperation{
Mode: spec.QueryModeMany, Mode: spec.QueryModeMany,
@ -173,11 +173,11 @@ func TestGenerateMongoRepository(t *testing.T) {
Params: []code.Param{ Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "gender", Type: code.SimpleType("Gender")}, {Name: "gender", Type: code.SimpleType("Gender")},
{Name: "age", Type: code.SimpleType("int")}, {Name: "age", Type: code.TypeInt},
}, },
Returns: []code.Type{ Returns: []code.Type{
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
code.SimpleType("error"), code.TypeError,
}, },
Operation: spec.FindOperation{ Operation: spec.FindOperation{
Mode: spec.QueryModeMany, Mode: spec.QueryModeMany,

View file

@ -1,9 +0,0 @@
package generator
const baseTemplate = `// Code generated by repogen. DO NOT EDIT.
package {{.PackageName}}
`
type baseTemplateData struct {
PackageName string
}

View file

@ -2,11 +2,12 @@ package mongo
import ( import (
"bytes" "bytes"
"io" "fmt"
"strings" "strings"
"text/template" "text/template"
"github.com/sunboyy/repogen/internal/code" "github.com/sunboyy/repogen/internal/code"
"github.com/sunboyy/repogen/internal/codegen"
"github.com/sunboyy/repogen/internal/spec" "github.com/sunboyy/repogen/internal/spec"
) )
@ -24,55 +25,104 @@ type RepositoryGenerator struct {
InterfaceName string InterfaceName string
} }
// GenerateConstructor generates mongo repository struct implementation and constructor for the struct // Imports returns necessary imports for the mongo repository implementation.
func (g RepositoryGenerator) GenerateConstructor(buffer io.Writer) error { func (g RepositoryGenerator) Imports() [][]code.Import {
tmpl, err := template.New("mongo_repository_base").Parse(constructorTemplate) return [][]code.Import{
if err != nil { {
return err {Path: "context"},
},
{
{Path: "go.mongodb.org/mongo-driver/bson"},
{Path: "go.mongodb.org/mongo-driver/bson/primitive"},
{Path: "go.mongodb.org/mongo-driver/mongo"},
{Path: "go.mongodb.org/mongo-driver/mongo/options"},
},
} }
tmplData := mongoConstructorTemplateData{
InterfaceName: g.InterfaceName,
ImplStructName: g.structName(),
}
if err := tmpl.Execute(buffer, tmplData); err != nil {
return err
}
return nil
} }
// GenerateMethod generates implementation of from provided method specification // GenerateStruct creates codegen.StructBuilder of mongo repository
func (g RepositoryGenerator) GenerateMethod(methodSpec spec.MethodSpec, buffer io.Writer) error { // implementation struct.
tmpl, err := template.New("mongo_repository_method").Parse(methodTemplate) func (g RepositoryGenerator) GenerateStruct() codegen.StructBuilder {
return codegen.StructBuilder{
Name: g.repoImplStructName(),
Fields: code.StructFields{
{
Name: "collection",
Type: code.PointerType{
ContainedType: code.ExternalType{
PackageAlias: "mongo",
Name: "Collection",
},
},
},
},
}
}
// GenerateConstructor creates codegen.FunctionBuilder of a constructor for
// mongo repository implementation struct.
func (g RepositoryGenerator) GenerateConstructor() (codegen.FunctionBuilder, error) {
tmpl, err := template.New("mongo_constructor_body").Parse(constructorBody)
if err != nil { if err != nil {
return err return codegen.FunctionBuilder{}, err
}
tmplData := constructorBodyData{
ImplStructName: g.repoImplStructName(),
}
buffer := new(bytes.Buffer)
if err := tmpl.Execute(buffer, tmplData); err != nil {
return codegen.FunctionBuilder{}, err
}
return codegen.FunctionBuilder{
Name: "New" + g.InterfaceName,
Params: []code.Param{
{
Name: "collection",
Type: code.PointerType{
ContainedType: code.ExternalType{
PackageAlias: "mongo",
Name: "Collection",
},
},
},
},
Returns: []code.Type{
code.SimpleType(g.InterfaceName),
},
Body: buffer.String(),
}, nil
}
// GenerateMethod creates codegen.MethodBuilder of repository method from the
// provided method specification.
func (g RepositoryGenerator) GenerateMethod(methodSpec spec.MethodSpec) (codegen.MethodBuilder, error) {
var params []code.Param
for i, param := range methodSpec.Params {
params = append(params, code.Param{
Name: fmt.Sprintf("arg%d", i),
Type: param.Type,
})
} }
implementation, err := g.generateMethodImplementation(methodSpec) implementation, err := g.generateMethodImplementation(methodSpec)
if err != nil { if err != nil {
return err return codegen.MethodBuilder{}, err
} }
var paramTypes []code.Type return codegen.MethodBuilder{
for _, param := range methodSpec.Params { Receiver: codegen.MethodReceiver{
paramTypes = append(paramTypes, param.Type) Name: "r",
} Type: code.SimpleType(g.repoImplStructName()),
Pointer: true,
tmplData := mongoMethodTemplateData{ },
StructName: g.structName(), Name: methodSpec.Name,
MethodName: methodSpec.Name, Params: params,
ParamTypes: paramTypes, Returns: methodSpec.Returns,
ReturnTypes: methodSpec.Returns, Body: implementation,
Implementation: implementation, }, nil
}
if err := tmpl.Execute(buffer, tmplData); err != nil {
return err
}
return nil
} }
func (g RepositoryGenerator) generateMethodImplementation(methodSpec spec.MethodSpec) (string, error) { func (g RepositoryGenerator) generateMethodImplementation(methodSpec spec.MethodSpec) (string, error) {
@ -275,7 +325,7 @@ func (g RepositoryGenerator) bsonTagFromField(field code.StructField) (string, e
return bsonTag[0], nil return bsonTag[0], nil
} }
func (g RepositoryGenerator) structName() string { func (g RepositoryGenerator) repoImplStructName() string {
return g.InterfaceName + "Mongo" return g.InterfaceName + "Mongo"
} }

File diff suppressed because it is too large Load diff

View file

@ -1,77 +1,17 @@
package mongo package mongo
import ( import (
"fmt"
"strings"
"github.com/sunboyy/repogen/internal/code"
"github.com/sunboyy/repogen/internal/spec" "github.com/sunboyy/repogen/internal/spec"
) )
const constructorTemplate = ` const constructorBody = ` return &{{.ImplStructName}}{
import (
"context"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
func New{{.InterfaceName}}(collection *mongo.Collection) {{.InterfaceName}} {
return &{{.ImplStructName}}{
collection: collection, collection: collection,
} }`
}
type {{.ImplStructName}} struct { type constructorBodyData struct {
collection *mongo.Collection
}
`
type mongoConstructorTemplateData struct {
InterfaceName string
ImplStructName string ImplStructName string
} }
const methodTemplate = `
func (r *{{.StructName}}) {{.MethodName}}({{.Parameters}}){{.Returns}} {
{{.Implementation}}
}
`
type mongoMethodTemplateData struct {
StructName string
MethodName string
ParamTypes []code.Type
ReturnTypes []code.Type
Implementation string
}
func (data mongoMethodTemplateData) Parameters() string {
var params []string
for i, paramType := range data.ParamTypes {
params = append(params, fmt.Sprintf("arg%d %s", i, paramType.Code()))
}
return strings.Join(params, ", ")
}
func (data mongoMethodTemplateData) Returns() string {
if len(data.ReturnTypes) == 0 {
return ""
}
if len(data.ReturnTypes) == 1 {
return fmt.Sprintf(" %s", data.ReturnTypes[0].Code())
}
var returns []string
for _, returnType := range data.ReturnTypes {
returns = append(returns, returnType.Code())
}
return fmt.Sprintf(" (%s)", strings.Join(returns, ", "))
}
const insertOneTemplate = ` result, err := r.collection.InsertOne(arg0, arg1) const insertOneTemplate = ` result, err := r.collection.InsertOne(arg0, arg1)
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -44,7 +44,7 @@ func TestError(t *testing.T) {
Name: "IncompatibleComparatorError", Name: "IncompatibleComparatorError",
Error: spec.NewIncompatibleComparatorError(spec.ComparatorTrue, code.StructField{ Error: spec.NewIncompatibleComparatorError(spec.ComparatorTrue, code.StructField{
Name: "Age", Name: "Age",
Type: code.SimpleType("int"), Type: code.TypeInt,
}), }),
ExpectedString: "cannot use comparator EQUAL_TRUE with struct field 'Age' of type 'int'", ExpectedString: "cannot use comparator EQUAL_TRUE with struct field 'Age' of type 'int'",
}, },
@ -55,7 +55,7 @@ func TestError(t *testing.T) {
}, },
{ {
Name: "ArgumentTypeNotMatchedError", Name: "ArgumentTypeNotMatchedError",
Error: spec.NewArgumentTypeNotMatchedError("Age", code.SimpleType("int"), code.SimpleType("float64")), Error: spec.NewArgumentTypeNotMatchedError("Age", code.TypeInt, code.TypeFloat64),
ExpectedString: "field 'Age' requires an argument of type 'int' (got 'float64')", ExpectedString: "field 'Age' requires an argument of type 'int' (got 'float64')",
}, },
{ {
@ -63,7 +63,7 @@ func TestError(t *testing.T) {
Error: spec.NewIncompatibleUpdateOperatorError(spec.UpdateOperatorInc, spec.FieldReference{ Error: spec.NewIncompatibleUpdateOperatorError(spec.UpdateOperatorInc, spec.FieldReference{
code.StructField{ code.StructField{
Name: "City", Name: "City",
Type: code.SimpleType("string"), Type: code.TypeString,
}, },
}), }),
ExpectedString: "cannot use update operator INC with struct field 'City' of type 'string'", ExpectedString: "cannot use update operator INC with struct field 'City' of type 'string'",

View file

@ -85,7 +85,7 @@ func (p interfaceMethodParser) extractInsertReturns(returns []code.Type) (QueryM
return "", NewOperationReturnCountUnmatchedError(2) return "", NewOperationReturnCountUnmatchedError(2)
} }
if returns[1] != code.SimpleType("error") { if returns[1] != code.TypeError {
return "", NewUnsupportedReturnError(returns[1], 1) return "", NewUnsupportedReturnError(returns[1], 1)
} }
@ -203,7 +203,7 @@ func (p interfaceMethodParser) extractModelOrSliceReturns(returns []code.Type) (
return "", NewOperationReturnCountUnmatchedError(2) return "", NewOperationReturnCountUnmatchedError(2)
} }
if returns[1] != code.SimpleType("error") { if returns[1] != code.TypeError {
return "", NewUnsupportedReturnError(returns[1], 1) return "", NewUnsupportedReturnError(returns[1], 1)
} }
@ -318,11 +318,11 @@ func (p interfaceMethodParser) validateCountReturns(returns []code.Type) error {
return NewOperationReturnCountUnmatchedError(2) return NewOperationReturnCountUnmatchedError(2)
} }
if returns[0] != code.SimpleType("int") { if returns[0] != code.TypeInt {
return NewUnsupportedReturnError(returns[0], 0) return NewUnsupportedReturnError(returns[0], 0)
} }
if returns[1] != code.SimpleType("error") { if returns[1] != code.TypeError {
return NewUnsupportedReturnError(returns[1], 1) return NewUnsupportedReturnError(returns[1], 1)
} }
@ -334,16 +334,16 @@ func (p interfaceMethodParser) extractIntOrBoolReturns(returns []code.Type) (Que
return "", NewOperationReturnCountUnmatchedError(2) return "", NewOperationReturnCountUnmatchedError(2)
} }
if returns[1] != code.SimpleType("error") { if returns[1] != code.TypeError {
return "", NewUnsupportedReturnError(returns[1], 1) return "", NewUnsupportedReturnError(returns[1], 1)
} }
simpleType, ok := returns[0].(code.SimpleType) simpleType, ok := returns[0].(code.SimpleType)
if ok { if ok {
if simpleType == code.SimpleType("bool") { if simpleType == code.TypeBool {
return QueryModeOne, nil return QueryModeOne, nil
} }
if simpleType == code.SimpleType("int") { if simpleType == code.TypeInt {
return QueryModeMany, nil return QueryModeMany, nil
} }
} }
@ -367,7 +367,7 @@ func (p interfaceMethodParser) validateQueryFromParams(params []code.Param, quer
var currentParamIndex int var currentParamIndex int
for _, predicate := range querySpec.Predicates { for _, predicate := range querySpec.Predicates {
if (predicate.Comparator == ComparatorTrue || predicate.Comparator == ComparatorFalse) && if (predicate.Comparator == ComparatorTrue || predicate.Comparator == ComparatorFalse) &&
predicate.FieldReference.ReferencedField().Type != code.SimpleType("bool") { predicate.FieldReference.ReferencedField().Type != code.TypeBool {
return NewIncompatibleComparatorError(predicate.Comparator, return NewIncompatibleComparatorError(predicate.Comparator,
predicate.FieldReference.ReferencedField()) predicate.FieldReference.ReferencedField())
} }

File diff suppressed because it is too large Load diff

View file

@ -21,7 +21,8 @@ const usageText = `repogen generates MongoDB repository implementation from repo
Supported options:` Supported options:`
const version = "v0.2.1" // version indicates the version of repogen.
const version = "v0.3-next"
func main() { func main() {
flag.Usage = printUsage flag.Usage = printUsage