Add entrypoint to the program

This commit is contained in:
sunboyy 2021-01-23 20:03:16 +07:00
parent f677158d69
commit 489edbd9e2
10 changed files with 927 additions and 266 deletions

View file

@ -2,7 +2,7 @@ coverage:
status:
project:
default:
target: 70%
target: 75%
threshold: 5%
patch:
default:

View file

@ -0,0 +1,75 @@
package generator
import (
"bytes"
"html/template"
"github.com/sunboyy/repogen/internal/code"
"github.com/sunboyy/repogen/internal/mongo"
"github.com/sunboyy/repogen/internal/spec"
"golang.org/x/tools/imports"
)
// GenerateRepository generates repository implementation from repository interface specification
func GenerateRepository(packageName string, structModel code.Struct, interfaceName string,
methodSpecs []spec.MethodSpec) (string, error) {
repositoryGenerator := repositoryGenerator{
PackageName: packageName,
StructModel: structModel,
InterfaceName: interfaceName,
MethodSpecs: methodSpecs,
Generator: mongo.NewGenerator(structModel, interfaceName),
}
return repositoryGenerator.Generate()
}
type repositoryGenerator struct {
PackageName string
StructModel code.Struct
InterfaceName string
MethodSpecs []spec.MethodSpec
Generator mongo.RepositoryGenerator
}
func (g repositoryGenerator) Generate() (string, error) {
buffer := new(bytes.Buffer)
if err := g.generateBase(buffer); err != nil {
return "", err
}
if err := g.Generator.GenerateConstructor(buffer); err != nil {
return "", err
}
for _, method := range g.MethodSpecs {
if err := g.Generator.GenerateMethod(method, buffer); err != nil {
return "", err
}
}
formattedCode, err := imports.Process("", buffer.Bytes(), nil)
if err != nil {
return "", err
}
return string(formattedCode), nil
}
func (g repositoryGenerator) generateBase(buffer *bytes.Buffer) error {
tmpl, err := template.New("file_base").Parse(baseTemplate)
if err != nil {
return err
}
tmplData := baseTemplateData{
PackageName: g.PackageName,
}
if err := tmpl.Execute(buffer, tmplData); err != nil {
return err
}
return nil
}

View file

@ -0,0 +1,278 @@
package generator_test
import (
"strings"
"testing"
"github.com/sunboyy/repogen/internal/code"
"github.com/sunboyy/repogen/internal/generator"
"github.com/sunboyy/repogen/internal/spec"
)
func TestGenerateMongoRepository(t *testing.T) {
userModel := code.Struct{
Name: "UserModel",
Fields: code.StructFields{
{
Name: "ID",
Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"},
Tags: map[string][]string{"bson": {"_id", "omitempty"}},
},
{
Name: "Username",
Type: code.SimpleType("string"),
Tags: map[string][]string{"bson": {"username"}},
},
{
Name: "Gender",
Type: code.SimpleType("Gender"),
Tags: map[string][]string{"bson": {"gender"}},
},
{
Name: "Age",
Type: code.SimpleType("int"),
Tags: map[string][]string{"bson": {"age"}},
},
},
}
methods := []spec.MethodSpec{
// test find: One mode
{
Name: "FindByID",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "id", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}},
},
Returns: []code.Type{code.PointerType{ContainedType: code.SimpleType("UserModel")}, code.SimpleType("error")},
Operation: spec.FindOperation{
Mode: spec.QueryModeOne,
Query: spec.QuerySpec{
Predicates: []spec.Predicate{
{Field: "ID", Comparator: spec.ComparatorEqual},
},
},
},
},
// test find: Many mode, And operator, NOT and LessThan comparator
{
Name: "FindByGenderNotAndAgeLessThan",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "gender", Type: code.SimpleType("Gender")},
{Name: "age", Type: code.SimpleType("int")},
},
Returns: []code.Type{
code.PointerType{ContainedType: code.SimpleType("UserModel")},
code.SimpleType("error"),
},
Operation: spec.FindOperation{
Mode: spec.QueryModeMany,
Query: spec.QuerySpec{
Operator: spec.OperatorAnd,
Predicates: []spec.Predicate{
{Field: "Gender", Comparator: spec.ComparatorNot},
{Field: "Age", Comparator: spec.ComparatorLessThan},
},
},
},
},
{
Name: "FindByAgeLessThanEqual",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "age", Type: code.SimpleType("int")},
},
Returns: []code.Type{
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
code.SimpleType("error"),
},
Operation: spec.FindOperation{
Mode: spec.QueryModeMany,
Query: spec.QuerySpec{
Predicates: []spec.Predicate{
{Field: "Age", Comparator: spec.ComparatorLessThanEqual},
},
},
},
},
{
Name: "FindByAgeGreaterThan",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "age", Type: code.SimpleType("int")},
},
Returns: []code.Type{
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
code.SimpleType("error"),
},
Operation: spec.FindOperation{
Mode: spec.QueryModeMany,
Query: spec.QuerySpec{
Predicates: []spec.Predicate{
{Field: "Age", Comparator: spec.ComparatorGreaterThan},
},
},
},
},
{
Name: "FindByAgeGreaterThanEqual",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "age", Type: code.SimpleType("int")},
},
Returns: []code.Type{
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
code.SimpleType("error"),
},
Operation: spec.FindOperation{
Mode: spec.QueryModeMany,
Query: spec.QuerySpec{
Predicates: []spec.Predicate{
{Field: "Age", Comparator: spec.ComparatorGreaterThanEqual},
},
},
},
},
{
Name: "FindByGenderOrAge",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "gender", Type: code.SimpleType("Gender")},
{Name: "age", Type: code.SimpleType("int")},
},
Returns: []code.Type{
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
code.SimpleType("error"),
},
Operation: spec.FindOperation{
Mode: spec.QueryModeMany,
Query: spec.QuerySpec{
Operator: spec.OperatorOr,
Predicates: []spec.Predicate{
{Field: "Gender", Comparator: spec.ComparatorEqual},
{Field: "Age", Comparator: spec.ComparatorEqual},
},
},
},
},
}
code, err := generator.GenerateRepository("user", userModel, "UserRepository", methods)
if err != nil {
t.Error(err)
}
expectedCode := `// Code generated by repogen. DO NOT EDIT.
package user
import (
"context"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo"
)
func NewUserRepository(collection *mongo.Collection) UserRepository {
return &UserRepositoryMongo{
collection: collection,
}
}
type UserRepositoryMongo struct {
collection *mongo.Collection
}
func (r *UserRepositoryMongo) FindByID(ctx context.Context, arg0 primitive.ObjectID) (*UserModel, error) {
var entity UserModel
if err := r.collection.FindOne(ctx, bson.M{
"_id": arg0,
}).Decode(&entity); err != nil {
return nil, err
}
return &entity, nil
}
func (r *UserRepositoryMongo) FindByGenderNotAndAgeLessThan(ctx context.Context, arg0 Gender, arg1 int) (*UserModel, error) {
cursor, err := r.collection.Find(ctx, bson.M{
"gender": bson.M{"$ne": arg0},
"age": bson.M{"$lt": arg1},
})
if err != nil {
return nil, err
}
var entities []*UserModel
if err := cursor.All(ctx, &entities); err != nil {
return nil, err
}
return entities, nil
}
func (r *UserRepositoryMongo) FindByAgeLessThanEqual(ctx context.Context, arg0 int) ([]*UserModel, error) {
cursor, err := r.collection.Find(ctx, bson.M{
"age": bson.M{"$lte": arg0},
})
if err != nil {
return nil, err
}
var entities []*UserModel
if err := cursor.All(ctx, &entities); err != nil {
return nil, err
}
return entities, nil
}
func (r *UserRepositoryMongo) FindByAgeGreaterThan(ctx context.Context, arg0 int) ([]*UserModel, error) {
cursor, err := r.collection.Find(ctx, bson.M{
"age": bson.M{"$gt": arg0},
})
if err != nil {
return nil, err
}
var entities []*UserModel
if err := cursor.All(ctx, &entities); err != nil {
return nil, err
}
return entities, nil
}
func (r *UserRepositoryMongo) FindByAgeGreaterThanEqual(ctx context.Context, arg0 int) ([]*UserModel, error) {
cursor, err := r.collection.Find(ctx, bson.M{
"age": bson.M{"$gte": arg0},
})
if err != nil {
return nil, err
}
var entities []*UserModel
if err := cursor.All(ctx, &entities); err != nil {
return nil, err
}
return entities, nil
}
func (r *UserRepositoryMongo) FindByGenderOrAge(ctx context.Context, arg0 Gender, arg1 int) ([]*UserModel, error) {
cursor, err := r.collection.Find(ctx, bson.M{
"$or": []bson.M{
{"gender": arg0},
{"age": arg1},
},
})
if err != nil {
return nil, err
}
var entities []*UserModel
if err := cursor.All(ctx, &entities); err != nil {
return nil, err
}
return entities, nil
}
`
expectedCodeLines := strings.Split(expectedCode, "\n")
actualCodeLines := strings.Split(code, "\n")
for i, line := range expectedCodeLines {
if line != actualCodeLines[i] {
t.Errorf("On line %d\nExpected = %v\nActual = %v", i, line, actualCodeLines[i])
}
}
}

View file

@ -0,0 +1,9 @@
package generator
const baseTemplate = `// Code generated by repogen. DO NOT EDIT.
package {{.PackageName}}
`
type baseTemplateData struct {
PackageName string
}

View file

@ -4,76 +4,37 @@ import (
"bytes"
"errors"
"fmt"
"io"
"text/template"
"github.com/sunboyy/repogen/internal/code"
"github.com/sunboyy/repogen/internal/spec"
"golang.org/x/tools/imports"
)
// GenerateMongoRepository generates mongodb repository
func GenerateMongoRepository(packageName string, structModel code.Struct, intf code.Interface) (string, error) {
var methodSpecs []spec.MethodSpec
for _, method := range intf.Methods {
methodSpec, err := spec.ParseInterfaceMethod(structModel, method)
if err != nil {
return "", err
}
methodSpecs = append(methodSpecs, methodSpec)
}
generator := mongoRepositoryGenerator{
PackageName: packageName,
// NewGenerator creates a new instance of MongoDB repository generator
func NewGenerator(structModel code.Struct, interfaceName string) RepositoryGenerator {
return RepositoryGenerator{
StructModel: structModel,
InterfaceName: intf.Name,
MethodSpecs: methodSpecs,
InterfaceName: interfaceName,
}
output, err := generator.Generate()
if err != nil {
return "", err
}
return output, nil
}
type mongoRepositoryGenerator struct {
PackageName string
// RepositoryGenerator provides repository constructor and method generation from provided specification
type RepositoryGenerator struct {
StructModel code.Struct
InterfaceName string
MethodSpecs []spec.MethodSpec
}
func (g mongoRepositoryGenerator) Generate() (string, error) {
buffer := new(bytes.Buffer)
if err := g.generateBaseContent(buffer); err != nil {
return "", err
}
for _, method := range g.MethodSpecs {
if err := g.generateMethod(buffer, method); err != nil {
return "", err
}
}
newOutput, err := imports.Process("", buffer.Bytes(), nil)
if err != nil {
return "", err
}
return string(newOutput), nil
}
func (g mongoRepositoryGenerator) generateBaseContent(buffer *bytes.Buffer) error {
tmpl, err := template.New("mongo_repository_base").Parse(baseTemplate)
// GenerateConstructor generates mongo repository struct implementation and constructor for the struct
func (g RepositoryGenerator) GenerateConstructor(buffer io.Writer) error {
tmpl, err := template.New("mongo_repository_base").Parse(constructorTemplate)
if err != nil {
return err
}
tmplData := mongoBaseTemplateData{
PackageName: g.PackageName,
InterfaceName: g.InterfaceName,
StructName: g.structName(),
tmplData := mongoConstructorTemplateData{
InterfaceName: g.InterfaceName,
ImplStructName: g.structName(),
}
if err := tmpl.Execute(buffer, tmplData); err != nil {
@ -83,27 +44,28 @@ func (g mongoRepositoryGenerator) generateBaseContent(buffer *bytes.Buffer) erro
return nil
}
func (g mongoRepositoryGenerator) generateMethod(buffer *bytes.Buffer, method spec.MethodSpec) error {
// GenerateMethod generates implementation of from provided method specification
func (g RepositoryGenerator) GenerateMethod(methodSpec spec.MethodSpec, buffer io.Writer) error {
tmpl, err := template.New("mongo_repository_method").Parse(methodTemplate)
if err != nil {
return err
}
implementation, err := g.generateMethodImplementation(method)
implementation, err := g.generateMethodImplementation(methodSpec)
if err != nil {
return err
}
var paramTypes []code.Type
for _, param := range method.Params[1:] {
for _, param := range methodSpec.Params[1:] {
paramTypes = append(paramTypes, param.Type)
}
tmplData := mongoMethodTemplateData{
StructName: g.structName(),
MethodName: method.Name,
MethodName: methodSpec.Name,
ParamTypes: paramTypes,
ReturnTypes: method.Returns,
ReturnTypes: methodSpec.Returns,
Implementation: implementation,
}
@ -114,7 +76,7 @@ func (g mongoRepositoryGenerator) generateMethod(buffer *bytes.Buffer, method sp
return nil
}
func (g mongoRepositoryGenerator) generateMethodImplementation(methodSpec spec.MethodSpec) (string, error) {
func (g RepositoryGenerator) generateMethodImplementation(methodSpec spec.MethodSpec) (string, error) {
switch operation := methodSpec.Operation.(type) {
case spec.FindOperation:
return g.generateFindImplementation(operation)
@ -123,7 +85,7 @@ func (g mongoRepositoryGenerator) generateMethodImplementation(methodSpec spec.M
return "", errors.New("method spec not supported")
}
func (g mongoRepositoryGenerator) generateFindImplementation(operation spec.FindOperation) (string, error) {
func (g RepositoryGenerator) generateFindImplementation(operation spec.FindOperation) (string, error) {
buffer := new(bytes.Buffer)
var predicates []predicate
@ -172,6 +134,6 @@ func (g mongoRepositoryGenerator) generateFindImplementation(operation spec.Find
return buffer.String(), nil
}
func (g mongoRepositoryGenerator) structName() string {
func (g RepositoryGenerator) structName() string {
return g.InterfaceName + "Mongo"
}

View file

@ -1,14 +1,36 @@
package mongo_test
import (
"strings"
"bytes"
"testing"
"github.com/sunboyy/repogen/internal/code"
"github.com/sunboyy/repogen/internal/mongo"
"github.com/sunboyy/repogen/internal/spec"
"github.com/sunboyy/repogen/internal/testutils"
)
func TestGenerateMongoRepository(t *testing.T) {
const expectedConstructorResult = `
import (
"context"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo"
)
func NewUserRepository(collection *mongo.Collection) UserRepository {
return &UserRepositoryMongo{
collection: collection,
}
}
type UserRepositoryMongo struct {
collection *mongo.Collection
}
`
func TestGenerateConstructor(t *testing.T) {
userModel := code.Struct{
Name: "UserModel",
Fields: code.StructFields{
@ -34,158 +56,46 @@ func TestGenerateMongoRepository(t *testing.T) {
},
},
}
intf := code.Interface{
Name: "UserRepository",
Methods: []code.Method{
{
generator := mongo.NewGenerator(userModel, "UserRepository")
buffer := new(bytes.Buffer)
err := generator.GenerateConstructor(buffer)
if err != nil {
t.Error(err)
}
if err := testutils.ExpectMultiLineString(expectedConstructorResult, buffer.String()); err != nil {
t.Error(err)
}
}
type GenerateMethodTestCase struct {
Name string
MethodSpec spec.MethodSpec
ExpectedCode string
}
func TestGenerateMethod(t *testing.T) {
testTable := []GenerateMethodTestCase{
{
Name: "simple find one method",
MethodSpec: spec.MethodSpec{
Name: "FindByID",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "id", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}},
},
Returns: []code.Type{code.PointerType{ContainedType: code.SimpleType("UserModel")}, code.SimpleType("error")},
},
{
Name: "FindOneByUsername",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "username", Type: code.SimpleType("string")},
},
Returns: []code.Type{
code.PointerType{ContainedType: code.SimpleType("UserModel")},
code.SimpleType("error"),
Operation: spec.FindOperation{
Mode: spec.QueryModeOne,
Query: spec.QuerySpec{
Predicates: []spec.Predicate{
{Comparator: spec.ComparatorEqual, Field: "ID"},
},
},
},
},
{
Name: "FindByUsername",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "username", Type: code.SimpleType("string")},
},
Returns: []code.Type{
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
code.SimpleType("error"),
},
},
{
Name: "FindByIDAndUsername",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "id", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}},
{Name: "username", Type: code.SimpleType("string")},
},
Returns: []code.Type{
code.PointerType{ContainedType: code.SimpleType("UserModel")},
code.SimpleType("error"),
},
},
{
Name: "FindByGenderNot",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "gender", Type: code.SimpleType("Gender")},
},
Returns: []code.Type{
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
code.SimpleType("error"),
},
},
{
Name: "FindByAgeLessThan",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "age", Type: code.SimpleType("int")},
},
Returns: []code.Type{
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
code.SimpleType("error"),
},
},
{
Name: "FindByAgeLessThanEqual",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "age", Type: code.SimpleType("int")},
},
Returns: []code.Type{
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
code.SimpleType("error"),
},
},
{
Name: "FindByAgeGreaterThan",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "age", Type: code.SimpleType("int")},
},
Returns: []code.Type{
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
code.SimpleType("error"),
},
},
{
Name: "FindByAgeGreaterThanEqual",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "age", Type: code.SimpleType("int")},
},
Returns: []code.Type{
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
code.SimpleType("error"),
},
},
{
Name: "FindByGenderOrAgeLessThan",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "gender", Type: code.SimpleType("Gender")},
{Name: "age", Type: code.SimpleType("int")},
},
Returns: []code.Type{
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
code.SimpleType("error"),
},
},
{
Name: "FindByGenderIn",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "gender", Type: code.ArrayType{ContainedType: code.SimpleType("Gender")}},
},
Returns: []code.Type{
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
code.SimpleType("error"),
},
},
},
}
code, err := mongo.GenerateMongoRepository("user", userModel, intf)
if err != nil {
t.Error(err)
}
expectedCode := `// Code generated by repogen. DO NOT EDIT.
package user
import (
"context"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo"
)
func NewUserRepository(collection *mongo.Collection) UserRepository {
return &UserRepositoryMongo{
collection: collection,
}
}
type UserRepositoryMongo struct {
collection *mongo.Collection
}
ExpectedCode: `
func (r *UserRepositoryMongo) FindByID(ctx context.Context, arg0 primitive.ObjectID) (*UserModel, error) {
var entity UserModel
if err := r.collection.FindOne(ctx, bson.M{
@ -195,20 +105,33 @@ func (r *UserRepositoryMongo) FindByID(ctx context.Context, arg0 primitive.Objec
}
return &entity, nil
}
func (r *UserRepositoryMongo) FindOneByUsername(ctx context.Context, arg0 string) (*UserModel, error) {
var entity UserModel
if err := r.collection.FindOne(ctx, bson.M{
"username": arg0,
}).Decode(&entity); err != nil {
return nil, err
}
return &entity, nil
}
func (r *UserRepositoryMongo) FindByUsername(ctx context.Context, arg0 string) ([]*UserModel, error) {
`,
},
{
Name: "simple find many method",
MethodSpec: spec.MethodSpec{
Name: "FindByGender",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "gender", Type: code.SimpleType("Gender")},
},
Returns: []code.Type{
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
code.SimpleType("error"),
},
Operation: spec.FindOperation{
Mode: spec.QueryModeMany,
Query: spec.QuerySpec{
Predicates: []spec.Predicate{
{Comparator: spec.ComparatorEqual, Field: "Gender"},
},
},
},
},
ExpectedCode: `
func (r *UserRepositoryMongo) FindByGender(ctx context.Context, arg0 Gender) ([]*UserModel, error) {
cursor, err := r.collection.Find(ctx, bson.M{
"username": arg0,
"gender": arg0,
})
if err != nil {
return nil, err
@ -219,18 +142,114 @@ func (r *UserRepositoryMongo) FindByUsername(ctx context.Context, arg0 string) (
}
return entities, nil
}
func (r *UserRepositoryMongo) FindByIDAndUsername(ctx context.Context, arg0 primitive.ObjectID, arg1 string) (*UserModel, error) {
var entity UserModel
if err := r.collection.FindOne(ctx, bson.M{
"_id": arg0,
"username": arg1,
}).Decode(&entity); err != nil {
`,
},
{
Name: "find with And operator",
MethodSpec: spec.MethodSpec{
Name: "FindByGenderAndAge",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "gender", Type: code.SimpleType("Gender")},
{Name: "age", Type: code.SimpleType("int")},
},
Returns: []code.Type{
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
code.SimpleType("error"),
},
Operation: spec.FindOperation{
Mode: spec.QueryModeMany,
Query: spec.QuerySpec{
Operator: spec.OperatorAnd,
Predicates: []spec.Predicate{
{Comparator: spec.ComparatorEqual, Field: "Gender"},
{Comparator: spec.ComparatorEqual, Field: "Age"},
},
},
},
},
ExpectedCode: `
func (r *UserRepositoryMongo) FindByGenderAndAge(ctx context.Context, arg0 Gender, arg1 int) ([]*UserModel, error) {
cursor, err := r.collection.Find(ctx, bson.M{
"gender": arg0,
"age": arg1,
})
if err != nil {
return nil, err
}
return &entity, nil
var entities []*UserModel
if err := cursor.All(ctx, &entities); err != nil {
return nil, err
}
return entities, nil
}
`,
},
{
Name: "find with Or operator",
MethodSpec: spec.MethodSpec{
Name: "FindByGenderOrAge",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "gender", Type: code.SimpleType("Gender")},
{Name: "age", Type: code.SimpleType("int")},
},
Returns: []code.Type{
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
code.SimpleType("error"),
},
Operation: spec.FindOperation{
Mode: spec.QueryModeMany,
Query: spec.QuerySpec{
Operator: spec.OperatorOr,
Predicates: []spec.Predicate{
{Comparator: spec.ComparatorEqual, Field: "Gender"},
{Comparator: spec.ComparatorEqual, Field: "Age"},
},
},
},
},
ExpectedCode: `
func (r *UserRepositoryMongo) FindByGenderOrAge(ctx context.Context, arg0 Gender, arg1 int) ([]*UserModel, error) {
cursor, err := r.collection.Find(ctx, bson.M{
"$or": []bson.M{
{"gender": arg0},
{"age": arg1},
},
})
if err != nil {
return nil, err
}
var entities []*UserModel
if err := cursor.All(ctx, &entities); err != nil {
return nil, err
}
return entities, nil
}
`,
},
{
Name: "find with Not comparator",
MethodSpec: spec.MethodSpec{
Name: "FindByGenderNot",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "gender", Type: code.SimpleType("Gender")},
},
Returns: []code.Type{
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
code.SimpleType("error"),
},
Operation: spec.FindOperation{
Mode: spec.QueryModeMany,
Query: spec.QuerySpec{
Predicates: []spec.Predicate{
{Comparator: spec.ComparatorNot, Field: "Gender"},
},
},
},
},
ExpectedCode: `
func (r *UserRepositoryMongo) FindByGenderNot(ctx context.Context, arg0 Gender) ([]*UserModel, error) {
cursor, err := r.collection.Find(ctx, bson.M{
"gender": bson.M{"$ne": arg0},
@ -244,7 +263,30 @@ func (r *UserRepositoryMongo) FindByGenderNot(ctx context.Context, arg0 Gender)
}
return entities, nil
}
`,
},
{
Name: "find with LessThan comparator",
MethodSpec: spec.MethodSpec{
Name: "FindByAgeLessThan",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "age", Type: code.SimpleType("int")},
},
Returns: []code.Type{
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
code.SimpleType("error"),
},
Operation: spec.FindOperation{
Mode: spec.QueryModeMany,
Query: spec.QuerySpec{
Predicates: []spec.Predicate{
{Comparator: spec.ComparatorLessThan, Field: "Age"},
},
},
},
},
ExpectedCode: `
func (r *UserRepositoryMongo) FindByAgeLessThan(ctx context.Context, arg0 int) ([]*UserModel, error) {
cursor, err := r.collection.Find(ctx, bson.M{
"age": bson.M{"$lt": arg0},
@ -258,7 +300,30 @@ func (r *UserRepositoryMongo) FindByAgeLessThan(ctx context.Context, arg0 int) (
}
return entities, nil
}
`,
},
{
Name: "find with LessThanEqual comparator",
MethodSpec: spec.MethodSpec{
Name: "FindByAgeLessThanEqual",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "age", Type: code.SimpleType("int")},
},
Returns: []code.Type{
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
code.SimpleType("error"),
},
Operation: spec.FindOperation{
Mode: spec.QueryModeMany,
Query: spec.QuerySpec{
Predicates: []spec.Predicate{
{Comparator: spec.ComparatorLessThanEqual, Field: "Age"},
},
},
},
},
ExpectedCode: `
func (r *UserRepositoryMongo) FindByAgeLessThanEqual(ctx context.Context, arg0 int) ([]*UserModel, error) {
cursor, err := r.collection.Find(ctx, bson.M{
"age": bson.M{"$lte": arg0},
@ -272,7 +337,30 @@ func (r *UserRepositoryMongo) FindByAgeLessThanEqual(ctx context.Context, arg0 i
}
return entities, nil
}
`,
},
{
Name: "find with GreaterThan comparator",
MethodSpec: spec.MethodSpec{
Name: "FindByAgeGreaterThan",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "age", Type: code.SimpleType("int")},
},
Returns: []code.Type{
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
code.SimpleType("error"),
},
Operation: spec.FindOperation{
Mode: spec.QueryModeMany,
Query: spec.QuerySpec{
Predicates: []spec.Predicate{
{Comparator: spec.ComparatorGreaterThan, Field: "Age"},
},
},
},
},
ExpectedCode: `
func (r *UserRepositoryMongo) FindByAgeGreaterThan(ctx context.Context, arg0 int) ([]*UserModel, error) {
cursor, err := r.collection.Find(ctx, bson.M{
"age": bson.M{"$gt": arg0},
@ -286,7 +374,30 @@ func (r *UserRepositoryMongo) FindByAgeGreaterThan(ctx context.Context, arg0 int
}
return entities, nil
}
`,
},
{
Name: "find with GreaterThanEqual comparator",
MethodSpec: spec.MethodSpec{
Name: "FindByAgeGreaterThanEqual",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "age", Type: code.SimpleType("int")},
},
Returns: []code.Type{
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
code.SimpleType("error"),
},
Operation: spec.FindOperation{
Mode: spec.QueryModeMany,
Query: spec.QuerySpec{
Predicates: []spec.Predicate{
{Comparator: spec.ComparatorGreaterThanEqual, Field: "Age"},
},
},
},
},
ExpectedCode: `
func (r *UserRepositoryMongo) FindByAgeGreaterThanEqual(ctx context.Context, arg0 int) ([]*UserModel, error) {
cursor, err := r.collection.Find(ctx, bson.M{
"age": bson.M{"$gte": arg0},
@ -300,24 +411,30 @@ func (r *UserRepositoryMongo) FindByAgeGreaterThanEqual(ctx context.Context, arg
}
return entities, nil
}
func (r *UserRepositoryMongo) FindByGenderOrAgeLessThan(ctx context.Context, arg0 Gender, arg1 int) ([]*UserModel, error) {
cursor, err := r.collection.Find(ctx, bson.M{
"$or": []bson.M{
{"gender": arg0},
{"age": bson.M{"$lt": arg1}},
`,
},
})
if err != nil {
return nil, err
}
var entities []*UserModel
if err := cursor.All(ctx, &entities); err != nil {
return nil, err
}
return entities, nil
}
{
Name: "find with In comparator",
MethodSpec: spec.MethodSpec{
Name: "FindByGenderIn",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "gender", Type: code.ArrayType{ContainedType: code.SimpleType("Gender")}},
},
Returns: []code.Type{
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
code.SimpleType("error"),
},
Operation: spec.FindOperation{
Mode: spec.QueryModeMany,
Query: spec.QuerySpec{
Predicates: []spec.Predicate{
{Comparator: spec.ComparatorIn, Field: "Gender"},
},
},
},
},
ExpectedCode: `
func (r *UserRepositoryMongo) FindByGenderIn(ctx context.Context, arg0 []Gender) ([]*UserModel, error) {
cursor, err := r.collection.Find(ctx, bson.M{
"gender": bson.M{"$in": arg0},
@ -331,13 +448,47 @@ func (r *UserRepositoryMongo) FindByGenderIn(ctx context.Context, arg0 []Gender)
}
return entities, nil
}
`
expectedCodeLines := strings.Split(expectedCode, "\n")
actualCodeLines := strings.Split(code, "\n")
`,
},
}
for _, testCase := range testTable {
t.Run(testCase.Name, func(t *testing.T) {
userModel := code.Struct{
Name: "UserModel",
Fields: code.StructFields{
{
Name: "ID",
Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"},
Tags: map[string][]string{"bson": {"_id", "omitempty"}},
},
{
Name: "Username",
Type: code.SimpleType("string"),
Tags: map[string][]string{"bson": {"username"}},
},
{
Name: "Gender",
Type: code.SimpleType("Gender"),
Tags: map[string][]string{"bson": {"gender"}},
},
{
Name: "Age",
Type: code.SimpleType("int"),
Tags: map[string][]string{"bson": {"age"}},
},
},
}
generator := mongo.NewGenerator(userModel, "UserRepository")
buffer := new(bytes.Buffer)
for i, line := range expectedCodeLines {
if line != actualCodeLines[i] {
t.Errorf("On line %d\nExpected = %v\nActual = %v", i, line, actualCodeLines[i])
}
err := generator.GenerateMethod(testCase.MethodSpec, buffer)
if err != nil {
t.Error(err)
}
if err := testutils.ExpectMultiLineString(testCase.ExpectedCode, buffer.String()); err != nil {
t.Error(err)
}
})
}
}

View file

@ -7,9 +7,7 @@ import (
"github.com/sunboyy/repogen/internal/code"
)
const baseTemplate = `// Code generated by repogen. DO NOT EDIT.
package {{.PackageName}}
const constructorTemplate = `
import (
"context"
@ -19,20 +17,19 @@ import (
)
func New{{.InterfaceName}}(collection *mongo.Collection) {{.InterfaceName}} {
return &{{.StructName}}{
return &{{.ImplStructName}}{
collection: collection,
}
}
type {{.StructName}} struct {
type {{.ImplStructName}} struct {
collection *mongo.Collection
}
`
type mongoBaseTemplateData struct {
PackageName string
InterfaceName string
StructName string
type mongoConstructorTemplateData struct {
InterfaceName string
ImplStructName string
}
const methodTemplate = `

View file

@ -0,0 +1,31 @@
package testutils
import (
"fmt"
"strings"
)
// ExpectMultiLineString compares two multi-line strings and report the difference
func ExpectMultiLineString(expected, actual string) error {
expectedLines := strings.Split(expected, "\n")
actualLines := strings.Split(actual, "\n")
numberOfComparableLines := len(expectedLines)
if len(actualLines) < numberOfComparableLines {
numberOfComparableLines = len(actualLines)
}
for i := 0; i < numberOfComparableLines; i++ {
if expectedLines[i] != actualLines[i] {
return fmt.Errorf("On line %d\nExpected: %v\nReceived: %v", i+1, expectedLines[i], actualLines[i])
}
}
if len(expectedLines) < len(actualLines) {
return fmt.Errorf("Unexpected lines:\n%s", strings.Join(actualLines[len(expectedLines):], "\n"))
} else if len(expectedLines) > len(actualLines) {
return fmt.Errorf("Missing lines:\n%s", strings.Join(expectedLines[len(actualLines):], "\n"))
}
return nil
}

View file

@ -0,0 +1,72 @@
package testutils_test
import (
"testing"
"github.com/sunboyy/repogen/internal/testutils"
)
func TestExpectMultiLineString(t *testing.T) {
t.Run("same string should return nil", func(t *testing.T) {
text := ` Hello world
this is a test text `
err := testutils.ExpectMultiLineString(text, text)
if err != nil {
t.Errorf("Expected = <nil>\nReceived = %s", err.Error())
}
})
t.Run("different string with same number of lines", func(t *testing.T) {
expectedText := ` Hello world
this is an expected text
how are you?`
actualText := ` Hello world
this is a real text
How are you?`
err := testutils.ExpectMultiLineString(expectedText, actualText)
expectedError := "On line 2\nExpected: this is an expected text\nReceived: this is a real text"
if err == nil || err.Error() != expectedError {
t.Errorf("Expected = %s\nReceived = %s", expectedError, err.Error())
}
})
t.Run("expected text longer than actual text", func(t *testing.T) {
expectedText := ` Hello world
this is an expected text
how are you?
I'm fine...
Thank you...`
actualText := ` Hello world
this is an expected text
how are you?`
err := testutils.ExpectMultiLineString(expectedText, actualText)
expectedError := "Missing lines:\nI'm fine...\nThank you..."
if err == nil || err.Error() != expectedError {
t.Errorf("Expected = %s\nReceived = %s", expectedError, err.Error())
}
})
t.Run("actual text longer than expected text", func(t *testing.T) {
expectedText := ` Hello world
this is an expected text
how are you?`
actualText := ` Hello world
this is an expected text
how are you?
I'm fine...
Thank you...`
err := testutils.ExpectMultiLineString(expectedText, actualText)
expectedError := "Unexpected lines:\nI'm fine...\nThank you..."
if err == nil || err.Error() != expectedError {
t.Errorf("Expected = %s\nReceived = %s", expectedError, err.Error())
}
})
}

86
main.go Normal file
View file

@ -0,0 +1,86 @@
package main
import (
"errors"
"flag"
"go/parser"
"go/token"
"os"
"path/filepath"
"github.com/sunboyy/repogen/internal/code"
"github.com/sunboyy/repogen/internal/generator"
"github.com/sunboyy/repogen/internal/spec"
)
func main() {
sourcePtr := flag.String("src", "", "source file")
destPtr := flag.String("dest", "", "destination file")
modelPtr := flag.String("model", "", "model struct name")
repoPtr := flag.String("repo", "", "repository interface name")
flag.Parse()
if *sourcePtr == "" {
panic("-source flag required")
}
if *modelPtr == "" {
panic("-model flag required")
}
if *repoPtr == "" {
panic("-repo flag required")
}
dest := os.Stdout
if *destPtr != "" {
if err := os.MkdirAll(filepath.Dir(*destPtr), os.ModePerm); err != nil {
panic(err)
}
file, err := os.Create(*destPtr)
if err != nil {
panic(err)
}
defer file.Close()
dest = file
}
code, err := generateFromRequest(*sourcePtr, *modelPtr, *repoPtr)
if err != nil {
panic(err)
}
if _, err := dest.WriteString(code); err != nil {
panic(err)
}
}
func generateFromRequest(fileName, structModelName, repositoryInterfaceName string) (string, error) {
fset := token.NewFileSet()
f, err := parser.ParseFile(fset, fileName, nil, parser.ParseComments)
if err != nil {
panic(err)
}
file := code.ExtractComponents(f)
structModel, ok := file.Structs.ByName(structModelName)
if !ok {
return "", errors.New("struct model not found")
}
intf, ok := file.Interfaces.ByName(repositoryInterfaceName)
if !ok {
return "", errors.New("interface model not found")
}
var methodSpecs []spec.MethodSpec
for _, method := range intf.Methods {
methodSpec, err := spec.ParseInterfaceMethod(structModel, method)
if err != nil {
return "", err
}
methodSpecs = append(methodSpecs, methodSpec)
}
return generator.GenerateRepository(file.PackageName, structModel, intf.Name, methodSpecs)
}