Segregate base code generation logic to a new package (#31)
This commit is contained in:
parent
737c1a4044
commit
ec08a5a918
22 changed files with 1761 additions and 937 deletions
|
@ -78,7 +78,7 @@ type UserModel struct {
|
|||
},
|
||||
code.StructField{
|
||||
Name: "Username",
|
||||
Type: code.SimpleType("string"),
|
||||
Type: code.TypeString,
|
||||
Tags: map[string][]string{
|
||||
"bson": {"username"},
|
||||
"json": {"username"},
|
||||
|
@ -120,7 +120,7 @@ type UserRepository interface {
|
|||
},
|
||||
Returns: []code.Type{
|
||||
code.PointerType{ContainedType: code.SimpleType("UserModel")},
|
||||
code.SimpleType("error"),
|
||||
code.TypeError,
|
||||
},
|
||||
},
|
||||
{
|
||||
|
@ -130,19 +130,19 @@ type UserRepository interface {
|
|||
},
|
||||
Returns: []code.Type{
|
||||
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
|
||||
code.SimpleType("error"),
|
||||
code.TypeError,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "FindByAgeBetween",
|
||||
Params: []code.Param{
|
||||
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
|
||||
{Name: "fromAge", Type: code.SimpleType("int")},
|
||||
{Name: "toAge", Type: code.SimpleType("int")},
|
||||
{Name: "fromAge", Type: code.TypeInt},
|
||||
{Name: "toAge", Type: code.TypeInt},
|
||||
},
|
||||
Returns: []code.Type{
|
||||
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
|
||||
code.SimpleType("error"),
|
||||
code.TypeError,
|
||||
},
|
||||
},
|
||||
{
|
||||
|
@ -153,19 +153,19 @@ type UserRepository interface {
|
|||
},
|
||||
Returns: []code.Type{
|
||||
code.InterfaceType{},
|
||||
code.SimpleType("error"),
|
||||
code.TypeError,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "UpdateAgreementByID",
|
||||
Params: []code.Param{
|
||||
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
|
||||
{Name: "agreement", Type: code.MapType{KeyType: code.SimpleType("string"), ValueType: code.SimpleType("bool")}},
|
||||
{Name: "agreement", Type: code.MapType{KeyType: code.TypeString, ValueType: code.TypeBool}},
|
||||
{Name: "id", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}},
|
||||
},
|
||||
Returns: []code.Type{
|
||||
code.SimpleType("bool"),
|
||||
code.SimpleType("error"),
|
||||
code.TypeBool,
|
||||
code.TypeError,
|
||||
},
|
||||
},
|
||||
{
|
||||
|
@ -178,7 +178,7 @@ type UserRepository interface {
|
|||
{
|
||||
Name: "Run",
|
||||
Params: []code.Param{
|
||||
{Name: "arg1", Type: code.SimpleType("int")},
|
||||
{Name: "arg1", Type: code.TypeInt},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
@ -191,7 +191,7 @@ type UserRepository interface {
|
|||
{
|
||||
Name: "Do",
|
||||
Params: []code.Param{
|
||||
{Name: "arg2", Type: code.SimpleType("string")},
|
||||
{Name: "arg2", Type: code.TypeString},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
@ -243,7 +243,7 @@ type UserRepository interface {
|
|||
},
|
||||
code.StructField{
|
||||
Name: "Username",
|
||||
Type: code.SimpleType("string"),
|
||||
Type: code.TypeString,
|
||||
Tags: map[string][]string{
|
||||
"bson": {"username"},
|
||||
"json": {"username"},
|
||||
|
@ -264,7 +264,7 @@ type UserRepository interface {
|
|||
},
|
||||
Returns: []code.Type{
|
||||
code.PointerType{ContainedType: code.SimpleType("UserModel")},
|
||||
code.SimpleType("error"),
|
||||
code.TypeError,
|
||||
},
|
||||
},
|
||||
{
|
||||
|
@ -274,7 +274,7 @@ type UserRepository interface {
|
|||
},
|
||||
Returns: []code.Type{
|
||||
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
|
||||
code.SimpleType("error"),
|
||||
code.TypeError,
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
|
@ -100,6 +100,15 @@ func (t SimpleType) IsNumber() bool {
|
|||
t == "float32" || t == "float64"
|
||||
}
|
||||
|
||||
// commonly-used types
|
||||
const (
|
||||
TypeBool = SimpleType("bool")
|
||||
TypeInt = SimpleType("int")
|
||||
TypeFloat64 = SimpleType("float64")
|
||||
TypeString = SimpleType("string")
|
||||
TypeError = SimpleType("error")
|
||||
)
|
||||
|
||||
// ExternalType is a type that is called to another package
|
||||
type ExternalType struct {
|
||||
PackageAlias string
|
||||
|
|
|
@ -9,7 +9,7 @@ import (
|
|||
|
||||
func TestStructFieldsByName(t *testing.T) {
|
||||
idField := code.StructField{Name: "ID", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}
|
||||
usernameField := code.StructField{Name: "Username", Type: code.SimpleType("string")}
|
||||
usernameField := code.StructField{Name: "Username", Type: code.TypeString}
|
||||
fields := code.StructFields{idField, usernameField}
|
||||
|
||||
t.Run("struct field found", func(t *testing.T) {
|
||||
|
@ -91,7 +91,7 @@ func TestTypeIsNumber(t *testing.T) {
|
|||
testTable := []TypeIsNumberTestCase{
|
||||
{
|
||||
Name: "simple type: int",
|
||||
Type: code.SimpleType("int"),
|
||||
Type: code.TypeInt,
|
||||
IsNumber: true,
|
||||
},
|
||||
{
|
||||
|
@ -116,12 +116,12 @@ func TestTypeIsNumber(t *testing.T) {
|
|||
},
|
||||
{
|
||||
Name: "simple type: other float variants",
|
||||
Type: code.SimpleType("float64"),
|
||||
Type: code.TypeFloat64,
|
||||
IsNumber: true,
|
||||
},
|
||||
{
|
||||
Name: "simple type: non-number primitive type",
|
||||
Type: code.SimpleType("string"),
|
||||
Type: code.TypeString,
|
||||
IsNumber: false,
|
||||
},
|
||||
{
|
||||
|
@ -136,22 +136,22 @@ func TestTypeIsNumber(t *testing.T) {
|
|||
},
|
||||
{
|
||||
Name: "pointer type: number",
|
||||
Type: code.PointerType{ContainedType: code.SimpleType("int")},
|
||||
Type: code.PointerType{ContainedType: code.TypeInt},
|
||||
IsNumber: true,
|
||||
},
|
||||
{
|
||||
Name: "pointer type: non-number",
|
||||
Type: code.PointerType{ContainedType: code.SimpleType("string")},
|
||||
Type: code.PointerType{ContainedType: code.TypeString},
|
||||
IsNumber: false,
|
||||
},
|
||||
{
|
||||
Name: "array type",
|
||||
Type: code.ArrayType{ContainedType: code.SimpleType("int")},
|
||||
Type: code.ArrayType{ContainedType: code.TypeInt},
|
||||
IsNumber: false,
|
||||
},
|
||||
{
|
||||
Name: "map type",
|
||||
Type: code.MapType{KeyType: code.SimpleType("int"), ValueType: code.SimpleType("float64")},
|
||||
Type: code.MapType{KeyType: code.TypeInt, ValueType: code.TypeFloat64},
|
||||
IsNumber: false,
|
||||
},
|
||||
{
|
||||
|
|
41
internal/codegen/base.go
Normal file
41
internal/codegen/base.go
Normal file
|
@ -0,0 +1,41 @@
|
|||
package codegen
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/sunboyy/repogen/internal/code"
|
||||
)
|
||||
|
||||
const baseTemplate = `// Code generated by {{.Program}}. DO NOT EDIT.
|
||||
package {{.PackageName}}
|
||||
|
||||
import (
|
||||
{{.GenImports}}
|
||||
)
|
||||
`
|
||||
|
||||
type baseTemplateData struct {
|
||||
Program string
|
||||
PackageName string
|
||||
Imports [][]code.Import
|
||||
}
|
||||
|
||||
func (data baseTemplateData) GenImports() string {
|
||||
var sections []string
|
||||
for _, importSection := range data.Imports {
|
||||
var section []string
|
||||
for _, imp := range importSection {
|
||||
section = append(section, data.generateImportLine(imp))
|
||||
}
|
||||
sections = append(sections, strings.Join(section, "\n"))
|
||||
}
|
||||
return strings.Join(sections, "\n\n")
|
||||
}
|
||||
|
||||
func (data baseTemplateData) generateImportLine(imp code.Import) string {
|
||||
if imp.Name == "" {
|
||||
return fmt.Sprintf("\t\"%s\"", imp.Path)
|
||||
}
|
||||
return fmt.Sprintf("\t%s \"%s\"", imp.Name, imp.Path)
|
||||
}
|
83
internal/codegen/builder.go
Normal file
83
internal/codegen/builder.go
Normal file
|
@ -0,0 +1,83 @@
|
|||
package codegen
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"text/template"
|
||||
|
||||
"github.com/sunboyy/repogen/internal/code"
|
||||
"golang.org/x/tools/imports"
|
||||
)
|
||||
|
||||
type Builder struct {
|
||||
// Program defines generator program name in the generated file.
|
||||
Program string
|
||||
|
||||
// PackageName defines the package name of the generated file.
|
||||
PackageName string
|
||||
|
||||
// Imports defines necessary imports to reduce ambiguity when generating
|
||||
// formatting the raw-generated code.
|
||||
Imports [][]code.Import
|
||||
|
||||
implementers []Implementer
|
||||
}
|
||||
|
||||
// Implementer is an interface that wraps the basic Impl method for code
|
||||
// generation.
|
||||
type Implementer interface {
|
||||
Impl(buffer *bytes.Buffer) error
|
||||
}
|
||||
|
||||
// NewBuilder is a constructor of Builder struct.
|
||||
func NewBuilder(program string, packageName string, imports [][]code.Import) *Builder {
|
||||
return &Builder{
|
||||
Program: program,
|
||||
PackageName: packageName,
|
||||
Imports: imports,
|
||||
}
|
||||
}
|
||||
|
||||
// AddImplementer appends a new implemeneter to the implementer list.
|
||||
func (b *Builder) AddImplementer(implementer Implementer) {
|
||||
b.implementers = append(b.implementers, implementer)
|
||||
}
|
||||
|
||||
// Build generates code from the previously provided specifications.
|
||||
func (b Builder) Build() (string, error) {
|
||||
buffer := new(bytes.Buffer)
|
||||
|
||||
if err := b.buildBase(buffer); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
for _, impl := range b.implementers {
|
||||
if err := impl.Impl(buffer); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
formattedCode, err := imports.Process("", buffer.Bytes(), nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return string(formattedCode), nil
|
||||
}
|
||||
|
||||
func (b Builder) buildBase(buffer *bytes.Buffer) error {
|
||||
tmpl, err := template.New("file_base").Parse(baseTemplate)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tmplData := baseTemplateData{
|
||||
Program: b.Program,
|
||||
PackageName: b.PackageName,
|
||||
Imports: b.Imports,
|
||||
}
|
||||
|
||||
// writing to a buffer should not cause errors.
|
||||
_ = tmpl.Execute(buffer, tmplData)
|
||||
|
||||
return nil
|
||||
}
|
102
internal/codegen/builder_test.go
Normal file
102
internal/codegen/builder_test.go
Normal file
|
@ -0,0 +1,102 @@
|
|||
package codegen_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/sunboyy/repogen/internal/code"
|
||||
"github.com/sunboyy/repogen/internal/codegen"
|
||||
"github.com/sunboyy/repogen/internal/testutils"
|
||||
)
|
||||
|
||||
const expectedBuildCode = `// Code generated by repogen. DO NOT EDIT.
|
||||
package user
|
||||
|
||||
import (
|
||||
_ "context"
|
||||
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
_ "go.mongodb.org/mongo-driver/mongo/options"
|
||||
)
|
||||
|
||||
type User struct {
|
||||
ID primitive.ObjectID ` + "`bson:\"id\" json:\"id,omitempty\"`" + `
|
||||
Username string
|
||||
}
|
||||
|
||||
func NewUser(username string) User {
|
||||
return User{
|
||||
ID: primitive.NewObjectID(),
|
||||
Username: username,
|
||||
}
|
||||
}
|
||||
|
||||
func (u User) IDHex() string {
|
||||
return u.ID.Hex()
|
||||
}
|
||||
`
|
||||
|
||||
func TestBuilderBuild(t *testing.T) {
|
||||
builder := codegen.NewBuilder("repogen", "user", [][]code.Import{
|
||||
{
|
||||
{
|
||||
Name: "_",
|
||||
Path: "context",
|
||||
},
|
||||
},
|
||||
{
|
||||
{
|
||||
Path: "go.mongodb.org/mongo-driver/bson/primitive",
|
||||
},
|
||||
{
|
||||
Name: "_",
|
||||
Path: "go.mongodb.org/mongo-driver/mongo/options",
|
||||
},
|
||||
},
|
||||
})
|
||||
builder.AddImplementer(codegen.StructBuilder{
|
||||
Name: "User",
|
||||
Fields: code.StructFields{
|
||||
{
|
||||
Name: "ID",
|
||||
Type: code.ExternalType{
|
||||
PackageAlias: "primitive",
|
||||
Name: "ObjectID",
|
||||
},
|
||||
Tags: map[string][]string{
|
||||
"bson": {"id"},
|
||||
"json": {"id", "omitempty"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "Username",
|
||||
Type: code.TypeString,
|
||||
},
|
||||
},
|
||||
})
|
||||
builder.AddImplementer(codegen.FunctionBuilder{
|
||||
Name: "NewUser",
|
||||
Params: []code.Param{
|
||||
{Name: "username", Type: code.TypeString},
|
||||
},
|
||||
Returns: []code.Type{code.SimpleType("User")},
|
||||
Body: ` return User{
|
||||
ID: primitive.NewObjectID(),
|
||||
Username: username,
|
||||
}`,
|
||||
})
|
||||
builder.AddImplementer(codegen.MethodBuilder{
|
||||
Receiver: codegen.MethodReceiver{Name: "u", Type: code.SimpleType("User")},
|
||||
Name: "IDHex",
|
||||
Params: nil,
|
||||
Returns: []code.Type{code.TypeString},
|
||||
Body: " return u.ID.Hex()",
|
||||
})
|
||||
|
||||
generatedCode, err := builder.Build()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := testutils.ExpectMultiLineString(expectedBuildCode, generatedCode); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
73
internal/codegen/function.go
Normal file
73
internal/codegen/function.go
Normal file
|
@ -0,0 +1,73 @@
|
|||
package codegen
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"strings"
|
||||
"text/template"
|
||||
|
||||
"github.com/sunboyy/repogen/internal/code"
|
||||
)
|
||||
|
||||
const functionTemplate = `
|
||||
func {{.Name}}({{.GenParams}}){{.GenReturns}} {
|
||||
{{.Body}}
|
||||
}
|
||||
`
|
||||
|
||||
// FunctionBuilder is an implementer of a function.
|
||||
type FunctionBuilder struct {
|
||||
Name string
|
||||
Params []code.Param
|
||||
Returns []code.Type
|
||||
Body string
|
||||
}
|
||||
|
||||
// Impl writes function declatation code to the buffer.
|
||||
func (fb FunctionBuilder) Impl(buffer *bytes.Buffer) error {
|
||||
tmpl, err := template.New("function").Parse(functionTemplate)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// writing to a buffer should not cause errors.
|
||||
_ = tmpl.Execute(buffer, fb)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (fb FunctionBuilder) GenParams() string {
|
||||
return generateParams(fb.Params)
|
||||
}
|
||||
|
||||
func (fb FunctionBuilder) GenReturns() string {
|
||||
return generateReturns(fb.Returns)
|
||||
}
|
||||
|
||||
func generateParams(params []code.Param) string {
|
||||
var paramList []string
|
||||
for _, param := range params {
|
||||
paramList = append(
|
||||
paramList,
|
||||
fmt.Sprintf("%s %s", param.Name, param.Type.Code()),
|
||||
)
|
||||
}
|
||||
return strings.Join(paramList, ", ")
|
||||
}
|
||||
|
||||
func generateReturns(returns []code.Type) string {
|
||||
if len(returns) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
if len(returns) == 1 {
|
||||
return " " + returns[0].Code()
|
||||
}
|
||||
|
||||
var returnList []string
|
||||
for _, ret := range returns {
|
||||
returnList = append(returnList, ret.Code())
|
||||
}
|
||||
|
||||
return fmt.Sprintf(" (%s)", strings.Join(returnList, ", "))
|
||||
}
|
125
internal/codegen/function_test.go
Normal file
125
internal/codegen/function_test.go
Normal file
|
@ -0,0 +1,125 @@
|
|||
package codegen_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"github.com/sunboyy/repogen/internal/code"
|
||||
"github.com/sunboyy/repogen/internal/codegen"
|
||||
"github.com/sunboyy/repogen/internal/testutils"
|
||||
)
|
||||
|
||||
func TestFunctionBuilderBuild_NoReturn(t *testing.T) {
|
||||
fb := codegen.FunctionBuilder{
|
||||
Name: "init",
|
||||
Params: nil,
|
||||
Returns: nil,
|
||||
Body: ` logrus.SetLevel(logrus.DebugLevel)`,
|
||||
}
|
||||
expectedCode := `
|
||||
func init() {
|
||||
logrus.SetLevel(logrus.DebugLevel)
|
||||
}
|
||||
`
|
||||
buffer := new(bytes.Buffer)
|
||||
|
||||
err := fb.Impl(buffer)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
actual := buffer.String()
|
||||
if err := testutils.ExpectMultiLineString(
|
||||
expectedCode,
|
||||
actual,
|
||||
); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFunctionBuilderBuild_OneReturn(t *testing.T) {
|
||||
fb := codegen.FunctionBuilder{
|
||||
Name: "NewUser",
|
||||
Params: []code.Param{
|
||||
{
|
||||
Name: "username",
|
||||
Type: code.TypeString,
|
||||
},
|
||||
{
|
||||
Name: "age",
|
||||
Type: code.TypeInt,
|
||||
},
|
||||
{
|
||||
Name: "parent",
|
||||
Type: code.PointerType{ContainedType: code.SimpleType("User")},
|
||||
},
|
||||
},
|
||||
Returns: []code.Type{
|
||||
code.SimpleType("User"),
|
||||
},
|
||||
Body: ` return User{
|
||||
Username: username,
|
||||
Age: age,
|
||||
Parent: parent
|
||||
}`,
|
||||
}
|
||||
expectedCode := `
|
||||
func NewUser(username string, age int, parent *User) User {
|
||||
return User{
|
||||
Username: username,
|
||||
Age: age,
|
||||
Parent: parent
|
||||
}
|
||||
}
|
||||
`
|
||||
buffer := new(bytes.Buffer)
|
||||
|
||||
err := fb.Impl(buffer)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
actual := buffer.String()
|
||||
if err := testutils.ExpectMultiLineString(
|
||||
expectedCode,
|
||||
actual,
|
||||
); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFunctionBuilderBuild_MultiReturn(t *testing.T) {
|
||||
fb := codegen.FunctionBuilder{
|
||||
Name: "Save",
|
||||
Params: []code.Param{
|
||||
{
|
||||
Name: "user",
|
||||
Type: code.SimpleType("User"),
|
||||
},
|
||||
},
|
||||
Returns: []code.Type{
|
||||
code.SimpleType("User"),
|
||||
code.TypeError,
|
||||
},
|
||||
Body: ` return collection.Save(user)`,
|
||||
}
|
||||
expectedCode := `
|
||||
func Save(user User) (User, error) {
|
||||
return collection.Save(user)
|
||||
}
|
||||
`
|
||||
buffer := new(bytes.Buffer)
|
||||
|
||||
err := fb.Impl(buffer)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
actual := buffer.String()
|
||||
if err := testutils.ExpectMultiLineString(
|
||||
expectedCode,
|
||||
actual,
|
||||
); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
66
internal/codegen/method.go
Normal file
66
internal/codegen/method.go
Normal file
|
@ -0,0 +1,66 @@
|
|||
package codegen
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"text/template"
|
||||
|
||||
"github.com/sunboyy/repogen/internal/code"
|
||||
)
|
||||
|
||||
const methodTemplate = `
|
||||
func ({{.GenReceiver}}) {{.Name}}({{.GenParams}}){{.GenReturns}} {
|
||||
{{.Body}}
|
||||
}
|
||||
`
|
||||
|
||||
// MethodBuilder is an implementer of a method.
|
||||
type MethodBuilder struct {
|
||||
Receiver MethodReceiver
|
||||
Name string
|
||||
Params []code.Param
|
||||
Returns []code.Type
|
||||
Body string
|
||||
}
|
||||
|
||||
// MethodReceiver describes a specification of a method receiver.
|
||||
type MethodReceiver struct {
|
||||
Name string
|
||||
Type code.SimpleType
|
||||
Pointer bool
|
||||
}
|
||||
|
||||
// Impl writes method declatation code to the buffer.
|
||||
func (mb MethodBuilder) Impl(buffer *bytes.Buffer) error {
|
||||
tmpl, err := template.New("function").Parse(methodTemplate)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// writing to a buffer should not cause errors.
|
||||
_ = tmpl.Execute(buffer, mb)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mb MethodBuilder) GenReceiver() string {
|
||||
if mb.Receiver.Name == "" {
|
||||
return mb.generateReceiverType()
|
||||
}
|
||||
return fmt.Sprintf("%s %s", mb.Receiver.Name, mb.generateReceiverType())
|
||||
}
|
||||
|
||||
func (mb MethodBuilder) generateReceiverType() string {
|
||||
if !mb.Receiver.Pointer {
|
||||
return mb.Receiver.Type.Code()
|
||||
}
|
||||
return code.PointerType{ContainedType: mb.Receiver.Type}.Code()
|
||||
}
|
||||
|
||||
func (mb MethodBuilder) GenParams() string {
|
||||
return generateParams(mb.Params)
|
||||
}
|
||||
|
||||
func (mb MethodBuilder) GenReturns() string {
|
||||
return generateReturns(mb.Returns)
|
||||
}
|
142
internal/codegen/method_test.go
Normal file
142
internal/codegen/method_test.go
Normal file
|
@ -0,0 +1,142 @@
|
|||
package codegen_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"github.com/sunboyy/repogen/internal/code"
|
||||
"github.com/sunboyy/repogen/internal/codegen"
|
||||
"github.com/sunboyy/repogen/internal/testutils"
|
||||
)
|
||||
|
||||
func TestMethodBuilderBuild_IgnoreReceiverNoReturn(t *testing.T) {
|
||||
fb := codegen.MethodBuilder{
|
||||
Receiver: codegen.MethodReceiver{Type: "User"},
|
||||
Name: "Init",
|
||||
Params: nil,
|
||||
Returns: nil,
|
||||
Body: ` db.Init(&User{})`,
|
||||
}
|
||||
expectedCode := `
|
||||
func (User) Init() {
|
||||
db.Init(&User{})
|
||||
}
|
||||
`
|
||||
buffer := new(bytes.Buffer)
|
||||
|
||||
err := fb.Impl(buffer)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
actual := buffer.String()
|
||||
if err := testutils.ExpectMultiLineString(
|
||||
expectedCode,
|
||||
actual,
|
||||
); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMethodBuilderBuild_IgnorePoinerReceiverOneReturn(t *testing.T) {
|
||||
fb := codegen.MethodBuilder{
|
||||
Receiver: codegen.MethodReceiver{
|
||||
Type: "User",
|
||||
Pointer: true,
|
||||
},
|
||||
Name: "Init",
|
||||
Params: nil,
|
||||
Returns: []code.Type{code.TypeError},
|
||||
Body: ` return db.Init(&User{})`,
|
||||
}
|
||||
expectedCode := `
|
||||
func (*User) Init() error {
|
||||
return db.Init(&User{})
|
||||
}
|
||||
`
|
||||
buffer := new(bytes.Buffer)
|
||||
|
||||
err := fb.Impl(buffer)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
actual := buffer.String()
|
||||
if err := testutils.ExpectMultiLineString(
|
||||
expectedCode,
|
||||
actual,
|
||||
); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMethodBuilderBuild_UseReceiverMultiReturn(t *testing.T) {
|
||||
fb := codegen.MethodBuilder{
|
||||
Receiver: codegen.MethodReceiver{
|
||||
Name: "u",
|
||||
Type: "User",
|
||||
},
|
||||
Name: "WithAge",
|
||||
Params: []code.Param{
|
||||
{Name: "age", Type: code.TypeInt},
|
||||
},
|
||||
Returns: []code.Type{code.SimpleType("User"), code.TypeError},
|
||||
Body: ` u.Age = age
|
||||
return u`,
|
||||
}
|
||||
expectedCode := `
|
||||
func (u User) WithAge(age int) (User, error) {
|
||||
u.Age = age
|
||||
return u
|
||||
}
|
||||
`
|
||||
buffer := new(bytes.Buffer)
|
||||
|
||||
err := fb.Impl(buffer)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
actual := buffer.String()
|
||||
if err := testutils.ExpectMultiLineString(
|
||||
expectedCode,
|
||||
actual,
|
||||
); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMethodBuilderBuild_UsePointerReceiverNoReturn(t *testing.T) {
|
||||
fb := codegen.MethodBuilder{
|
||||
Receiver: codegen.MethodReceiver{
|
||||
Name: "u",
|
||||
Type: "User",
|
||||
Pointer: true,
|
||||
},
|
||||
Name: "SetAge",
|
||||
Params: []code.Param{
|
||||
{Name: "age", Type: code.TypeInt},
|
||||
},
|
||||
Returns: nil,
|
||||
Body: ` u.Age = age`,
|
||||
}
|
||||
expectedCode := `
|
||||
func (u *User) SetAge(age int) {
|
||||
u.Age = age
|
||||
}
|
||||
`
|
||||
buffer := new(bytes.Buffer)
|
||||
|
||||
err := fb.Impl(buffer)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
actual := buffer.String()
|
||||
if err := testutils.ExpectMultiLineString(
|
||||
expectedCode,
|
||||
actual,
|
||||
); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
63
internal/codegen/struct.go
Normal file
63
internal/codegen/struct.go
Normal file
|
@ -0,0 +1,63 @@
|
|||
package codegen
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
"text/template"
|
||||
|
||||
"github.com/sunboyy/repogen/internal/code"
|
||||
)
|
||||
|
||||
const structTemplate = `
|
||||
type {{.Name}} struct {
|
||||
{{.GenFields}}
|
||||
}
|
||||
`
|
||||
|
||||
// StructBuilder is an implementer of a struct.
|
||||
type StructBuilder struct {
|
||||
Name string
|
||||
Fields code.StructFields
|
||||
}
|
||||
|
||||
// Impl writes struct declatation code to the buffer.
|
||||
func (sb StructBuilder) Impl(buffer *bytes.Buffer) error {
|
||||
tmpl, err := template.New("struct").Parse(structTemplate)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// writing to a buffer should not cause errors.
|
||||
_ = tmpl.Execute(buffer, sb)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sb StructBuilder) GenFields() string {
|
||||
var fieldLines []string
|
||||
for _, field := range sb.Fields {
|
||||
fieldLine := fmt.Sprintf("\t%s %s", field.Name, field.Type.Code())
|
||||
if len(field.Tags) > 0 {
|
||||
fieldLine += fmt.Sprintf(" `%s`", sb.generateStructTag(field.Tags))
|
||||
}
|
||||
fieldLines = append(fieldLines, fieldLine)
|
||||
}
|
||||
return strings.Join(fieldLines, "\n")
|
||||
}
|
||||
|
||||
func (sb StructBuilder) generateStructTag(tags map[string][]string) string {
|
||||
var tagKeys []string
|
||||
for key := range tags {
|
||||
tagKeys = append(tagKeys, key)
|
||||
}
|
||||
sort.Strings(tagKeys)
|
||||
|
||||
var tagGroups []string
|
||||
for _, key := range tagKeys {
|
||||
tagValue := strings.Join(tags[key], ",")
|
||||
tagGroups = append(tagGroups, fmt.Sprintf("%s:\"%s\"", key, tagValue))
|
||||
}
|
||||
return strings.Join(tagGroups, " ")
|
||||
}
|
73
internal/codegen/struct_test.go
Normal file
73
internal/codegen/struct_test.go
Normal file
|
@ -0,0 +1,73 @@
|
|||
package codegen_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"github.com/sunboyy/repogen/internal/code"
|
||||
"github.com/sunboyy/repogen/internal/codegen"
|
||||
"github.com/sunboyy/repogen/internal/testutils"
|
||||
)
|
||||
|
||||
const expectedStructBuilderCode = `
|
||||
type User struct {
|
||||
ID primitive.ObjectID ` + "`bson:\"id,omitempty\" json:\"id,omitempty\"`" + `
|
||||
Username string ` + "`bson:\"username\" json:\"username\"`" + `
|
||||
Age int ` + "`bson:\"age\"`" + `
|
||||
orderCount *int
|
||||
}
|
||||
`
|
||||
|
||||
func TestStructBuilderBuild(t *testing.T) {
|
||||
sb := codegen.StructBuilder{
|
||||
Name: "User",
|
||||
Fields: []code.StructField{
|
||||
{
|
||||
Name: "ID",
|
||||
Type: code.ExternalType{
|
||||
PackageAlias: "primitive",
|
||||
Name: "ObjectID",
|
||||
},
|
||||
Tags: map[string][]string{
|
||||
"json": {"id", "omitempty"},
|
||||
"bson": {"id", "omitempty"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "Username",
|
||||
Type: code.TypeString,
|
||||
Tags: map[string][]string{
|
||||
"json": {"username"},
|
||||
"bson": {"username"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "Age",
|
||||
Type: code.TypeInt,
|
||||
Tags: map[string][]string{
|
||||
"bson": {"age"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "orderCount",
|
||||
Type: code.PointerType{
|
||||
ContainedType: code.TypeInt,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
buffer := new(bytes.Buffer)
|
||||
|
||||
err := sb.Impl(buffer)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
actual := buffer.String()
|
||||
if err := testutils.ExpectMultiLineString(
|
||||
expectedStructBuilderCode,
|
||||
actual,
|
||||
); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
|
@ -1,75 +1,40 @@
|
|||
package generator
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"html/template"
|
||||
|
||||
"github.com/sunboyy/repogen/internal/code"
|
||||
"github.com/sunboyy/repogen/internal/codegen"
|
||||
"github.com/sunboyy/repogen/internal/mongo"
|
||||
"github.com/sunboyy/repogen/internal/spec"
|
||||
"golang.org/x/tools/imports"
|
||||
)
|
||||
|
||||
// GenerateRepository generates repository implementation from repository interface specification
|
||||
func GenerateRepository(packageName string, structModel code.Struct, interfaceName string,
|
||||
methodSpecs []spec.MethodSpec) (string, error) {
|
||||
// GenerateRepository generates repository implementation code from repository
|
||||
// interface specification.
|
||||
func GenerateRepository(packageName string, structModel code.Struct,
|
||||
interfaceName string, methodSpecs []spec.MethodSpec) (string, error) {
|
||||
|
||||
repositoryGenerator := repositoryGenerator{
|
||||
PackageName: packageName,
|
||||
StructModel: structModel,
|
||||
InterfaceName: interfaceName,
|
||||
MethodSpecs: methodSpecs,
|
||||
Generator: mongo.NewGenerator(structModel, interfaceName),
|
||||
}
|
||||
generator := mongo.NewGenerator(structModel, interfaceName)
|
||||
|
||||
return repositoryGenerator.Generate()
|
||||
}
|
||||
codeBuilder := codegen.NewBuilder(
|
||||
"repogen",
|
||||
packageName,
|
||||
generator.Imports(),
|
||||
)
|
||||
|
||||
type repositoryGenerator struct {
|
||||
PackageName string
|
||||
StructModel code.Struct
|
||||
InterfaceName string
|
||||
MethodSpecs []spec.MethodSpec
|
||||
Generator mongo.RepositoryGenerator
|
||||
}
|
||||
|
||||
func (g repositoryGenerator) Generate() (string, error) {
|
||||
buffer := new(bytes.Buffer)
|
||||
if err := g.generateBase(buffer); err != nil {
|
||||
constructorBuilder, err := generator.GenerateConstructor()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if err := g.Generator.GenerateConstructor(buffer); err != nil {
|
||||
return "", err
|
||||
}
|
||||
codeBuilder.AddImplementer(constructorBuilder)
|
||||
codeBuilder.AddImplementer(generator.GenerateStruct())
|
||||
|
||||
for _, method := range g.MethodSpecs {
|
||||
if err := g.Generator.GenerateMethod(method, buffer); err != nil {
|
||||
for _, method := range methodSpecs {
|
||||
methodBuilder, err := generator.GenerateMethod(method)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
codeBuilder.AddImplementer(methodBuilder)
|
||||
}
|
||||
|
||||
formattedCode, err := imports.Process("", buffer.Bytes(), nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return string(formattedCode), nil
|
||||
}
|
||||
|
||||
func (g repositoryGenerator) generateBase(buffer *bytes.Buffer) error {
|
||||
tmpl, err := template.New("file_base").Parse(baseTemplate)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tmplData := baseTemplateData{
|
||||
PackageName: g.PackageName,
|
||||
}
|
||||
|
||||
if err := tmpl.Execute(buffer, tmplData); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
return codeBuilder.Build()
|
||||
}
|
||||
|
|
|
@ -22,7 +22,7 @@ var (
|
|||
}
|
||||
ageField = code.StructField{
|
||||
Name: "Age",
|
||||
Type: code.SimpleType("int"),
|
||||
Type: code.TypeInt,
|
||||
Tags: map[string][]string{"bson": {"age"}},
|
||||
}
|
||||
)
|
||||
|
@ -34,7 +34,7 @@ func TestGenerateMongoRepository(t *testing.T) {
|
|||
idField,
|
||||
code.StructField{
|
||||
Name: "Username",
|
||||
Type: code.SimpleType("string"),
|
||||
Type: code.TypeString,
|
||||
Tags: map[string][]string{"bson": {"username"}},
|
||||
},
|
||||
genderField,
|
||||
|
@ -49,7 +49,7 @@ func TestGenerateMongoRepository(t *testing.T) {
|
|||
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
|
||||
{Name: "id", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}},
|
||||
},
|
||||
Returns: []code.Type{code.PointerType{ContainedType: code.SimpleType("UserModel")}, code.SimpleType("error")},
|
||||
Returns: []code.Type{code.PointerType{ContainedType: code.SimpleType("UserModel")}, code.TypeError},
|
||||
Operation: spec.FindOperation{
|
||||
Mode: spec.QueryModeOne,
|
||||
Query: spec.QuerySpec{
|
||||
|
@ -65,11 +65,11 @@ func TestGenerateMongoRepository(t *testing.T) {
|
|||
Params: []code.Param{
|
||||
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
|
||||
{Name: "gender", Type: code.SimpleType("Gender")},
|
||||
{Name: "age", Type: code.SimpleType("int")},
|
||||
{Name: "age", Type: code.TypeInt},
|
||||
},
|
||||
Returns: []code.Type{
|
||||
code.PointerType{ContainedType: code.SimpleType("UserModel")},
|
||||
code.SimpleType("error"),
|
||||
code.TypeError,
|
||||
},
|
||||
Operation: spec.FindOperation{
|
||||
Mode: spec.QueryModeMany,
|
||||
|
@ -86,11 +86,11 @@ func TestGenerateMongoRepository(t *testing.T) {
|
|||
Name: "FindByAgeLessThanEqualOrderByAge",
|
||||
Params: []code.Param{
|
||||
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
|
||||
{Name: "age", Type: code.SimpleType("int")},
|
||||
{Name: "age", Type: code.TypeInt},
|
||||
},
|
||||
Returns: []code.Type{
|
||||
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
|
||||
code.SimpleType("error"),
|
||||
code.TypeError,
|
||||
},
|
||||
Operation: spec.FindOperation{
|
||||
Mode: spec.QueryModeMany,
|
||||
|
@ -108,11 +108,11 @@ func TestGenerateMongoRepository(t *testing.T) {
|
|||
Name: "FindByAgeGreaterThanOrderByAgeAsc",
|
||||
Params: []code.Param{
|
||||
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
|
||||
{Name: "age", Type: code.SimpleType("int")},
|
||||
{Name: "age", Type: code.TypeInt},
|
||||
},
|
||||
Returns: []code.Type{
|
||||
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
|
||||
code.SimpleType("error"),
|
||||
code.TypeError,
|
||||
},
|
||||
Operation: spec.FindOperation{
|
||||
Mode: spec.QueryModeMany,
|
||||
|
@ -130,11 +130,11 @@ func TestGenerateMongoRepository(t *testing.T) {
|
|||
Name: "FindByAgeGreaterThanEqualOrderByAgeDesc",
|
||||
Params: []code.Param{
|
||||
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
|
||||
{Name: "age", Type: code.SimpleType("int")},
|
||||
{Name: "age", Type: code.TypeInt},
|
||||
},
|
||||
Returns: []code.Type{
|
||||
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
|
||||
code.SimpleType("error"),
|
||||
code.TypeError,
|
||||
},
|
||||
Operation: spec.FindOperation{
|
||||
Mode: spec.QueryModeMany,
|
||||
|
@ -152,12 +152,12 @@ func TestGenerateMongoRepository(t *testing.T) {
|
|||
Name: "FindByAgeBetween",
|
||||
Params: []code.Param{
|
||||
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
|
||||
{Name: "fromAge", Type: code.SimpleType("int")},
|
||||
{Name: "toAge", Type: code.SimpleType("int")},
|
||||
{Name: "fromAge", Type: code.TypeInt},
|
||||
{Name: "toAge", Type: code.TypeInt},
|
||||
},
|
||||
Returns: []code.Type{
|
||||
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
|
||||
code.SimpleType("error"),
|
||||
code.TypeError,
|
||||
},
|
||||
Operation: spec.FindOperation{
|
||||
Mode: spec.QueryModeMany,
|
||||
|
@ -173,11 +173,11 @@ func TestGenerateMongoRepository(t *testing.T) {
|
|||
Params: []code.Param{
|
||||
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
|
||||
{Name: "gender", Type: code.SimpleType("Gender")},
|
||||
{Name: "age", Type: code.SimpleType("int")},
|
||||
{Name: "age", Type: code.TypeInt},
|
||||
},
|
||||
Returns: []code.Type{
|
||||
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
|
||||
code.SimpleType("error"),
|
||||
code.TypeError,
|
||||
},
|
||||
Operation: spec.FindOperation{
|
||||
Mode: spec.QueryModeMany,
|
||||
|
|
|
@ -1,9 +0,0 @@
|
|||
package generator
|
||||
|
||||
const baseTemplate = `// Code generated by repogen. DO NOT EDIT.
|
||||
package {{.PackageName}}
|
||||
`
|
||||
|
||||
type baseTemplateData struct {
|
||||
PackageName string
|
||||
}
|
|
@ -2,11 +2,12 @@ package mongo
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"fmt"
|
||||
"strings"
|
||||
"text/template"
|
||||
|
||||
"github.com/sunboyy/repogen/internal/code"
|
||||
"github.com/sunboyy/repogen/internal/codegen"
|
||||
"github.com/sunboyy/repogen/internal/spec"
|
||||
)
|
||||
|
||||
|
@ -24,55 +25,104 @@ type RepositoryGenerator struct {
|
|||
InterfaceName string
|
||||
}
|
||||
|
||||
// GenerateConstructor generates mongo repository struct implementation and constructor for the struct
|
||||
func (g RepositoryGenerator) GenerateConstructor(buffer io.Writer) error {
|
||||
tmpl, err := template.New("mongo_repository_base").Parse(constructorTemplate)
|
||||
if err != nil {
|
||||
return err
|
||||
// Imports returns necessary imports for the mongo repository implementation.
|
||||
func (g RepositoryGenerator) Imports() [][]code.Import {
|
||||
return [][]code.Import{
|
||||
{
|
||||
{Path: "context"},
|
||||
},
|
||||
{
|
||||
{Path: "go.mongodb.org/mongo-driver/bson"},
|
||||
{Path: "go.mongodb.org/mongo-driver/bson/primitive"},
|
||||
{Path: "go.mongodb.org/mongo-driver/mongo"},
|
||||
{Path: "go.mongodb.org/mongo-driver/mongo/options"},
|
||||
},
|
||||
}
|
||||
|
||||
tmplData := mongoConstructorTemplateData{
|
||||
InterfaceName: g.InterfaceName,
|
||||
ImplStructName: g.structName(),
|
||||
}
|
||||
|
||||
if err := tmpl.Execute(buffer, tmplData); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateMethod generates implementation of from provided method specification
|
||||
func (g RepositoryGenerator) GenerateMethod(methodSpec spec.MethodSpec, buffer io.Writer) error {
|
||||
tmpl, err := template.New("mongo_repository_method").Parse(methodTemplate)
|
||||
// GenerateStruct creates codegen.StructBuilder of mongo repository
|
||||
// implementation struct.
|
||||
func (g RepositoryGenerator) GenerateStruct() codegen.StructBuilder {
|
||||
return codegen.StructBuilder{
|
||||
Name: g.repoImplStructName(),
|
||||
Fields: code.StructFields{
|
||||
{
|
||||
Name: "collection",
|
||||
Type: code.PointerType{
|
||||
ContainedType: code.ExternalType{
|
||||
PackageAlias: "mongo",
|
||||
Name: "Collection",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateConstructor creates codegen.FunctionBuilder of a constructor for
|
||||
// mongo repository implementation struct.
|
||||
func (g RepositoryGenerator) GenerateConstructor() (codegen.FunctionBuilder, error) {
|
||||
tmpl, err := template.New("mongo_constructor_body").Parse(constructorBody)
|
||||
if err != nil {
|
||||
return err
|
||||
return codegen.FunctionBuilder{}, err
|
||||
}
|
||||
|
||||
tmplData := constructorBodyData{
|
||||
ImplStructName: g.repoImplStructName(),
|
||||
}
|
||||
|
||||
buffer := new(bytes.Buffer)
|
||||
if err := tmpl.Execute(buffer, tmplData); err != nil {
|
||||
return codegen.FunctionBuilder{}, err
|
||||
}
|
||||
|
||||
return codegen.FunctionBuilder{
|
||||
Name: "New" + g.InterfaceName,
|
||||
Params: []code.Param{
|
||||
{
|
||||
Name: "collection",
|
||||
Type: code.PointerType{
|
||||
ContainedType: code.ExternalType{
|
||||
PackageAlias: "mongo",
|
||||
Name: "Collection",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Returns: []code.Type{
|
||||
code.SimpleType(g.InterfaceName),
|
||||
},
|
||||
Body: buffer.String(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GenerateMethod creates codegen.MethodBuilder of repository method from the
|
||||
// provided method specification.
|
||||
func (g RepositoryGenerator) GenerateMethod(methodSpec spec.MethodSpec) (codegen.MethodBuilder, error) {
|
||||
var params []code.Param
|
||||
for i, param := range methodSpec.Params {
|
||||
params = append(params, code.Param{
|
||||
Name: fmt.Sprintf("arg%d", i),
|
||||
Type: param.Type,
|
||||
})
|
||||
}
|
||||
|
||||
implementation, err := g.generateMethodImplementation(methodSpec)
|
||||
if err != nil {
|
||||
return err
|
||||
return codegen.MethodBuilder{}, err
|
||||
}
|
||||
|
||||
var paramTypes []code.Type
|
||||
for _, param := range methodSpec.Params {
|
||||
paramTypes = append(paramTypes, param.Type)
|
||||
}
|
||||
|
||||
tmplData := mongoMethodTemplateData{
|
||||
StructName: g.structName(),
|
||||
MethodName: methodSpec.Name,
|
||||
ParamTypes: paramTypes,
|
||||
ReturnTypes: methodSpec.Returns,
|
||||
Implementation: implementation,
|
||||
}
|
||||
|
||||
if err := tmpl.Execute(buffer, tmplData); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
return codegen.MethodBuilder{
|
||||
Receiver: codegen.MethodReceiver{
|
||||
Name: "r",
|
||||
Type: code.SimpleType(g.repoImplStructName()),
|
||||
Pointer: true,
|
||||
},
|
||||
Name: methodSpec.Name,
|
||||
Params: params,
|
||||
Returns: methodSpec.Returns,
|
||||
Body: implementation,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (g RepositoryGenerator) generateMethodImplementation(methodSpec spec.MethodSpec) (string, error) {
|
||||
|
@ -275,7 +325,7 @@ func (g RepositoryGenerator) bsonTagFromField(field code.StructField) (string, e
|
|||
return bsonTag[0], nil
|
||||
}
|
||||
|
||||
func (g RepositoryGenerator) structName() string {
|
||||
func (g RepositoryGenerator) repoImplStructName() string {
|
||||
return g.InterfaceName + "Mongo"
|
||||
}
|
||||
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -1,77 +1,17 @@
|
|||
package mongo
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/sunboyy/repogen/internal/code"
|
||||
"github.com/sunboyy/repogen/internal/spec"
|
||||
)
|
||||
|
||||
const constructorTemplate = `
|
||||
import (
|
||||
"context"
|
||||
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
)
|
||||
|
||||
func New{{.InterfaceName}}(collection *mongo.Collection) {{.InterfaceName}} {
|
||||
return &{{.ImplStructName}}{
|
||||
const constructorBody = ` return &{{.ImplStructName}}{
|
||||
collection: collection,
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
type {{.ImplStructName}} struct {
|
||||
collection *mongo.Collection
|
||||
}
|
||||
`
|
||||
|
||||
type mongoConstructorTemplateData struct {
|
||||
InterfaceName string
|
||||
type constructorBodyData struct {
|
||||
ImplStructName string
|
||||
}
|
||||
|
||||
const methodTemplate = `
|
||||
func (r *{{.StructName}}) {{.MethodName}}({{.Parameters}}){{.Returns}} {
|
||||
{{.Implementation}}
|
||||
}
|
||||
`
|
||||
|
||||
type mongoMethodTemplateData struct {
|
||||
StructName string
|
||||
MethodName string
|
||||
ParamTypes []code.Type
|
||||
ReturnTypes []code.Type
|
||||
Implementation string
|
||||
}
|
||||
|
||||
func (data mongoMethodTemplateData) Parameters() string {
|
||||
var params []string
|
||||
for i, paramType := range data.ParamTypes {
|
||||
params = append(params, fmt.Sprintf("arg%d %s", i, paramType.Code()))
|
||||
}
|
||||
return strings.Join(params, ", ")
|
||||
}
|
||||
|
||||
func (data mongoMethodTemplateData) Returns() string {
|
||||
if len(data.ReturnTypes) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
if len(data.ReturnTypes) == 1 {
|
||||
return fmt.Sprintf(" %s", data.ReturnTypes[0].Code())
|
||||
}
|
||||
|
||||
var returns []string
|
||||
for _, returnType := range data.ReturnTypes {
|
||||
returns = append(returns, returnType.Code())
|
||||
}
|
||||
return fmt.Sprintf(" (%s)", strings.Join(returns, ", "))
|
||||
}
|
||||
|
||||
const insertOneTemplate = ` result, err := r.collection.InsertOne(arg0, arg1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -44,7 +44,7 @@ func TestError(t *testing.T) {
|
|||
Name: "IncompatibleComparatorError",
|
||||
Error: spec.NewIncompatibleComparatorError(spec.ComparatorTrue, code.StructField{
|
||||
Name: "Age",
|
||||
Type: code.SimpleType("int"),
|
||||
Type: code.TypeInt,
|
||||
}),
|
||||
ExpectedString: "cannot use comparator EQUAL_TRUE with struct field 'Age' of type 'int'",
|
||||
},
|
||||
|
@ -55,7 +55,7 @@ func TestError(t *testing.T) {
|
|||
},
|
||||
{
|
||||
Name: "ArgumentTypeNotMatchedError",
|
||||
Error: spec.NewArgumentTypeNotMatchedError("Age", code.SimpleType("int"), code.SimpleType("float64")),
|
||||
Error: spec.NewArgumentTypeNotMatchedError("Age", code.TypeInt, code.TypeFloat64),
|
||||
ExpectedString: "field 'Age' requires an argument of type 'int' (got 'float64')",
|
||||
},
|
||||
{
|
||||
|
@ -63,7 +63,7 @@ func TestError(t *testing.T) {
|
|||
Error: spec.NewIncompatibleUpdateOperatorError(spec.UpdateOperatorInc, spec.FieldReference{
|
||||
code.StructField{
|
||||
Name: "City",
|
||||
Type: code.SimpleType("string"),
|
||||
Type: code.TypeString,
|
||||
},
|
||||
}),
|
||||
ExpectedString: "cannot use update operator INC with struct field 'City' of type 'string'",
|
||||
|
|
|
@ -85,7 +85,7 @@ func (p interfaceMethodParser) extractInsertReturns(returns []code.Type) (QueryM
|
|||
return "", NewOperationReturnCountUnmatchedError(2)
|
||||
}
|
||||
|
||||
if returns[1] != code.SimpleType("error") {
|
||||
if returns[1] != code.TypeError {
|
||||
return "", NewUnsupportedReturnError(returns[1], 1)
|
||||
}
|
||||
|
||||
|
@ -203,7 +203,7 @@ func (p interfaceMethodParser) extractModelOrSliceReturns(returns []code.Type) (
|
|||
return "", NewOperationReturnCountUnmatchedError(2)
|
||||
}
|
||||
|
||||
if returns[1] != code.SimpleType("error") {
|
||||
if returns[1] != code.TypeError {
|
||||
return "", NewUnsupportedReturnError(returns[1], 1)
|
||||
}
|
||||
|
||||
|
@ -318,11 +318,11 @@ func (p interfaceMethodParser) validateCountReturns(returns []code.Type) error {
|
|||
return NewOperationReturnCountUnmatchedError(2)
|
||||
}
|
||||
|
||||
if returns[0] != code.SimpleType("int") {
|
||||
if returns[0] != code.TypeInt {
|
||||
return NewUnsupportedReturnError(returns[0], 0)
|
||||
}
|
||||
|
||||
if returns[1] != code.SimpleType("error") {
|
||||
if returns[1] != code.TypeError {
|
||||
return NewUnsupportedReturnError(returns[1], 1)
|
||||
}
|
||||
|
||||
|
@ -334,16 +334,16 @@ func (p interfaceMethodParser) extractIntOrBoolReturns(returns []code.Type) (Que
|
|||
return "", NewOperationReturnCountUnmatchedError(2)
|
||||
}
|
||||
|
||||
if returns[1] != code.SimpleType("error") {
|
||||
if returns[1] != code.TypeError {
|
||||
return "", NewUnsupportedReturnError(returns[1], 1)
|
||||
}
|
||||
|
||||
simpleType, ok := returns[0].(code.SimpleType)
|
||||
if ok {
|
||||
if simpleType == code.SimpleType("bool") {
|
||||
if simpleType == code.TypeBool {
|
||||
return QueryModeOne, nil
|
||||
}
|
||||
if simpleType == code.SimpleType("int") {
|
||||
if simpleType == code.TypeInt {
|
||||
return QueryModeMany, nil
|
||||
}
|
||||
}
|
||||
|
@ -367,7 +367,7 @@ func (p interfaceMethodParser) validateQueryFromParams(params []code.Param, quer
|
|||
var currentParamIndex int
|
||||
for _, predicate := range querySpec.Predicates {
|
||||
if (predicate.Comparator == ComparatorTrue || predicate.Comparator == ComparatorFalse) &&
|
||||
predicate.FieldReference.ReferencedField().Type != code.SimpleType("bool") {
|
||||
predicate.FieldReference.ReferencedField().Type != code.TypeBool {
|
||||
return NewIncompatibleComparatorError(predicate.Comparator,
|
||||
predicate.FieldReference.ReferencedField())
|
||||
}
|
||||
|
|
File diff suppressed because it is too large
Load diff
3
main.go
3
main.go
|
@ -21,7 +21,8 @@ const usageText = `repogen generates MongoDB repository implementation from repo
|
|||
|
||||
Supported options:`
|
||||
|
||||
const version = "v0.2.1"
|
||||
// version indicates the version of repogen.
|
||||
const version = "v0.3-next"
|
||||
|
||||
func main() {
|
||||
flag.Usage = printUsage
|
||||
|
|
Loading…
Reference in a new issue