Validate parameter types in method signature

This commit is contained in:
sunboyy 2021-01-21 18:56:30 +07:00
parent 26feb0e0e3
commit f799a284b1
7 changed files with 587 additions and 488 deletions

View file

@ -2,7 +2,7 @@ coverage:
status: status:
project: project:
default: default:
target: 65% target: 70%
threshold: 5% threshold: 5%
patch: patch:
default: default:

View file

@ -13,15 +13,20 @@ import (
// GenerateMongoRepository generates mongodb repository // GenerateMongoRepository generates mongodb repository
func GenerateMongoRepository(packageName string, structModel code.Struct, intf code.Interface) (string, error) { func GenerateMongoRepository(packageName string, structModel code.Struct, intf code.Interface) (string, error) {
repositorySpec, err := spec.ParseRepositoryInterface(structModel, intf) var methodSpecs []spec.MethodSpec
for _, method := range intf.Methods {
methodSpec, err := spec.ParseInterfaceMethod(structModel, method)
if err != nil { if err != nil {
return "", err return "", err
} }
methodSpecs = append(methodSpecs, methodSpec)
}
generator := mongoRepositoryGenerator{ generator := mongoRepositoryGenerator{
PackageName: packageName, PackageName: packageName,
StructModel: structModel, StructModel: structModel,
RepositorySpec: repositorySpec, InterfaceName: intf.Name,
MethodSpecs: methodSpecs,
} }
output, err := generator.Generate() output, err := generator.Generate()
@ -35,7 +40,8 @@ func GenerateMongoRepository(packageName string, structModel code.Struct, intf c
type mongoRepositoryGenerator struct { type mongoRepositoryGenerator struct {
PackageName string PackageName string
StructModel code.Struct StructModel code.Struct
RepositorySpec spec.RepositorySpec InterfaceName string
MethodSpecs []spec.MethodSpec
} }
func (g mongoRepositoryGenerator) Generate() (string, error) { func (g mongoRepositoryGenerator) Generate() (string, error) {
@ -44,7 +50,7 @@ func (g mongoRepositoryGenerator) Generate() (string, error) {
return "", err return "", err
} }
for _, method := range g.RepositorySpec.Methods { for _, method := range g.MethodSpecs {
if err := g.generateMethod(buffer, method); err != nil { if err := g.generateMethod(buffer, method); err != nil {
return "", err return "", err
} }
@ -66,7 +72,7 @@ func (g mongoRepositoryGenerator) generateBaseContent(buffer *bytes.Buffer) erro
tmplData := mongoBaseTemplateData{ tmplData := mongoBaseTemplateData{
PackageName: g.PackageName, PackageName: g.PackageName,
InterfaceName: g.RepositorySpec.InterfaceName, InterfaceName: g.InterfaceName,
StructName: g.structName(), StructName: g.structName(),
} }
@ -167,5 +173,5 @@ func (g mongoRepositoryGenerator) generateFindImplementation(operation spec.Find
} }
func (g mongoRepositoryGenerator) structName() string { func (g mongoRepositoryGenerator) structName() string {
return g.RepositorySpec.InterfaceName + "Mongo" return g.InterfaceName + "Mongo"
} }

View file

@ -83,7 +83,7 @@ func TestGenerateMongoRepository(t *testing.T) {
Name: "FindByGenderNot", Name: "FindByGenderNot",
Params: []code.Param{ Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "gender", Type: code.SimpleType("int")}, {Name: "gender", Type: code.SimpleType("Gender")},
}, },
Returns: []code.Type{ Returns: []code.Type{
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
@ -220,7 +220,7 @@ func (r *UserRepositoryMongo) FindByIDAndUsername(ctx context.Context, arg0 prim
return &entity, nil return &entity, nil
} }
func (r *UserRepositoryMongo) FindByGenderNot(ctx context.Context, arg0 int) ([]*UserModel, error) { func (r *UserRepositoryMongo) FindByGenderNot(ctx context.Context, arg0 Gender) ([]*UserModel, error) {
cursor, err := r.collection.Find(ctx, bson.M{ cursor, err := r.collection.Find(ctx, bson.M{
"gender": bson.M{"$ne": arg0}, "gender": bson.M{"$ne": arg0},
}) })

35
internal/spec/errors.go Normal file
View file

@ -0,0 +1,35 @@
package spec
// ParsingError is an error from parsing interface methods
type ParsingError string
func (err ParsingError) Error() string {
switch err {
case UnknownOperationError:
return "unknown operation"
case UnsupportedNameError:
return "method name is not supported"
case InvalidQueryError:
return "invalid query"
case InvalidParamError:
return "parameters do not match the query"
case UnsupportedReturnError:
return "this type of return is not supported"
case ContextParamRequiredError:
return "context parameter is required"
case StructFieldNotFoundError:
return "struct field not found"
}
return string(err)
}
// parsing error constants
const (
UnknownOperationError ParsingError = "ERROR_UNKNOWN_OPERATION"
UnsupportedNameError ParsingError = "ERROR_UNSUPPORTED"
InvalidQueryError ParsingError = "ERROR_INVALID_QUERY"
InvalidParamError ParsingError = "ERROR_INVALID_PARAM"
UnsupportedReturnError ParsingError = "ERROR_INVALID_RETURN"
ContextParamRequiredError ParsingError = "ERROR_CONTEXT_PARAM_REQUIRED"
StructFieldNotFoundError ParsingError = "ERROR_STRUCT_FIELD_NOT_FOUND"
)

View file

@ -15,12 +15,6 @@ const (
QueryModeMany QueryMode = "MANY" QueryModeMany QueryMode = "MANY"
) )
// RepositorySpec is a specification generated from the repository interface
type RepositorySpec struct {
InterfaceName string
Methods []MethodSpec
}
// MethodSpec is a method specification inside repository specification // MethodSpec is a method specification inside repository specification
type MethodSpec struct { type MethodSpec struct {
Name string Name string

View file

@ -1,59 +1,40 @@
package spec package spec
import ( import (
"errors"
"fmt"
"github.com/fatih/camelcase" "github.com/fatih/camelcase"
"github.com/sunboyy/repogen/internal/code" "github.com/sunboyy/repogen/internal/code"
) )
// ParseRepositoryInterface returns repository spec from declared repository interface // ParseInterfaceMethod returns repository method spec from declared interface method
func ParseRepositoryInterface(structModel code.Struct, intf code.Interface) (RepositorySpec, error) { func ParseInterfaceMethod(structModel code.Struct, method code.Method) (MethodSpec, error) {
parser := repositoryInterfaceParser{ parser := interfaceMethodParser{
StructModel: structModel, StructModel: structModel,
Interface: intf, Method: method,
} }
return parser.Parse() return parser.Parse()
} }
type repositoryInterfaceParser struct { type interfaceMethodParser struct {
StructModel code.Struct StructModel code.Struct
Interface code.Interface Method code.Method
} }
func (p repositoryInterfaceParser) Parse() (RepositorySpec, error) { func (p interfaceMethodParser) Parse() (MethodSpec, error) {
repositorySpec := RepositorySpec{ methodNameTokens := camelcase.Split(p.Method.Name)
InterfaceName: p.Interface.Name,
}
for _, method := range p.Interface.Methods {
methodSpec, err := p.parseMethod(method)
if err != nil {
return RepositorySpec{}, err
}
repositorySpec.Methods = append(repositorySpec.Methods, methodSpec)
}
return repositorySpec, nil
}
func (p repositoryInterfaceParser) parseMethod(method code.Method) (MethodSpec, error) {
methodNameTokens := camelcase.Split(method.Name)
switch methodNameTokens[0] { switch methodNameTokens[0] {
case "Find": case "Find":
return p.parseFindMethod(method, methodNameTokens[1:]) return p.parseFindMethod(methodNameTokens[1:])
} }
return MethodSpec{}, errors.New("method name not supported") return MethodSpec{}, UnknownOperationError
} }
func (p repositoryInterfaceParser) parseFindMethod(method code.Method, tokens []string) (MethodSpec, error) { func (p interfaceMethodParser) parseFindMethod(tokens []string) (MethodSpec, error) {
if len(tokens) == 0 { if len(tokens) == 0 {
return MethodSpec{}, errors.New("method name not supported") return MethodSpec{}, UnsupportedNameError
} }
mode, err := p.extractFindReturns(method.Returns) mode, err := p.extractFindReturns(p.Method.Returns)
if err != nil { if err != nil {
return MethodSpec{}, err return MethodSpec{}, err
} }
@ -63,14 +44,14 @@ func (p repositoryInterfaceParser) parseFindMethod(method code.Method, tokens []
return MethodSpec{}, err return MethodSpec{}, err
} }
if querySpec.NumberOfArguments()+1 != len(method.Params) { if err := p.validateMethodSignature(querySpec); err != nil {
return MethodSpec{}, errors.New("method parameter not supported") return MethodSpec{}, err
} }
return MethodSpec{ return MethodSpec{
Name: method.Name, Name: p.Method.Name,
Params: method.Params, Params: p.Method.Params,
Returns: method.Returns, Returns: p.Method.Returns,
Operation: FindOperation{ Operation: FindOperation{
Mode: mode, Mode: mode,
Query: querySpec, Query: querySpec,
@ -78,13 +59,13 @@ func (p repositoryInterfaceParser) parseFindMethod(method code.Method, tokens []
}, nil }, nil
} }
func (p repositoryInterfaceParser) extractFindReturns(returns []code.Type) (QueryMode, error) { func (p interfaceMethodParser) extractFindReturns(returns []code.Type) (QueryMode, error) {
if len(returns) != 2 { if len(returns) != 2 {
return "", errors.New("method return not supported") return "", UnsupportedReturnError
} }
if returns[1] != code.SimpleType("error") { if returns[1] != code.SimpleType("error") {
return "", errors.New("method return not supported") return "", UnsupportedReturnError
} }
pointerType, ok := returns[0].(code.PointerType) pointerType, ok := returns[0].(code.PointerType)
@ -93,7 +74,7 @@ func (p repositoryInterfaceParser) extractFindReturns(returns []code.Type) (Quer
if simpleType == code.SimpleType(p.StructModel.Name) { if simpleType == code.SimpleType(p.StructModel.Name) {
return QueryModeOne, nil return QueryModeOne, nil
} }
return "", fmt.Errorf("invalid return type %s", pointerType.Code()) return "", UnsupportedReturnError
} }
arrayType, ok := returns[0].(code.ArrayType) arrayType, ok := returns[0].(code.ArrayType)
@ -104,16 +85,16 @@ func (p repositoryInterfaceParser) extractFindReturns(returns []code.Type) (Quer
if simpleType == code.SimpleType(p.StructModel.Name) { if simpleType == code.SimpleType(p.StructModel.Name) {
return QueryModeMany, nil return QueryModeMany, nil
} }
return "", fmt.Errorf("invalid return type %s", pointerType.Code()) return "", UnsupportedReturnError
} }
} }
return "", errors.New("method return not supported") return "", UnsupportedReturnError
} }
func (p repositoryInterfaceParser) parseQuery(tokens []string) (QuerySpec, error) { func (p interfaceMethodParser) parseQuery(tokens []string) (QuerySpec, error) {
if len(tokens) == 0 { if len(tokens) == 0 {
return QuerySpec{}, errors.New("method name not supported") return QuerySpec{}, InvalidQueryError
} }
if len(tokens) == 1 && tokens[0] == "All" { if len(tokens) == 1 && tokens[0] == "All" {
@ -128,7 +109,7 @@ func (p repositoryInterfaceParser) parseQuery(tokens []string) (QuerySpec, error
} }
if tokens[0] == "And" || tokens[0] == "Or" { if tokens[0] == "And" || tokens[0] == "Or" {
return QuerySpec{}, errors.New("method name not supported") return QuerySpec{}, InvalidQueryError
} }
var operator Operator var operator Operator
@ -137,6 +118,8 @@ func (p repositoryInterfaceParser) parseQuery(tokens []string) (QuerySpec, error
for _, token := range tokens { for _, token := range tokens {
if token != "And" && token != "Or" { if token != "And" && token != "Or" {
aggregatedToken = append(aggregatedToken, token) aggregatedToken = append(aggregatedToken, token)
} else if len(aggregatedToken) == 0 {
return QuerySpec{}, InvalidQueryError
} else if token == "And" && operator != OperatorOr { } else if token == "And" && operator != OperatorOr {
operator = OperatorAnd operator = OperatorAnd
predicates = append(predicates, aggregatedToken.ToPredicate()) predicates = append(predicates, aggregatedToken.ToPredicate())
@ -146,13 +129,40 @@ func (p repositoryInterfaceParser) parseQuery(tokens []string) (QuerySpec, error
predicates = append(predicates, aggregatedToken.ToPredicate()) predicates = append(predicates, aggregatedToken.ToPredicate())
aggregatedToken = predicateToken{} aggregatedToken = predicateToken{}
} else { } else {
return QuerySpec{}, errors.New("method name contains ambiguous query") return QuerySpec{}, InvalidQueryError
} }
} }
if len(aggregatedToken) == 0 { if len(aggregatedToken) == 0 {
return QuerySpec{}, errors.New("method name not supported") return QuerySpec{}, InvalidQueryError
} }
predicates = append(predicates, aggregatedToken.ToPredicate()) predicates = append(predicates, aggregatedToken.ToPredicate())
return QuerySpec{Operator: operator, Predicates: predicates}, nil return QuerySpec{Operator: operator, Predicates: predicates}, nil
} }
func (p interfaceMethodParser) validateMethodSignature(querySpec QuerySpec) error {
contextType := code.ExternalType{PackageAlias: "context", Name: "Context"}
if len(p.Method.Params) == 0 || p.Method.Params[0].Type != contextType {
return ContextParamRequiredError
}
if querySpec.NumberOfArguments()+1 != len(p.Method.Params) {
return InvalidParamError
}
currentParamIndex := 1
for _, predicate := range querySpec.Predicates {
structField, ok := p.StructModel.Fields.ByName(predicate.Field)
if !ok {
return StructFieldNotFoundError
}
if structField.Type != p.Method.Params[currentParamIndex].Type {
return InvalidParamError
}
currentParamIndex++
}
return nil
}

File diff suppressed because it is too large Load diff