Add NotIn, True and False comparators

This commit is contained in:
sunboyy 2021-02-23 19:10:25 +07:00
parent cbcaf377d2
commit 2383e5da86
7 changed files with 274 additions and 6 deletions

View file

@ -3,6 +3,8 @@ package spec
import (
"fmt"
"strings"
"github.com/sunboyy/repogen/internal/code"
)
// ParsingError is an error from parsing interface methods
@ -71,3 +73,21 @@ type structFieldNotFoundError struct {
func (err structFieldNotFoundError) Error() string {
return fmt.Sprintf("struct field '%s' not found", err.FieldName)
}
// NewIncompatibleComparatorError creates incompatibleComparatorError
func NewIncompatibleComparatorError(comparator Comparator, field code.StructField) error {
return incompatibleComparatorError{
Comparator: comparator,
Field: field,
}
}
type incompatibleComparatorError struct {
Comparator Comparator
Field code.StructField
}
func (err incompatibleComparatorError) Error() string {
return fmt.Sprintf("cannot use comparator %s with struct field '%s' of type '%s'",
err.Comparator, err.Field.Name, err.Field.Type.Code())
}

View file

@ -3,6 +3,7 @@ package spec_test
import (
"testing"
"github.com/sunboyy/repogen/internal/code"
"github.com/sunboyy/repogen/internal/spec"
)
@ -29,6 +30,14 @@ func TestError(t *testing.T) {
Error: spec.NewInvalidQueryError([]string{"By", "And"}),
ExpectedString: "invalid query 'ByAnd'",
},
{
Name: "IncompatibleComparatorError",
Error: spec.NewIncompatibleComparatorError(spec.ComparatorTrue, code.StructField{
Name: "Age",
Type: code.SimpleType("int"),
}),
ExpectedString: "cannot use comparator EQUAL_TRUE with struct field 'Age' of type 'int'",
},
}
for _, testCase := range testTable {

View file

@ -343,6 +343,11 @@ func (p interfaceMethodParser) validateQueryFromParams(params []code.Param, quer
return NewStructFieldNotFoundError(predicate.Field)
}
if (predicate.Comparator == ComparatorTrue || predicate.Comparator == ComparatorFalse) &&
structField.Type != code.SimpleType("bool") {
return NewIncompatibleComparatorError(predicate.Comparator, structField)
}
for i := 0; i < predicate.Comparator.NumberOfArguments(); i++ {
if params[currentParamIndex].Type != predicate.Comparator.ArgumentTypeFromFieldType(
structField.Type) {

View file

@ -31,6 +31,10 @@ var structModel = code.Struct{
Name: "Age",
Type: code.SimpleType("int"),
},
{
Name: "Enabled",
Type: code.SimpleType("bool"),
},
},
}
@ -367,6 +371,64 @@ func TestParseInterfaceMethod_Find(t *testing.T) {
}},
},
},
{
Name: "FindByArgNotIn method",
Method: code.Method{
Name: "FindByCityNotIn",
Params: []code.Param{
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Type: code.ArrayType{ContainedType: code.SimpleType("string")}},
},
Returns: []code.Type{
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
code.SimpleType("error"),
},
},
ExpectedOperation: spec.FindOperation{
Mode: spec.QueryModeMany,
Query: spec.QuerySpec{Predicates: []spec.Predicate{
{Field: "City", Comparator: spec.ComparatorNotIn, ParamIndex: 1},
}},
},
},
{
Name: "FindByArgTrue method",
Method: code.Method{
Name: "FindByEnabledTrue",
Params: []code.Param{
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
},
Returns: []code.Type{
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
code.SimpleType("error"),
},
},
ExpectedOperation: spec.FindOperation{
Mode: spec.QueryModeMany,
Query: spec.QuerySpec{Predicates: []spec.Predicate{
{Field: "Enabled", Comparator: spec.ComparatorTrue, ParamIndex: 1},
}},
},
},
{
Name: "FindByArgFalse method",
Method: code.Method{
Name: "FindByEnabledFalse",
Params: []code.Param{
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
},
Returns: []code.Type{
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
code.SimpleType("error"),
},
},
ExpectedOperation: spec.FindOperation{
Mode: spec.QueryModeMany,
Query: spec.QuerySpec{Predicates: []spec.Predicate{
{Field: "Enabled", Comparator: spec.ComparatorFalse, ParamIndex: 1},
}},
},
},
}
for _, testCase := range testTable {
@ -1100,6 +1162,40 @@ func TestParseInterfaceMethod_Find_Invalid(t *testing.T) {
},
ExpectedError: spec.NewStructFieldNotFoundError("Country"),
},
{
Name: "incompatible struct field for True comparator",
Method: code.Method{
Name: "FindByGenderTrue",
Params: []code.Param{
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
},
Returns: []code.Type{
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
code.SimpleType("error"),
},
},
ExpectedError: spec.NewIncompatibleComparatorError(spec.ComparatorTrue, code.StructField{
Name: "Gender",
Type: code.SimpleType("Gender"),
}),
},
{
Name: "incompatible struct field for False comparator",
Method: code.Method{
Name: "FindByGenderFalse",
Params: []code.Param{
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
},
Returns: []code.Type{
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
code.SimpleType("error"),
},
},
ExpectedError: spec.NewIncompatibleComparatorError(spec.ComparatorFalse, code.StructField{
Name: "Gender",
Type: code.SimpleType("Gender"),
}),
},
{
Name: "mismatched method parameter type",
Method: code.Method{
@ -1136,8 +1232,8 @@ func TestParseInterfaceMethod_Find_Invalid(t *testing.T) {
t.Run(testCase.Name, func(t *testing.T) {
_, err := spec.ParseInterfaceMethod(structModel, testCase.Method)
if err != testCase.ExpectedError {
t.Errorf("\nExpected = %v\nReceived = %v", testCase.ExpectedError, err)
if err.Error() != testCase.ExpectedError.Error() {
t.Errorf("\nExpected = %v\nReceived = %v", testCase.ExpectedError.Error(), err.Error())
}
})
}

View file

@ -43,22 +43,31 @@ const (
ComparatorGreaterThanEqual Comparator = "GREATER_THAN_EQUAL"
ComparatorBetween Comparator = "BETWEEN"
ComparatorIn Comparator = "IN"
ComparatorNotIn Comparator = "NOT_IN"
ComparatorTrue Comparator = "EQUAL_TRUE"
ComparatorFalse Comparator = "EQUAL_FALSE"
)
// ArgumentTypeFromFieldType returns a type of required argument from the given struct field type
func (c Comparator) ArgumentTypeFromFieldType(t code.Type) code.Type {
if c == ComparatorIn {
switch c {
case ComparatorIn, ComparatorNotIn:
return code.ArrayType{ContainedType: t}
default:
return t
}
return t
}
// NumberOfArguments returns the number of arguments required to perform the comparison
func (c Comparator) NumberOfArguments() int {
if c == ComparatorBetween {
switch c {
case ComparatorBetween:
return 2
case ComparatorTrue, ComparatorFalse:
return 0
default:
return 1
}
return 1
}
// Predicate is a criteria for querying a field
@ -86,12 +95,21 @@ func (t predicateToken) ToPredicate(paramIndex int) Predicate {
if len(t) > 3 && t[len(t)-3] == "Greater" && t[len(t)-2] == "Than" && t[len(t)-1] == "Equal" {
return Predicate{Field: strings.Join(t[:len(t)-3], ""), Comparator: ComparatorGreaterThanEqual, ParamIndex: paramIndex}
}
if len(t) > 2 && t[len(t)-2] == "Not" && t[len(t)-1] == "In" {
return Predicate{Field: strings.Join(t[:len(t)-2], ""), Comparator: ComparatorNotIn, ParamIndex: paramIndex}
}
if len(t) > 1 && t[len(t)-1] == "In" {
return Predicate{Field: strings.Join(t[:len(t)-1], ""), Comparator: ComparatorIn, ParamIndex: paramIndex}
}
if len(t) > 1 && t[len(t)-1] == "Between" {
return Predicate{Field: strings.Join(t[:len(t)-1], ""), Comparator: ComparatorBetween, ParamIndex: paramIndex}
}
if len(t) > 1 && t[len(t)-1] == "True" {
return Predicate{Field: strings.Join(t[:len(t)-1], ""), Comparator: ComparatorTrue, ParamIndex: paramIndex}
}
if len(t) > 1 && t[len(t)-1] == "False" {
return Predicate{Field: strings.Join(t[:len(t)-1], ""), Comparator: ComparatorFalse, ParamIndex: paramIndex}
}
return Predicate{Field: strings.Join(t, ""), Comparator: ComparatorEqual, ParamIndex: paramIndex}
}