Add insert operation

This commit is contained in:
sunboyy 2021-02-01 21:39:20 +07:00
parent 7f07de08af
commit 25e1b2aa85
13 changed files with 509 additions and 38 deletions

View file

@ -13,6 +13,7 @@ Repogen is a code generator for database repository in Golang inspired by Spring
Repogen is a library that generates MongoDB repository implementation from repository interface by using method name pattern.
- CRUD functionality
- Method signature validation
- Supports single-entity and multiple-entity operations
- Supports many comparison operators

View file

@ -2,7 +2,7 @@ coverage:
status:
project:
default:
target: 75%
target: 80%
threshold: 4%
patch:
default:

View file

@ -62,41 +62,23 @@ func ExtractComponents(f *ast.File) File {
interfaceType, ok := typeSpec.Type.(*ast.InterfaceType)
if ok {
intf := Interface{
intf := InterfaceType{
Name: typeSpec.Name.Name,
}
for _, method := range interfaceType.Methods.List {
var meth Method
for _, name := range method.Names {
meth.Name = name.Name
break
}
funcType, ok := method.Type.(*ast.FuncType)
if !ok {
continue
}
for _, param := range funcType.Params.List {
paramType := getType(param.Type)
if len(param.Names) == 0 {
meth.Params = append(meth.Params, Param{Type: paramType})
continue
var name string
for _, n := range method.Names {
name = n.Name
break
}
for _, name := range param.Names {
meth.Params = append(meth.Params, Param{
Name: name.Name,
Type: paramType,
})
}
}
for _, result := range funcType.Results.List {
meth.Returns = append(meth.Returns, getType(result.Type))
}
meth := extractFunction(name, funcType)
intf.Methods = append(intf.Methods, meth)
}
@ -131,6 +113,35 @@ func extractStructTag(tagValue string) map[string][]string {
return tags
}
func extractFunction(name string, funcType *ast.FuncType) Method {
meth := Method{
Name: name,
}
for _, param := range funcType.Params.List {
paramType := getType(param.Type)
if len(param.Names) == 0 {
meth.Params = append(meth.Params, Param{Type: paramType})
continue
}
for _, name := range param.Names {
meth.Params = append(meth.Params, Param{
Name: name.Name,
Type: paramType,
})
}
}
if funcType.Results != nil {
for _, result := range funcType.Results.List {
meth.Returns = append(meth.Returns, getType(result.Type))
}
}
return meth
}
func getType(expr ast.Expr) Type {
identExpr, ok := expr.(*ast.Ident)
if ok {
@ -158,5 +169,28 @@ func getType(expr ast.Expr) Type {
return ArrayType{containedType}
}
intfType, ok := expr.(*ast.InterfaceType)
if ok {
var methods []Method
for _, method := range intfType.Methods.List {
funcType, ok := method.Type.(*ast.FuncType)
if !ok {
continue
}
var name string
for _, n := range method.Names {
name = n.Name
break
}
methods = append(methods, extractFunction(name, funcType))
}
return InterfaceType{
Methods: methods,
}
}
return nil
}

View file

@ -97,6 +97,12 @@ type UserRepository interface {
FindOneByID(ctx context.Context, id primitive.ObjectID) (*UserModel, error)
FindAll(context.Context) ([]*UserModel, error)
FindByAgeBetween(ctx context.Context, fromAge, toAge int) ([]*UserModel, error)
InsertOne(ctx context.Context, user *UserModel) (interface{}, error)
CustomMethod(interface {
Run(arg1 int)
}) interface {
Do(arg2 string)
}
}`,
ExpectedOutput: code.File{
PackageName: "user",
@ -137,6 +143,46 @@ type UserRepository interface {
code.SimpleType("error"),
},
},
{
Name: "InsertOne",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "user", Type: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
},
Returns: []code.Type{
code.InterfaceType{},
code.SimpleType("error"),
},
},
{
Name: "CustomMethod",
Params: []code.Param{
{
Type: code.InterfaceType{
Methods: []code.Method{
{
Name: "Run",
Params: []code.Param{
{Name: "arg1", Type: code.SimpleType("int")},
},
},
},
},
},
},
Returns: []code.Type{
code.InterfaceType{
Methods: []code.Method{
{
Name: "Do",
Params: []code.Param{
{Name: "arg2", Type: code.SimpleType("string")},
},
},
},
},
},
},
},
},
},

View file

@ -38,6 +38,11 @@ type Struct struct {
Fields StructFields
}
// ReferencedType returns a type variable of this struct
func (str Struct) ReferencedType() Type {
return SimpleType(str.Name)
}
// StructFields is a group of the StructField model
type StructFields []StructField
@ -59,25 +64,30 @@ type StructField struct {
}
// Interfaces is a group of Interface model
type Interfaces []Interface
type Interfaces []InterfaceType
// ByName return interface by name Another return value shows whether there is an interface
// with that name exists.
func (intfs Interfaces) ByName(name string) (Interface, bool) {
func (intfs Interfaces) ByName(name string) (InterfaceType, bool) {
for _, intf := range intfs {
if intf.Name == name {
return intf, true
}
}
return Interface{}, false
return InterfaceType{}, false
}
// Interface is a definition of the interface
type Interface struct {
// InterfaceType is a definition of the interface
type InterfaceType struct {
Name string
Methods []Method
}
// Code returns token string in code format
func (intf InterfaceType) Code() string {
return `interface{}`
}
// Method is a definition of the method inside the interface
type Method struct {
Name string

View file

@ -63,7 +63,7 @@ func TestStructFieldsByName(t *testing.T) {
}
func TestInterfacesByName(t *testing.T) {
userRepoIntf := code.Interface{Name: "UserRepository"}
userRepoIntf := code.InterfaceType{Name: "UserRepository"}
interfaces := code.Interfaces{userRepoIntf}
t.Run("struct field found", func(t *testing.T) {
@ -85,3 +85,44 @@ func TestInterfacesByName(t *testing.T) {
}
})
}
type TypeCodeTestCase struct {
Name string
Type code.Type
ExpectedCode string
}
func TestArrayTypeCode(t *testing.T) {
testTable := []TypeCodeTestCase{
{
Name: "simple type",
Type: code.SimpleType("UserModel"),
ExpectedCode: "UserModel",
},
{
Name: "external type",
Type: code.ExternalType{PackageAlias: "context", Name: "Context"},
ExpectedCode: "context.Context",
},
{
Name: "pointer type",
Type: code.PointerType{ContainedType: code.SimpleType("UserModel")},
ExpectedCode: "*UserModel",
},
{
Name: "array type",
Type: code.ArrayType{ContainedType: code.SimpleType("UserModel")},
ExpectedCode: "[]UserModel",
},
}
for _, testCase := range testTable {
t.Run(testCase.Name, func(t *testing.T) {
code := testCase.Type.Code()
if code != testCase.ExpectedCode {
t.Errorf("Expected = %v\nReceived = %v", testCase.ExpectedCode, code)
}
})
}
}

View file

@ -77,6 +77,8 @@ func (g RepositoryGenerator) GenerateMethod(methodSpec spec.MethodSpec, buffer i
func (g RepositoryGenerator) generateMethodImplementation(methodSpec spec.MethodSpec) (string, error) {
switch operation := methodSpec.Operation.(type) {
case spec.InsertOperation:
return g.generateInsertImplementation(operation)
case spec.FindOperation:
return g.generateFindImplementation(operation)
case spec.UpdateOperation:
@ -88,6 +90,13 @@ func (g RepositoryGenerator) generateMethodImplementation(methodSpec spec.Method
return "", OperationNotSupportedError
}
func (g RepositoryGenerator) generateInsertImplementation(operation spec.InsertOperation) (string, error) {
if operation.Mode == spec.QueryModeOne {
return insertOneTemplate, nil
}
return insertManyTemplate, nil
}
func (g RepositoryGenerator) generateFindImplementation(operation spec.FindOperation) (string, error) {
buffer := new(bytes.Buffer)

View file

@ -80,6 +80,85 @@ type GenerateMethodTestCase struct {
ExpectedCode string
}
func TestGenerateMethod_Insert(t *testing.T) {
testTable := []GenerateMethodTestCase{
{
Name: "insert one method",
MethodSpec: spec.MethodSpec{
Name: "InsertOne",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "userModel", Type: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
},
Returns: []code.Type{
code.InterfaceType{},
code.SimpleType("error"),
},
Operation: spec.InsertOperation{
Mode: spec.QueryModeOne,
},
},
ExpectedCode: `
func (r *UserRepositoryMongo) InsertOne(arg0 context.Context, arg1 *UserModel) (interface{}, error) {
result, err := r.collection.InsertOne(arg0, arg1)
if err != nil {
return nil, err
}
return result.InsertedID, nil
}
`,
},
{
Name: "insert many method",
MethodSpec: spec.MethodSpec{
Name: "Insert",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "userModel", Type: code.ArrayType{
ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")},
}},
},
Returns: []code.Type{
code.ArrayType{ContainedType: code.InterfaceType{}},
code.SimpleType("error"),
},
Operation: spec.InsertOperation{
Mode: spec.QueryModeMany,
},
},
ExpectedCode: `
func (r *UserRepositoryMongo) Insert(arg0 context.Context, arg1 []*UserModel) ([]interface{}, error) {
var entities []interface{}
for _, model := range arg1 {
entities = append(entities, model)
}
result, err := r.collection.InsertMany(arg0, entities)
if err != nil {
return nil, err
}
return result.InsertedIDs, nil
}
`,
},
}
for _, testCase := range testTable {
t.Run(testCase.Name, func(t *testing.T) {
generator := mongo.NewGenerator(userModel, "UserRepository")
buffer := new(bytes.Buffer)
err := generator.GenerateMethod(testCase.MethodSpec, buffer)
if err != nil {
t.Error(err)
}
if err := testutils.ExpectMultiLineString(testCase.ExpectedCode, buffer.String()); err != nil {
t.Error(err)
}
})
}
}
func TestGenerateMethod_Find(t *testing.T) {
testTable := []GenerateMethodTestCase{
{
@ -90,7 +169,10 @@ func TestGenerateMethod_Find(t *testing.T) {
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "id", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}},
},
Returns: []code.Type{code.PointerType{ContainedType: code.SimpleType("UserModel")}, code.SimpleType("error")},
Returns: []code.Type{
code.PointerType{ContainedType: code.SimpleType("UserModel")},
code.SimpleType("error"),
},
Operation: spec.FindOperation{
Mode: spec.QueryModeOne,
Query: spec.QuerySpec{

View file

@ -70,6 +70,22 @@ func (data mongoMethodTemplateData) Returns() string {
return fmt.Sprintf(" (%s)", strings.Join(returns, ", "))
}
const insertOneTemplate = ` result, err := r.collection.InsertOne(arg0, arg1)
if err != nil {
return nil, err
}
return result.InsertedID, nil`
const insertManyTemplate = ` var entities []interface{}
for _, model := range arg1 {
entities = append(entities, model)
}
result, err := r.collection.InsertMany(arg0, entities)
if err != nil {
return nil, err
}
return result.InsertedIDs, nil`
type mongoFindTemplateData struct {
EntityType string
QuerySpec querySpec

View file

@ -9,14 +9,14 @@ func (err ParsingError) Error() string {
return "unknown operation"
case UnsupportedNameError:
return "method name is not supported"
case UnsupportedReturnError:
return "this type of return is not supported"
case InvalidQueryError:
return "invalid query"
case InvalidParamError:
return "parameters do not match the query"
case InvalidUpdateFieldsError:
return "update fields is invalid"
case UnsupportedReturnError:
return "this type of return is not supported"
case ContextParamRequiredError:
return "context parameter is required"
case StructFieldNotFoundError:
@ -29,10 +29,10 @@ func (err ParsingError) Error() string {
const (
UnknownOperationError ParsingError = "ERROR_UNKNOWN_OPERATION"
UnsupportedNameError ParsingError = "ERROR_UNSUPPORTED"
UnsupportedReturnError ParsingError = "ERROR_UNSUPPORTED_RETURN"
InvalidQueryError ParsingError = "ERROR_INVALID_QUERY"
InvalidParamError ParsingError = "ERROR_INVALID_PARAM"
InvalidUpdateFieldsError ParsingError = "ERROR_INVALID_UPDATE_FIELDS"
UnsupportedReturnError ParsingError = "ERROR_INVALID_RETURN"
ContextParamRequiredError ParsingError = "ERROR_CONTEXT_PARAM_REQUIRED"
StructFieldNotFoundError ParsingError = "ERROR_STRUCT_FIELD_NOT_FOUND"
)

View file

@ -27,6 +27,11 @@ type MethodSpec struct {
type Operation interface {
}
// InsertOperation is a method specification for insert operations
type InsertOperation struct {
Mode QueryMode
}
// FindOperation is a method specification for find operations
type FindOperation struct {
Mode QueryMode

View file

@ -23,6 +23,8 @@ type interfaceMethodParser struct {
func (p interfaceMethodParser) Parse() (MethodSpec, error) {
methodNameTokens := camelcase.Split(p.Method.Name)
switch methodNameTokens[0] {
case "Insert":
return p.parseInsertMethod(methodNameTokens[1:])
case "Find":
return p.parseFindMethod(methodNameTokens[1:])
case "Update":
@ -33,6 +35,65 @@ func (p interfaceMethodParser) Parse() (MethodSpec, error) {
return MethodSpec{}, UnknownOperationError
}
func (p interfaceMethodParser) parseInsertMethod(tokens []string) (MethodSpec, error) {
mode, err := p.extractInsertReturns(p.Method.Returns)
if err != nil {
return MethodSpec{}, err
}
if err := p.validateContextParam(); err != nil {
return MethodSpec{}, err
}
pointerType := code.PointerType{ContainedType: p.StructModel.ReferencedType()}
if mode == QueryModeOne && p.Method.Params[1].Type != pointerType {
return MethodSpec{}, InvalidParamError
}
arrayType := code.ArrayType{ContainedType: pointerType}
if mode == QueryModeMany && p.Method.Params[1].Type != arrayType {
return MethodSpec{}, InvalidParamError
}
return MethodSpec{
Name: p.Method.Name,
Params: p.Method.Params,
Returns: p.Method.Returns,
Operation: InsertOperation{
Mode: mode,
},
}, nil
}
func (p interfaceMethodParser) extractInsertReturns(returns []code.Type) (QueryMode, error) {
if len(returns) != 2 {
return "", UnsupportedReturnError
}
if returns[1] != code.SimpleType("error") {
return "", UnsupportedReturnError
}
interfaceType, ok := returns[0].(code.InterfaceType)
if ok {
if len(interfaceType.Methods) != 0 {
return "", UnsupportedReturnError
}
return QueryModeOne, nil
}
arrayType, ok := returns[0].(code.ArrayType)
if ok {
interfaceType, ok := arrayType.ContainedType.(code.InterfaceType)
if !ok || len(interfaceType.Methods) != 0 {
return "", UnsupportedReturnError
}
return QueryModeMany, nil
}
return "", UnsupportedReturnError
}
func (p interfaceMethodParser) parseFindMethod(tokens []string) (MethodSpec, error) {
if len(tokens) == 0 {
return MethodSpec{}, UnsupportedNameError

View file

@ -40,6 +40,64 @@ type ParseInterfaceMethodTestCase struct {
ExpectedOperation spec.Operation
}
func TestParseInterfaceMethod_Insert(t *testing.T) {
testTable := []ParseInterfaceMethodTestCase{
{
Name: "InsertOne method",
Method: code.Method{
Name: "InsertOne",
Params: []code.Param{
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Type: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
},
Returns: []code.Type{
code.InterfaceType{},
code.SimpleType("error"),
},
},
ExpectedOperation: spec.InsertOperation{
Mode: spec.QueryModeOne,
},
},
{
Name: "InsertMany method",
Method: code.Method{
Name: "InsertMany",
Params: []code.Param{
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Type: code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}},
},
Returns: []code.Type{
code.ArrayType{ContainedType: code.InterfaceType{}},
code.SimpleType("error"),
},
},
ExpectedOperation: spec.InsertOperation{
Mode: spec.QueryModeMany,
},
},
}
for _, testCase := range testTable {
t.Run(testCase.Name, func(t *testing.T) {
actualSpec, err := spec.ParseInterfaceMethod(structModel, testCase.Method)
if err != nil {
t.Errorf("Error = %s", err)
}
expectedOutput := spec.MethodSpec{
Name: testCase.Method.Name,
Params: testCase.Method.Params,
Returns: testCase.Method.Returns,
Operation: testCase.ExpectedOperation,
}
if !reflect.DeepEqual(actualSpec, expectedOutput) {
t.Errorf("Expected = %v\nReceived = %v", expectedOutput, actualSpec)
}
})
}
}
func TestParseInterfaceMethod_Find(t *testing.T) {
testTable := []ParseInterfaceMethodTestCase{
{
@ -736,6 +794,114 @@ func TestParseInterfaceMethod_Invalid(t *testing.T) {
}
}
func TestParseInterfaceMethod_Insert_Invalid(t *testing.T) {
testTable := []ParseInterfaceMethodInvalidTestCase{
{
Name: "invalid number of returns",
Method: code.Method{
Name: "Insert",
Returns: []code.Type{
code.PointerType{ContainedType: code.SimpleType("UserModel")},
code.InterfaceType{},
code.SimpleType("error"),
},
},
ExpectedError: spec.UnsupportedReturnError,
},
{
Name: "unsupported return types from insert method",
Method: code.Method{
Name: "Insert",
Returns: []code.Type{
code.PointerType{ContainedType: code.SimpleType("UserModel")},
code.SimpleType("error"),
},
},
ExpectedError: spec.UnsupportedReturnError,
},
{
Name: "unempty interface return from insert method",
Method: code.Method{
Name: "Insert",
Returns: []code.Type{
code.InterfaceType{
Methods: []code.Method{
{Name: "DoSomething"},
},
},
code.SimpleType("error"),
},
},
ExpectedError: spec.UnsupportedReturnError,
},
{
Name: "error return not provided",
Method: code.Method{
Name: "Insert",
Returns: []code.Type{
code.PointerType{ContainedType: code.SimpleType("UserModel")},
code.InterfaceType{},
},
},
ExpectedError: spec.UnsupportedReturnError,
},
{
Name: "no context parameter",
Method: code.Method{
Name: "Insert",
Params: []code.Param{
{Name: "userModel", Type: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
},
Returns: []code.Type{
code.InterfaceType{},
code.SimpleType("error"),
},
},
ExpectedError: spec.ContextParamRequiredError,
},
{
Name: "mismatched model parameter for one mode",
Method: code.Method{
Name: "Insert",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "userModel", Type: code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}},
},
Returns: []code.Type{
code.InterfaceType{},
code.SimpleType("error"),
},
},
ExpectedError: spec.InvalidParamError,
},
{
Name: "mismatched model parameter for many mode",
Method: code.Method{
Name: "Insert",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "userModel", Type: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
},
Returns: []code.Type{
code.ArrayType{ContainedType: code.InterfaceType{}},
code.SimpleType("error"),
},
},
ExpectedError: spec.InvalidParamError,
},
}
for _, testCase := range testTable {
t.Run(testCase.Name, func(t *testing.T) {
_, err := spec.ParseInterfaceMethod(structModel, testCase.Method)
if err != testCase.ExpectedError {
t.Errorf("\nExpected = %v\nReceived = %v", testCase.ExpectedError, err)
}
})
}
}
func TestParseInterfaceMethod_Find_Invalid(t *testing.T) {
testTable := []ParseInterfaceMethodInvalidTestCase{
{
@ -758,7 +924,7 @@ func TestParseInterfaceMethod_Find_Invalid(t *testing.T) {
ExpectedError: spec.UnsupportedReturnError,
},
{
Name: "unsupported return values from find method",
Name: "unsupported return types from find method",
Method: code.Method{
Name: "FindOneByID",
Returns: []code.Type{
@ -933,7 +1099,7 @@ func TestParseInterfaceMethod_Update_Invalid(t *testing.T) {
ExpectedError: spec.UnsupportedReturnError,
},
{
Name: "unsupported return values from find method",
Name: "unsupported return types from update method",
Method: code.Method{
Name: "UpdateAgeByID",
Returns: []code.Type{
@ -1085,7 +1251,7 @@ func TestParseInterfaceMethod_Delete_Invalid(t *testing.T) {
ExpectedError: spec.UnsupportedReturnError,
},
{
Name: "unsupported return values from find method",
Name: "unsupported return types from delete method",
Method: code.Method{
Name: "DeleteOneByID",
Returns: []code.Type{