Add NotIn, True and False comparators
This commit is contained in:
parent
cbcaf377d2
commit
2383e5da86
7 changed files with 274 additions and 6 deletions
|
@ -33,6 +33,11 @@ var userModel = code.Struct{
|
||||||
Type: code.SimpleType("int"),
|
Type: code.SimpleType("int"),
|
||||||
Tags: map[string][]string{"bson": {"age"}},
|
Tags: map[string][]string{"bson": {"age"}},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Name: "Enabled",
|
||||||
|
Type: code.SimpleType("bool"),
|
||||||
|
Tags: map[string][]string{"bson": {"enabled"}},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
Name: "AccessToken",
|
Name: "AccessToken",
|
||||||
Type: code.SimpleType("string"),
|
Type: code.SimpleType("string"),
|
||||||
|
@ -575,6 +580,115 @@ func (r *UserRepositoryMongo) FindByGenderIn(arg0 context.Context, arg1 []Gender
|
||||||
}
|
}
|
||||||
return entities, nil
|
return entities, nil
|
||||||
}
|
}
|
||||||
|
`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "find with NotIn comparator",
|
||||||
|
MethodSpec: spec.MethodSpec{
|
||||||
|
Name: "FindByGenderNotIn",
|
||||||
|
Params: []code.Param{
|
||||||
|
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
|
||||||
|
{Name: "gender", Type: code.ArrayType{ContainedType: code.SimpleType("Gender")}},
|
||||||
|
},
|
||||||
|
Returns: []code.Type{
|
||||||
|
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
|
||||||
|
code.SimpleType("error"),
|
||||||
|
},
|
||||||
|
Operation: spec.FindOperation{
|
||||||
|
Mode: spec.QueryModeMany,
|
||||||
|
Query: spec.QuerySpec{
|
||||||
|
Predicates: []spec.Predicate{
|
||||||
|
{Comparator: spec.ComparatorNotIn, Field: "Gender", ParamIndex: 1},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ExpectedCode: `
|
||||||
|
func (r *UserRepositoryMongo) FindByGenderNotIn(arg0 context.Context, arg1 []Gender) ([]*UserModel, error) {
|
||||||
|
cursor, err := r.collection.Find(arg0, bson.M{
|
||||||
|
"gender": bson.M{"$nin": arg1},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var entities []*UserModel
|
||||||
|
if err := cursor.All(arg0, &entities); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return entities, nil
|
||||||
|
}
|
||||||
|
`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "find with True comparator",
|
||||||
|
MethodSpec: spec.MethodSpec{
|
||||||
|
Name: "FindByEnabledTrue",
|
||||||
|
Params: []code.Param{
|
||||||
|
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
|
||||||
|
},
|
||||||
|
Returns: []code.Type{
|
||||||
|
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
|
||||||
|
code.SimpleType("error"),
|
||||||
|
},
|
||||||
|
Operation: spec.FindOperation{
|
||||||
|
Mode: spec.QueryModeMany,
|
||||||
|
Query: spec.QuerySpec{
|
||||||
|
Predicates: []spec.Predicate{
|
||||||
|
{Comparator: spec.ComparatorTrue, Field: "Enabled", ParamIndex: 1},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ExpectedCode: `
|
||||||
|
func (r *UserRepositoryMongo) FindByEnabledTrue(arg0 context.Context) ([]*UserModel, error) {
|
||||||
|
cursor, err := r.collection.Find(arg0, bson.M{
|
||||||
|
"enabled": true,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var entities []*UserModel
|
||||||
|
if err := cursor.All(arg0, &entities); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return entities, nil
|
||||||
|
}
|
||||||
|
`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "find with False comparator",
|
||||||
|
MethodSpec: spec.MethodSpec{
|
||||||
|
Name: "FindByEnabledFalse",
|
||||||
|
Params: []code.Param{
|
||||||
|
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
|
||||||
|
},
|
||||||
|
Returns: []code.Type{
|
||||||
|
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
|
||||||
|
code.SimpleType("error"),
|
||||||
|
},
|
||||||
|
Operation: spec.FindOperation{
|
||||||
|
Mode: spec.QueryModeMany,
|
||||||
|
Query: spec.QuerySpec{
|
||||||
|
Predicates: []spec.Predicate{
|
||||||
|
{Comparator: spec.ComparatorFalse, Field: "Enabled", ParamIndex: 1},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ExpectedCode: `
|
||||||
|
func (r *UserRepositoryMongo) FindByEnabledFalse(arg0 context.Context) ([]*UserModel, error) {
|
||||||
|
cursor, err := r.collection.Find(arg0, bson.M{
|
||||||
|
"enabled": false,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var entities []*UserModel
|
||||||
|
if err := cursor.All(arg0, &entities); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return entities, nil
|
||||||
|
}
|
||||||
`,
|
`,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
|
@ -69,6 +69,12 @@ func (p predicate) Code() string {
|
||||||
return fmt.Sprintf(`"%s": bson.M{"$gte": arg%d, "$lte": arg%d}`, p.Field, p.ParamIndex, p.ParamIndex+1)
|
return fmt.Sprintf(`"%s": bson.M{"$gte": arg%d, "$lte": arg%d}`, p.Field, p.ParamIndex, p.ParamIndex+1)
|
||||||
case spec.ComparatorIn:
|
case spec.ComparatorIn:
|
||||||
return fmt.Sprintf(`"%s": bson.M{"$in": arg%d}`, p.Field, p.ParamIndex)
|
return fmt.Sprintf(`"%s": bson.M{"$in": arg%d}`, p.Field, p.ParamIndex)
|
||||||
|
case spec.ComparatorNotIn:
|
||||||
|
return fmt.Sprintf(`"%s": bson.M{"$nin": arg%d}`, p.Field, p.ParamIndex)
|
||||||
|
case spec.ComparatorTrue:
|
||||||
|
return fmt.Sprintf(`"%s": true`, p.Field)
|
||||||
|
case spec.ComparatorFalse:
|
||||||
|
return fmt.Sprintf(`"%s": false`, p.Field)
|
||||||
}
|
}
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,6 +3,8 @@ package spec
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/sunboyy/repogen/internal/code"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ParsingError is an error from parsing interface methods
|
// ParsingError is an error from parsing interface methods
|
||||||
|
@ -71,3 +73,21 @@ type structFieldNotFoundError struct {
|
||||||
func (err structFieldNotFoundError) Error() string {
|
func (err structFieldNotFoundError) Error() string {
|
||||||
return fmt.Sprintf("struct field '%s' not found", err.FieldName)
|
return fmt.Sprintf("struct field '%s' not found", err.FieldName)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewIncompatibleComparatorError creates incompatibleComparatorError
|
||||||
|
func NewIncompatibleComparatorError(comparator Comparator, field code.StructField) error {
|
||||||
|
return incompatibleComparatorError{
|
||||||
|
Comparator: comparator,
|
||||||
|
Field: field,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type incompatibleComparatorError struct {
|
||||||
|
Comparator Comparator
|
||||||
|
Field code.StructField
|
||||||
|
}
|
||||||
|
|
||||||
|
func (err incompatibleComparatorError) Error() string {
|
||||||
|
return fmt.Sprintf("cannot use comparator %s with struct field '%s' of type '%s'",
|
||||||
|
err.Comparator, err.Field.Name, err.Field.Type.Code())
|
||||||
|
}
|
||||||
|
|
|
@ -3,6 +3,7 @@ package spec_test
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/sunboyy/repogen/internal/code"
|
||||||
"github.com/sunboyy/repogen/internal/spec"
|
"github.com/sunboyy/repogen/internal/spec"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -29,6 +30,14 @@ func TestError(t *testing.T) {
|
||||||
Error: spec.NewInvalidQueryError([]string{"By", "And"}),
|
Error: spec.NewInvalidQueryError([]string{"By", "And"}),
|
||||||
ExpectedString: "invalid query 'ByAnd'",
|
ExpectedString: "invalid query 'ByAnd'",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Name: "IncompatibleComparatorError",
|
||||||
|
Error: spec.NewIncompatibleComparatorError(spec.ComparatorTrue, code.StructField{
|
||||||
|
Name: "Age",
|
||||||
|
Type: code.SimpleType("int"),
|
||||||
|
}),
|
||||||
|
ExpectedString: "cannot use comparator EQUAL_TRUE with struct field 'Age' of type 'int'",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, testCase := range testTable {
|
for _, testCase := range testTable {
|
||||||
|
|
|
@ -343,6 +343,11 @@ func (p interfaceMethodParser) validateQueryFromParams(params []code.Param, quer
|
||||||
return NewStructFieldNotFoundError(predicate.Field)
|
return NewStructFieldNotFoundError(predicate.Field)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (predicate.Comparator == ComparatorTrue || predicate.Comparator == ComparatorFalse) &&
|
||||||
|
structField.Type != code.SimpleType("bool") {
|
||||||
|
return NewIncompatibleComparatorError(predicate.Comparator, structField)
|
||||||
|
}
|
||||||
|
|
||||||
for i := 0; i < predicate.Comparator.NumberOfArguments(); i++ {
|
for i := 0; i < predicate.Comparator.NumberOfArguments(); i++ {
|
||||||
if params[currentParamIndex].Type != predicate.Comparator.ArgumentTypeFromFieldType(
|
if params[currentParamIndex].Type != predicate.Comparator.ArgumentTypeFromFieldType(
|
||||||
structField.Type) {
|
structField.Type) {
|
||||||
|
|
|
@ -31,6 +31,10 @@ var structModel = code.Struct{
|
||||||
Name: "Age",
|
Name: "Age",
|
||||||
Type: code.SimpleType("int"),
|
Type: code.SimpleType("int"),
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Name: "Enabled",
|
||||||
|
Type: code.SimpleType("bool"),
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -367,6 +371,64 @@ func TestParseInterfaceMethod_Find(t *testing.T) {
|
||||||
}},
|
}},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Name: "FindByArgNotIn method",
|
||||||
|
Method: code.Method{
|
||||||
|
Name: "FindByCityNotIn",
|
||||||
|
Params: []code.Param{
|
||||||
|
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
|
||||||
|
{Type: code.ArrayType{ContainedType: code.SimpleType("string")}},
|
||||||
|
},
|
||||||
|
Returns: []code.Type{
|
||||||
|
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
|
||||||
|
code.SimpleType("error"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ExpectedOperation: spec.FindOperation{
|
||||||
|
Mode: spec.QueryModeMany,
|
||||||
|
Query: spec.QuerySpec{Predicates: []spec.Predicate{
|
||||||
|
{Field: "City", Comparator: spec.ComparatorNotIn, ParamIndex: 1},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "FindByArgTrue method",
|
||||||
|
Method: code.Method{
|
||||||
|
Name: "FindByEnabledTrue",
|
||||||
|
Params: []code.Param{
|
||||||
|
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
|
||||||
|
},
|
||||||
|
Returns: []code.Type{
|
||||||
|
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
|
||||||
|
code.SimpleType("error"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ExpectedOperation: spec.FindOperation{
|
||||||
|
Mode: spec.QueryModeMany,
|
||||||
|
Query: spec.QuerySpec{Predicates: []spec.Predicate{
|
||||||
|
{Field: "Enabled", Comparator: spec.ComparatorTrue, ParamIndex: 1},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "FindByArgFalse method",
|
||||||
|
Method: code.Method{
|
||||||
|
Name: "FindByEnabledFalse",
|
||||||
|
Params: []code.Param{
|
||||||
|
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
|
||||||
|
},
|
||||||
|
Returns: []code.Type{
|
||||||
|
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
|
||||||
|
code.SimpleType("error"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ExpectedOperation: spec.FindOperation{
|
||||||
|
Mode: spec.QueryModeMany,
|
||||||
|
Query: spec.QuerySpec{Predicates: []spec.Predicate{
|
||||||
|
{Field: "Enabled", Comparator: spec.ComparatorFalse, ParamIndex: 1},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, testCase := range testTable {
|
for _, testCase := range testTable {
|
||||||
|
@ -1100,6 +1162,40 @@ func TestParseInterfaceMethod_Find_Invalid(t *testing.T) {
|
||||||
},
|
},
|
||||||
ExpectedError: spec.NewStructFieldNotFoundError("Country"),
|
ExpectedError: spec.NewStructFieldNotFoundError("Country"),
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Name: "incompatible struct field for True comparator",
|
||||||
|
Method: code.Method{
|
||||||
|
Name: "FindByGenderTrue",
|
||||||
|
Params: []code.Param{
|
||||||
|
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
|
||||||
|
},
|
||||||
|
Returns: []code.Type{
|
||||||
|
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
|
||||||
|
code.SimpleType("error"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ExpectedError: spec.NewIncompatibleComparatorError(spec.ComparatorTrue, code.StructField{
|
||||||
|
Name: "Gender",
|
||||||
|
Type: code.SimpleType("Gender"),
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "incompatible struct field for False comparator",
|
||||||
|
Method: code.Method{
|
||||||
|
Name: "FindByGenderFalse",
|
||||||
|
Params: []code.Param{
|
||||||
|
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
|
||||||
|
},
|
||||||
|
Returns: []code.Type{
|
||||||
|
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
|
||||||
|
code.SimpleType("error"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ExpectedError: spec.NewIncompatibleComparatorError(spec.ComparatorFalse, code.StructField{
|
||||||
|
Name: "Gender",
|
||||||
|
Type: code.SimpleType("Gender"),
|
||||||
|
}),
|
||||||
|
},
|
||||||
{
|
{
|
||||||
Name: "mismatched method parameter type",
|
Name: "mismatched method parameter type",
|
||||||
Method: code.Method{
|
Method: code.Method{
|
||||||
|
@ -1136,8 +1232,8 @@ func TestParseInterfaceMethod_Find_Invalid(t *testing.T) {
|
||||||
t.Run(testCase.Name, func(t *testing.T) {
|
t.Run(testCase.Name, func(t *testing.T) {
|
||||||
_, err := spec.ParseInterfaceMethod(structModel, testCase.Method)
|
_, err := spec.ParseInterfaceMethod(structModel, testCase.Method)
|
||||||
|
|
||||||
if err != testCase.ExpectedError {
|
if err.Error() != testCase.ExpectedError.Error() {
|
||||||
t.Errorf("\nExpected = %v\nReceived = %v", testCase.ExpectedError, err)
|
t.Errorf("\nExpected = %v\nReceived = %v", testCase.ExpectedError.Error(), err.Error())
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -43,22 +43,31 @@ const (
|
||||||
ComparatorGreaterThanEqual Comparator = "GREATER_THAN_EQUAL"
|
ComparatorGreaterThanEqual Comparator = "GREATER_THAN_EQUAL"
|
||||||
ComparatorBetween Comparator = "BETWEEN"
|
ComparatorBetween Comparator = "BETWEEN"
|
||||||
ComparatorIn Comparator = "IN"
|
ComparatorIn Comparator = "IN"
|
||||||
|
ComparatorNotIn Comparator = "NOT_IN"
|
||||||
|
ComparatorTrue Comparator = "EQUAL_TRUE"
|
||||||
|
ComparatorFalse Comparator = "EQUAL_FALSE"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ArgumentTypeFromFieldType returns a type of required argument from the given struct field type
|
// ArgumentTypeFromFieldType returns a type of required argument from the given struct field type
|
||||||
func (c Comparator) ArgumentTypeFromFieldType(t code.Type) code.Type {
|
func (c Comparator) ArgumentTypeFromFieldType(t code.Type) code.Type {
|
||||||
if c == ComparatorIn {
|
switch c {
|
||||||
|
case ComparatorIn, ComparatorNotIn:
|
||||||
return code.ArrayType{ContainedType: t}
|
return code.ArrayType{ContainedType: t}
|
||||||
|
default:
|
||||||
|
return t
|
||||||
}
|
}
|
||||||
return t
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NumberOfArguments returns the number of arguments required to perform the comparison
|
// NumberOfArguments returns the number of arguments required to perform the comparison
|
||||||
func (c Comparator) NumberOfArguments() int {
|
func (c Comparator) NumberOfArguments() int {
|
||||||
if c == ComparatorBetween {
|
switch c {
|
||||||
|
case ComparatorBetween:
|
||||||
return 2
|
return 2
|
||||||
|
case ComparatorTrue, ComparatorFalse:
|
||||||
|
return 0
|
||||||
|
default:
|
||||||
|
return 1
|
||||||
}
|
}
|
||||||
return 1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Predicate is a criteria for querying a field
|
// Predicate is a criteria for querying a field
|
||||||
|
@ -86,12 +95,21 @@ func (t predicateToken) ToPredicate(paramIndex int) Predicate {
|
||||||
if len(t) > 3 && t[len(t)-3] == "Greater" && t[len(t)-2] == "Than" && t[len(t)-1] == "Equal" {
|
if len(t) > 3 && t[len(t)-3] == "Greater" && t[len(t)-2] == "Than" && t[len(t)-1] == "Equal" {
|
||||||
return Predicate{Field: strings.Join(t[:len(t)-3], ""), Comparator: ComparatorGreaterThanEqual, ParamIndex: paramIndex}
|
return Predicate{Field: strings.Join(t[:len(t)-3], ""), Comparator: ComparatorGreaterThanEqual, ParamIndex: paramIndex}
|
||||||
}
|
}
|
||||||
|
if len(t) > 2 && t[len(t)-2] == "Not" && t[len(t)-1] == "In" {
|
||||||
|
return Predicate{Field: strings.Join(t[:len(t)-2], ""), Comparator: ComparatorNotIn, ParamIndex: paramIndex}
|
||||||
|
}
|
||||||
if len(t) > 1 && t[len(t)-1] == "In" {
|
if len(t) > 1 && t[len(t)-1] == "In" {
|
||||||
return Predicate{Field: strings.Join(t[:len(t)-1], ""), Comparator: ComparatorIn, ParamIndex: paramIndex}
|
return Predicate{Field: strings.Join(t[:len(t)-1], ""), Comparator: ComparatorIn, ParamIndex: paramIndex}
|
||||||
}
|
}
|
||||||
if len(t) > 1 && t[len(t)-1] == "Between" {
|
if len(t) > 1 && t[len(t)-1] == "Between" {
|
||||||
return Predicate{Field: strings.Join(t[:len(t)-1], ""), Comparator: ComparatorBetween, ParamIndex: paramIndex}
|
return Predicate{Field: strings.Join(t[:len(t)-1], ""), Comparator: ComparatorBetween, ParamIndex: paramIndex}
|
||||||
}
|
}
|
||||||
|
if len(t) > 1 && t[len(t)-1] == "True" {
|
||||||
|
return Predicate{Field: strings.Join(t[:len(t)-1], ""), Comparator: ComparatorTrue, ParamIndex: paramIndex}
|
||||||
|
}
|
||||||
|
if len(t) > 1 && t[len(t)-1] == "False" {
|
||||||
|
return Predicate{Field: strings.Join(t[:len(t)-1], ""), Comparator: ComparatorFalse, ParamIndex: paramIndex}
|
||||||
|
}
|
||||||
return Predicate{Field: strings.Join(t, ""), Comparator: ComparatorEqual, ParamIndex: paramIndex}
|
return Predicate{Field: strings.Join(t, ""), Comparator: ComparatorEqual, ParamIndex: paramIndex}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue