diff --git a/README.md b/README.md index 44aee98..651b944 100644 --- a/README.md +++ b/README.md @@ -13,6 +13,7 @@ Repogen is a code generator for database repository in Golang inspired by Spring Repogen is a library that generates MongoDB repository implementation from repository interface by using method name pattern. +- CRUD functionality - Method signature validation - Supports single-entity and multiple-entity operations - Supports many comparison operators diff --git a/codecov.yml b/codecov.yml index 2c28e2f..4b6e28e 100644 --- a/codecov.yml +++ b/codecov.yml @@ -2,7 +2,7 @@ coverage: status: project: default: - target: 75% + target: 80% threshold: 4% patch: default: diff --git a/internal/code/extractor.go b/internal/code/extractor.go index 93320bc..8ddc438 100644 --- a/internal/code/extractor.go +++ b/internal/code/extractor.go @@ -62,41 +62,23 @@ func ExtractComponents(f *ast.File) File { interfaceType, ok := typeSpec.Type.(*ast.InterfaceType) if ok { - intf := Interface{ + intf := InterfaceType{ Name: typeSpec.Name.Name, } for _, method := range interfaceType.Methods.List { - var meth Method - for _, name := range method.Names { - meth.Name = name.Name - break - } - funcType, ok := method.Type.(*ast.FuncType) if !ok { continue } - for _, param := range funcType.Params.List { - paramType := getType(param.Type) - - if len(param.Names) == 0 { - meth.Params = append(meth.Params, Param{Type: paramType}) - continue - } - - for _, name := range param.Names { - meth.Params = append(meth.Params, Param{ - Name: name.Name, - Type: paramType, - }) - } + var name string + for _, n := range method.Names { + name = n.Name + break } - for _, result := range funcType.Results.List { - meth.Returns = append(meth.Returns, getType(result.Type)) - } + meth := extractFunction(name, funcType) intf.Methods = append(intf.Methods, meth) } @@ -131,6 +113,35 @@ func extractStructTag(tagValue string) map[string][]string { return tags } +func extractFunction(name string, funcType *ast.FuncType) Method { + meth := Method{ + Name: name, + } + for _, param := range funcType.Params.List { + paramType := getType(param.Type) + + if len(param.Names) == 0 { + meth.Params = append(meth.Params, Param{Type: paramType}) + continue + } + + for _, name := range param.Names { + meth.Params = append(meth.Params, Param{ + Name: name.Name, + Type: paramType, + }) + } + } + + if funcType.Results != nil { + for _, result := range funcType.Results.List { + meth.Returns = append(meth.Returns, getType(result.Type)) + } + } + + return meth +} + func getType(expr ast.Expr) Type { identExpr, ok := expr.(*ast.Ident) if ok { @@ -158,5 +169,28 @@ func getType(expr ast.Expr) Type { return ArrayType{containedType} } + intfType, ok := expr.(*ast.InterfaceType) + if ok { + var methods []Method + for _, method := range intfType.Methods.List { + funcType, ok := method.Type.(*ast.FuncType) + if !ok { + continue + } + + var name string + for _, n := range method.Names { + name = n.Name + break + } + + methods = append(methods, extractFunction(name, funcType)) + } + + return InterfaceType{ + Methods: methods, + } + } + return nil } diff --git a/internal/code/extractor_test.go b/internal/code/extractor_test.go index 0cea722..4a1a705 100644 --- a/internal/code/extractor_test.go +++ b/internal/code/extractor_test.go @@ -97,6 +97,12 @@ type UserRepository interface { FindOneByID(ctx context.Context, id primitive.ObjectID) (*UserModel, error) FindAll(context.Context) ([]*UserModel, error) FindByAgeBetween(ctx context.Context, fromAge, toAge int) ([]*UserModel, error) + InsertOne(ctx context.Context, user *UserModel) (interface{}, error) + CustomMethod(interface { + Run(arg1 int) + }) interface { + Do(arg2 string) + } }`, ExpectedOutput: code.File{ PackageName: "user", @@ -137,6 +143,46 @@ type UserRepository interface { code.SimpleType("error"), }, }, + { + Name: "InsertOne", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Name: "user", Type: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + }, + Returns: []code.Type{ + code.InterfaceType{}, + code.SimpleType("error"), + }, + }, + { + Name: "CustomMethod", + Params: []code.Param{ + { + Type: code.InterfaceType{ + Methods: []code.Method{ + { + Name: "Run", + Params: []code.Param{ + {Name: "arg1", Type: code.SimpleType("int")}, + }, + }, + }, + }, + }, + }, + Returns: []code.Type{ + code.InterfaceType{ + Methods: []code.Method{ + { + Name: "Do", + Params: []code.Param{ + {Name: "arg2", Type: code.SimpleType("string")}, + }, + }, + }, + }, + }, + }, }, }, }, diff --git a/internal/code/models.go b/internal/code/models.go index 4e1bc4d..b321456 100644 --- a/internal/code/models.go +++ b/internal/code/models.go @@ -38,6 +38,11 @@ type Struct struct { Fields StructFields } +// ReferencedType returns a type variable of this struct +func (str Struct) ReferencedType() Type { + return SimpleType(str.Name) +} + // StructFields is a group of the StructField model type StructFields []StructField @@ -59,25 +64,30 @@ type StructField struct { } // Interfaces is a group of Interface model -type Interfaces []Interface +type Interfaces []InterfaceType // ByName return interface by name Another return value shows whether there is an interface // with that name exists. -func (intfs Interfaces) ByName(name string) (Interface, bool) { +func (intfs Interfaces) ByName(name string) (InterfaceType, bool) { for _, intf := range intfs { if intf.Name == name { return intf, true } } - return Interface{}, false + return InterfaceType{}, false } -// Interface is a definition of the interface -type Interface struct { +// InterfaceType is a definition of the interface +type InterfaceType struct { Name string Methods []Method } +// Code returns token string in code format +func (intf InterfaceType) Code() string { + return `interface{}` +} + // Method is a definition of the method inside the interface type Method struct { Name string diff --git a/internal/code/models_test.go b/internal/code/models_test.go index d1e84b9..f1ccb38 100644 --- a/internal/code/models_test.go +++ b/internal/code/models_test.go @@ -63,7 +63,7 @@ func TestStructFieldsByName(t *testing.T) { } func TestInterfacesByName(t *testing.T) { - userRepoIntf := code.Interface{Name: "UserRepository"} + userRepoIntf := code.InterfaceType{Name: "UserRepository"} interfaces := code.Interfaces{userRepoIntf} t.Run("struct field found", func(t *testing.T) { @@ -85,3 +85,44 @@ func TestInterfacesByName(t *testing.T) { } }) } + +type TypeCodeTestCase struct { + Name string + Type code.Type + ExpectedCode string +} + +func TestArrayTypeCode(t *testing.T) { + testTable := []TypeCodeTestCase{ + { + Name: "simple type", + Type: code.SimpleType("UserModel"), + ExpectedCode: "UserModel", + }, + { + Name: "external type", + Type: code.ExternalType{PackageAlias: "context", Name: "Context"}, + ExpectedCode: "context.Context", + }, + { + Name: "pointer type", + Type: code.PointerType{ContainedType: code.SimpleType("UserModel")}, + ExpectedCode: "*UserModel", + }, + { + Name: "array type", + Type: code.ArrayType{ContainedType: code.SimpleType("UserModel")}, + ExpectedCode: "[]UserModel", + }, + } + + for _, testCase := range testTable { + t.Run(testCase.Name, func(t *testing.T) { + code := testCase.Type.Code() + + if code != testCase.ExpectedCode { + t.Errorf("Expected = %v\nReceived = %v", testCase.ExpectedCode, code) + } + }) + } +} diff --git a/internal/mongo/generator.go b/internal/mongo/generator.go index e421c93..d51cba9 100644 --- a/internal/mongo/generator.go +++ b/internal/mongo/generator.go @@ -77,6 +77,8 @@ func (g RepositoryGenerator) GenerateMethod(methodSpec spec.MethodSpec, buffer i func (g RepositoryGenerator) generateMethodImplementation(methodSpec spec.MethodSpec) (string, error) { switch operation := methodSpec.Operation.(type) { + case spec.InsertOperation: + return g.generateInsertImplementation(operation) case spec.FindOperation: return g.generateFindImplementation(operation) case spec.UpdateOperation: @@ -88,6 +90,13 @@ func (g RepositoryGenerator) generateMethodImplementation(methodSpec spec.Method return "", OperationNotSupportedError } +func (g RepositoryGenerator) generateInsertImplementation(operation spec.InsertOperation) (string, error) { + if operation.Mode == spec.QueryModeOne { + return insertOneTemplate, nil + } + return insertManyTemplate, nil +} + func (g RepositoryGenerator) generateFindImplementation(operation spec.FindOperation) (string, error) { buffer := new(bytes.Buffer) diff --git a/internal/mongo/generator_test.go b/internal/mongo/generator_test.go index bdfc73c..f274221 100644 --- a/internal/mongo/generator_test.go +++ b/internal/mongo/generator_test.go @@ -80,6 +80,85 @@ type GenerateMethodTestCase struct { ExpectedCode string } +func TestGenerateMethod_Insert(t *testing.T) { + testTable := []GenerateMethodTestCase{ + { + Name: "insert one method", + MethodSpec: spec.MethodSpec{ + Name: "InsertOne", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Name: "userModel", Type: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + }, + Returns: []code.Type{ + code.InterfaceType{}, + code.SimpleType("error"), + }, + Operation: spec.InsertOperation{ + Mode: spec.QueryModeOne, + }, + }, + ExpectedCode: ` +func (r *UserRepositoryMongo) InsertOne(arg0 context.Context, arg1 *UserModel) (interface{}, error) { + result, err := r.collection.InsertOne(arg0, arg1) + if err != nil { + return nil, err + } + return result.InsertedID, nil +} +`, + }, + { + Name: "insert many method", + MethodSpec: spec.MethodSpec{ + Name: "Insert", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Name: "userModel", Type: code.ArrayType{ + ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}, + }}, + }, + Returns: []code.Type{ + code.ArrayType{ContainedType: code.InterfaceType{}}, + code.SimpleType("error"), + }, + Operation: spec.InsertOperation{ + Mode: spec.QueryModeMany, + }, + }, + ExpectedCode: ` +func (r *UserRepositoryMongo) Insert(arg0 context.Context, arg1 []*UserModel) ([]interface{}, error) { + var entities []interface{} + for _, model := range arg1 { + entities = append(entities, model) + } + result, err := r.collection.InsertMany(arg0, entities) + if err != nil { + return nil, err + } + return result.InsertedIDs, nil +} +`, + }, + } + + for _, testCase := range testTable { + t.Run(testCase.Name, func(t *testing.T) { + generator := mongo.NewGenerator(userModel, "UserRepository") + buffer := new(bytes.Buffer) + + err := generator.GenerateMethod(testCase.MethodSpec, buffer) + + if err != nil { + t.Error(err) + } + if err := testutils.ExpectMultiLineString(testCase.ExpectedCode, buffer.String()); err != nil { + t.Error(err) + } + }) + } +} + func TestGenerateMethod_Find(t *testing.T) { testTable := []GenerateMethodTestCase{ { @@ -90,7 +169,10 @@ func TestGenerateMethod_Find(t *testing.T) { {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, {Name: "id", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, }, - Returns: []code.Type{code.PointerType{ContainedType: code.SimpleType("UserModel")}, code.SimpleType("error")}, + Returns: []code.Type{ + code.PointerType{ContainedType: code.SimpleType("UserModel")}, + code.SimpleType("error"), + }, Operation: spec.FindOperation{ Mode: spec.QueryModeOne, Query: spec.QuerySpec{ diff --git a/internal/mongo/templates.go b/internal/mongo/templates.go index 7a1028f..ca74a49 100644 --- a/internal/mongo/templates.go +++ b/internal/mongo/templates.go @@ -70,6 +70,22 @@ func (data mongoMethodTemplateData) Returns() string { return fmt.Sprintf(" (%s)", strings.Join(returns, ", ")) } +const insertOneTemplate = ` result, err := r.collection.InsertOne(arg0, arg1) + if err != nil { + return nil, err + } + return result.InsertedID, nil` + +const insertManyTemplate = ` var entities []interface{} + for _, model := range arg1 { + entities = append(entities, model) + } + result, err := r.collection.InsertMany(arg0, entities) + if err != nil { + return nil, err + } + return result.InsertedIDs, nil` + type mongoFindTemplateData struct { EntityType string QuerySpec querySpec diff --git a/internal/spec/errors.go b/internal/spec/errors.go index d4d0dd9..5d6ac52 100644 --- a/internal/spec/errors.go +++ b/internal/spec/errors.go @@ -9,14 +9,14 @@ func (err ParsingError) Error() string { return "unknown operation" case UnsupportedNameError: return "method name is not supported" + case UnsupportedReturnError: + return "this type of return is not supported" case InvalidQueryError: return "invalid query" case InvalidParamError: return "parameters do not match the query" case InvalidUpdateFieldsError: return "update fields is invalid" - case UnsupportedReturnError: - return "this type of return is not supported" case ContextParamRequiredError: return "context parameter is required" case StructFieldNotFoundError: @@ -29,10 +29,10 @@ func (err ParsingError) Error() string { const ( UnknownOperationError ParsingError = "ERROR_UNKNOWN_OPERATION" UnsupportedNameError ParsingError = "ERROR_UNSUPPORTED" + UnsupportedReturnError ParsingError = "ERROR_UNSUPPORTED_RETURN" InvalidQueryError ParsingError = "ERROR_INVALID_QUERY" InvalidParamError ParsingError = "ERROR_INVALID_PARAM" InvalidUpdateFieldsError ParsingError = "ERROR_INVALID_UPDATE_FIELDS" - UnsupportedReturnError ParsingError = "ERROR_INVALID_RETURN" ContextParamRequiredError ParsingError = "ERROR_CONTEXT_PARAM_REQUIRED" StructFieldNotFoundError ParsingError = "ERROR_STRUCT_FIELD_NOT_FOUND" ) diff --git a/internal/spec/models.go b/internal/spec/models.go index b893fc1..89d0b27 100644 --- a/internal/spec/models.go +++ b/internal/spec/models.go @@ -27,6 +27,11 @@ type MethodSpec struct { type Operation interface { } +// InsertOperation is a method specification for insert operations +type InsertOperation struct { + Mode QueryMode +} + // FindOperation is a method specification for find operations type FindOperation struct { Mode QueryMode diff --git a/internal/spec/parser.go b/internal/spec/parser.go index d39d7be..7a6167c 100644 --- a/internal/spec/parser.go +++ b/internal/spec/parser.go @@ -23,6 +23,8 @@ type interfaceMethodParser struct { func (p interfaceMethodParser) Parse() (MethodSpec, error) { methodNameTokens := camelcase.Split(p.Method.Name) switch methodNameTokens[0] { + case "Insert": + return p.parseInsertMethod(methodNameTokens[1:]) case "Find": return p.parseFindMethod(methodNameTokens[1:]) case "Update": @@ -33,6 +35,65 @@ func (p interfaceMethodParser) Parse() (MethodSpec, error) { return MethodSpec{}, UnknownOperationError } +func (p interfaceMethodParser) parseInsertMethod(tokens []string) (MethodSpec, error) { + mode, err := p.extractInsertReturns(p.Method.Returns) + if err != nil { + return MethodSpec{}, err + } + + if err := p.validateContextParam(); err != nil { + return MethodSpec{}, err + } + + pointerType := code.PointerType{ContainedType: p.StructModel.ReferencedType()} + if mode == QueryModeOne && p.Method.Params[1].Type != pointerType { + return MethodSpec{}, InvalidParamError + } + + arrayType := code.ArrayType{ContainedType: pointerType} + if mode == QueryModeMany && p.Method.Params[1].Type != arrayType { + return MethodSpec{}, InvalidParamError + } + + return MethodSpec{ + Name: p.Method.Name, + Params: p.Method.Params, + Returns: p.Method.Returns, + Operation: InsertOperation{ + Mode: mode, + }, + }, nil +} + +func (p interfaceMethodParser) extractInsertReturns(returns []code.Type) (QueryMode, error) { + if len(returns) != 2 { + return "", UnsupportedReturnError + } + + if returns[1] != code.SimpleType("error") { + return "", UnsupportedReturnError + } + + interfaceType, ok := returns[0].(code.InterfaceType) + if ok { + if len(interfaceType.Methods) != 0 { + return "", UnsupportedReturnError + } + return QueryModeOne, nil + } + + arrayType, ok := returns[0].(code.ArrayType) + if ok { + interfaceType, ok := arrayType.ContainedType.(code.InterfaceType) + if !ok || len(interfaceType.Methods) != 0 { + return "", UnsupportedReturnError + } + return QueryModeMany, nil + } + + return "", UnsupportedReturnError +} + func (p interfaceMethodParser) parseFindMethod(tokens []string) (MethodSpec, error) { if len(tokens) == 0 { return MethodSpec{}, UnsupportedNameError diff --git a/internal/spec/parser_test.go b/internal/spec/parser_test.go index f283492..b523e7a 100644 --- a/internal/spec/parser_test.go +++ b/internal/spec/parser_test.go @@ -40,6 +40,64 @@ type ParseInterfaceMethodTestCase struct { ExpectedOperation spec.Operation } +func TestParseInterfaceMethod_Insert(t *testing.T) { + testTable := []ParseInterfaceMethodTestCase{ + { + Name: "InsertOne method", + Method: code.Method{ + Name: "InsertOne", + Params: []code.Param{ + {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Type: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + }, + Returns: []code.Type{ + code.InterfaceType{}, + code.SimpleType("error"), + }, + }, + ExpectedOperation: spec.InsertOperation{ + Mode: spec.QueryModeOne, + }, + }, + { + Name: "InsertMany method", + Method: code.Method{ + Name: "InsertMany", + Params: []code.Param{ + {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Type: code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}}, + }, + Returns: []code.Type{ + code.ArrayType{ContainedType: code.InterfaceType{}}, + code.SimpleType("error"), + }, + }, + ExpectedOperation: spec.InsertOperation{ + Mode: spec.QueryModeMany, + }, + }, + } + + for _, testCase := range testTable { + t.Run(testCase.Name, func(t *testing.T) { + actualSpec, err := spec.ParseInterfaceMethod(structModel, testCase.Method) + + if err != nil { + t.Errorf("Error = %s", err) + } + expectedOutput := spec.MethodSpec{ + Name: testCase.Method.Name, + Params: testCase.Method.Params, + Returns: testCase.Method.Returns, + Operation: testCase.ExpectedOperation, + } + if !reflect.DeepEqual(actualSpec, expectedOutput) { + t.Errorf("Expected = %v\nReceived = %v", expectedOutput, actualSpec) + } + }) + } +} + func TestParseInterfaceMethod_Find(t *testing.T) { testTable := []ParseInterfaceMethodTestCase{ { @@ -736,6 +794,114 @@ func TestParseInterfaceMethod_Invalid(t *testing.T) { } } +func TestParseInterfaceMethod_Insert_Invalid(t *testing.T) { + testTable := []ParseInterfaceMethodInvalidTestCase{ + { + Name: "invalid number of returns", + Method: code.Method{ + Name: "Insert", + Returns: []code.Type{ + code.PointerType{ContainedType: code.SimpleType("UserModel")}, + code.InterfaceType{}, + code.SimpleType("error"), + }, + }, + ExpectedError: spec.UnsupportedReturnError, + }, + { + Name: "unsupported return types from insert method", + Method: code.Method{ + Name: "Insert", + Returns: []code.Type{ + code.PointerType{ContainedType: code.SimpleType("UserModel")}, + code.SimpleType("error"), + }, + }, + ExpectedError: spec.UnsupportedReturnError, + }, + { + Name: "unempty interface return from insert method", + Method: code.Method{ + Name: "Insert", + Returns: []code.Type{ + code.InterfaceType{ + Methods: []code.Method{ + {Name: "DoSomething"}, + }, + }, + code.SimpleType("error"), + }, + }, + ExpectedError: spec.UnsupportedReturnError, + }, + { + Name: "error return not provided", + Method: code.Method{ + Name: "Insert", + Returns: []code.Type{ + code.PointerType{ContainedType: code.SimpleType("UserModel")}, + code.InterfaceType{}, + }, + }, + ExpectedError: spec.UnsupportedReturnError, + }, + { + Name: "no context parameter", + Method: code.Method{ + Name: "Insert", + Params: []code.Param{ + {Name: "userModel", Type: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + }, + Returns: []code.Type{ + code.InterfaceType{}, + code.SimpleType("error"), + }, + }, + ExpectedError: spec.ContextParamRequiredError, + }, + { + Name: "mismatched model parameter for one mode", + Method: code.Method{ + Name: "Insert", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Name: "userModel", Type: code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}}, + }, + Returns: []code.Type{ + code.InterfaceType{}, + code.SimpleType("error"), + }, + }, + ExpectedError: spec.InvalidParamError, + }, + { + Name: "mismatched model parameter for many mode", + Method: code.Method{ + Name: "Insert", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Name: "userModel", Type: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + }, + Returns: []code.Type{ + code.ArrayType{ContainedType: code.InterfaceType{}}, + code.SimpleType("error"), + }, + }, + ExpectedError: spec.InvalidParamError, + }, + } + + for _, testCase := range testTable { + t.Run(testCase.Name, func(t *testing.T) { + _, err := spec.ParseInterfaceMethod(structModel, testCase.Method) + + if err != testCase.ExpectedError { + t.Errorf("\nExpected = %v\nReceived = %v", testCase.ExpectedError, err) + } + }) + } +} + func TestParseInterfaceMethod_Find_Invalid(t *testing.T) { testTable := []ParseInterfaceMethodInvalidTestCase{ { @@ -758,7 +924,7 @@ func TestParseInterfaceMethod_Find_Invalid(t *testing.T) { ExpectedError: spec.UnsupportedReturnError, }, { - Name: "unsupported return values from find method", + Name: "unsupported return types from find method", Method: code.Method{ Name: "FindOneByID", Returns: []code.Type{ @@ -933,7 +1099,7 @@ func TestParseInterfaceMethod_Update_Invalid(t *testing.T) { ExpectedError: spec.UnsupportedReturnError, }, { - Name: "unsupported return values from find method", + Name: "unsupported return types from update method", Method: code.Method{ Name: "UpdateAgeByID", Returns: []code.Type{ @@ -1085,7 +1251,7 @@ func TestParseInterfaceMethod_Delete_Invalid(t *testing.T) { ExpectedError: spec.UnsupportedReturnError, }, { - Name: "unsupported return values from find method", + Name: "unsupported return types from delete method", Method: code.Method{ Name: "DeleteOneByID", Returns: []code.Type{