From 2383e5da862949bb99501e94ae7ff96335bfd2a3 Mon Sep 17 00:00:00 2001 From: sunboyy Date: Tue, 23 Feb 2021 19:10:25 +0700 Subject: [PATCH] Add NotIn, True and False comparators --- internal/mongo/generator_test.go | 114 +++++++++++++++++++++++++++++++ internal/mongo/models.go | 6 ++ internal/spec/errors.go | 20 ++++++ internal/spec/errors_test.go | 9 +++ internal/spec/parser.go | 5 ++ internal/spec/parser_test.go | 100 ++++++++++++++++++++++++++- internal/spec/query.go | 26 +++++-- 7 files changed, 274 insertions(+), 6 deletions(-) diff --git a/internal/mongo/generator_test.go b/internal/mongo/generator_test.go index 9ed9a58..d44a389 100644 --- a/internal/mongo/generator_test.go +++ b/internal/mongo/generator_test.go @@ -33,6 +33,11 @@ var userModel = code.Struct{ Type: code.SimpleType("int"), Tags: map[string][]string{"bson": {"age"}}, }, + { + Name: "Enabled", + Type: code.SimpleType("bool"), + Tags: map[string][]string{"bson": {"enabled"}}, + }, { Name: "AccessToken", Type: code.SimpleType("string"), @@ -575,6 +580,115 @@ func (r *UserRepositoryMongo) FindByGenderIn(arg0 context.Context, arg1 []Gender } return entities, nil } +`, + }, + { + Name: "find with NotIn comparator", + MethodSpec: spec.MethodSpec{ + Name: "FindByGenderNotIn", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Name: "gender", Type: code.ArrayType{ContainedType: code.SimpleType("Gender")}}, + }, + Returns: []code.Type{ + code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + code.SimpleType("error"), + }, + Operation: spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + {Comparator: spec.ComparatorNotIn, Field: "Gender", ParamIndex: 1}, + }, + }, + }, + }, + ExpectedCode: ` +func (r *UserRepositoryMongo) FindByGenderNotIn(arg0 context.Context, arg1 []Gender) ([]*UserModel, error) { + cursor, err := r.collection.Find(arg0, bson.M{ + "gender": bson.M{"$nin": arg1}, + }) + if err != nil { + return nil, err + } + var entities []*UserModel + if err := cursor.All(arg0, &entities); err != nil { + return nil, err + } + return entities, nil +} +`, + }, + { + Name: "find with True comparator", + MethodSpec: spec.MethodSpec{ + Name: "FindByEnabledTrue", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + }, + Returns: []code.Type{ + code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + code.SimpleType("error"), + }, + Operation: spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + {Comparator: spec.ComparatorTrue, Field: "Enabled", ParamIndex: 1}, + }, + }, + }, + }, + ExpectedCode: ` +func (r *UserRepositoryMongo) FindByEnabledTrue(arg0 context.Context) ([]*UserModel, error) { + cursor, err := r.collection.Find(arg0, bson.M{ + "enabled": true, + }) + if err != nil { + return nil, err + } + var entities []*UserModel + if err := cursor.All(arg0, &entities); err != nil { + return nil, err + } + return entities, nil +} +`, + }, + { + Name: "find with False comparator", + MethodSpec: spec.MethodSpec{ + Name: "FindByEnabledFalse", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + }, + Returns: []code.Type{ + code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + code.SimpleType("error"), + }, + Operation: spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + {Comparator: spec.ComparatorFalse, Field: "Enabled", ParamIndex: 1}, + }, + }, + }, + }, + ExpectedCode: ` +func (r *UserRepositoryMongo) FindByEnabledFalse(arg0 context.Context) ([]*UserModel, error) { + cursor, err := r.collection.Find(arg0, bson.M{ + "enabled": false, + }) + if err != nil { + return nil, err + } + var entities []*UserModel + if err := cursor.All(arg0, &entities); err != nil { + return nil, err + } + return entities, nil +} `, }, } diff --git a/internal/mongo/models.go b/internal/mongo/models.go index db2e3bf..849d092 100644 --- a/internal/mongo/models.go +++ b/internal/mongo/models.go @@ -69,6 +69,12 @@ func (p predicate) Code() string { return fmt.Sprintf(`"%s": bson.M{"$gte": arg%d, "$lte": arg%d}`, p.Field, p.ParamIndex, p.ParamIndex+1) case spec.ComparatorIn: return fmt.Sprintf(`"%s": bson.M{"$in": arg%d}`, p.Field, p.ParamIndex) + case spec.ComparatorNotIn: + return fmt.Sprintf(`"%s": bson.M{"$nin": arg%d}`, p.Field, p.ParamIndex) + case spec.ComparatorTrue: + return fmt.Sprintf(`"%s": true`, p.Field) + case spec.ComparatorFalse: + return fmt.Sprintf(`"%s": false`, p.Field) } return "" } diff --git a/internal/spec/errors.go b/internal/spec/errors.go index 072329c..a25f07a 100644 --- a/internal/spec/errors.go +++ b/internal/spec/errors.go @@ -3,6 +3,8 @@ package spec import ( "fmt" "strings" + + "github.com/sunboyy/repogen/internal/code" ) // ParsingError is an error from parsing interface methods @@ -71,3 +73,21 @@ type structFieldNotFoundError struct { func (err structFieldNotFoundError) Error() string { return fmt.Sprintf("struct field '%s' not found", err.FieldName) } + +// NewIncompatibleComparatorError creates incompatibleComparatorError +func NewIncompatibleComparatorError(comparator Comparator, field code.StructField) error { + return incompatibleComparatorError{ + Comparator: comparator, + Field: field, + } +} + +type incompatibleComparatorError struct { + Comparator Comparator + Field code.StructField +} + +func (err incompatibleComparatorError) Error() string { + return fmt.Sprintf("cannot use comparator %s with struct field '%s' of type '%s'", + err.Comparator, err.Field.Name, err.Field.Type.Code()) +} diff --git a/internal/spec/errors_test.go b/internal/spec/errors_test.go index e3dfe0e..28f4b6c 100644 --- a/internal/spec/errors_test.go +++ b/internal/spec/errors_test.go @@ -3,6 +3,7 @@ package spec_test import ( "testing" + "github.com/sunboyy/repogen/internal/code" "github.com/sunboyy/repogen/internal/spec" ) @@ -29,6 +30,14 @@ func TestError(t *testing.T) { Error: spec.NewInvalidQueryError([]string{"By", "And"}), ExpectedString: "invalid query 'ByAnd'", }, + { + Name: "IncompatibleComparatorError", + Error: spec.NewIncompatibleComparatorError(spec.ComparatorTrue, code.StructField{ + Name: "Age", + Type: code.SimpleType("int"), + }), + ExpectedString: "cannot use comparator EQUAL_TRUE with struct field 'Age' of type 'int'", + }, } for _, testCase := range testTable { diff --git a/internal/spec/parser.go b/internal/spec/parser.go index 9e70bde..23e0334 100644 --- a/internal/spec/parser.go +++ b/internal/spec/parser.go @@ -343,6 +343,11 @@ func (p interfaceMethodParser) validateQueryFromParams(params []code.Param, quer return NewStructFieldNotFoundError(predicate.Field) } + if (predicate.Comparator == ComparatorTrue || predicate.Comparator == ComparatorFalse) && + structField.Type != code.SimpleType("bool") { + return NewIncompatibleComparatorError(predicate.Comparator, structField) + } + for i := 0; i < predicate.Comparator.NumberOfArguments(); i++ { if params[currentParamIndex].Type != predicate.Comparator.ArgumentTypeFromFieldType( structField.Type) { diff --git a/internal/spec/parser_test.go b/internal/spec/parser_test.go index 06bbeb2..5b109f2 100644 --- a/internal/spec/parser_test.go +++ b/internal/spec/parser_test.go @@ -31,6 +31,10 @@ var structModel = code.Struct{ Name: "Age", Type: code.SimpleType("int"), }, + { + Name: "Enabled", + Type: code.SimpleType("bool"), + }, }, } @@ -367,6 +371,64 @@ func TestParseInterfaceMethod_Find(t *testing.T) { }}, }, }, + { + Name: "FindByArgNotIn method", + Method: code.Method{ + Name: "FindByCityNotIn", + Params: []code.Param{ + {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Type: code.ArrayType{ContainedType: code.SimpleType("string")}}, + }, + Returns: []code.Type{ + code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + code.SimpleType("error"), + }, + }, + ExpectedOperation: spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + {Field: "City", Comparator: spec.ComparatorNotIn, ParamIndex: 1}, + }}, + }, + }, + { + Name: "FindByArgTrue method", + Method: code.Method{ + Name: "FindByEnabledTrue", + Params: []code.Param{ + {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + }, + Returns: []code.Type{ + code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + code.SimpleType("error"), + }, + }, + ExpectedOperation: spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + {Field: "Enabled", Comparator: spec.ComparatorTrue, ParamIndex: 1}, + }}, + }, + }, + { + Name: "FindByArgFalse method", + Method: code.Method{ + Name: "FindByEnabledFalse", + Params: []code.Param{ + {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + }, + Returns: []code.Type{ + code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + code.SimpleType("error"), + }, + }, + ExpectedOperation: spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + {Field: "Enabled", Comparator: spec.ComparatorFalse, ParamIndex: 1}, + }}, + }, + }, } for _, testCase := range testTable { @@ -1100,6 +1162,40 @@ func TestParseInterfaceMethod_Find_Invalid(t *testing.T) { }, ExpectedError: spec.NewStructFieldNotFoundError("Country"), }, + { + Name: "incompatible struct field for True comparator", + Method: code.Method{ + Name: "FindByGenderTrue", + Params: []code.Param{ + {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + }, + Returns: []code.Type{ + code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + code.SimpleType("error"), + }, + }, + ExpectedError: spec.NewIncompatibleComparatorError(spec.ComparatorTrue, code.StructField{ + Name: "Gender", + Type: code.SimpleType("Gender"), + }), + }, + { + Name: "incompatible struct field for False comparator", + Method: code.Method{ + Name: "FindByGenderFalse", + Params: []code.Param{ + {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + }, + Returns: []code.Type{ + code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + code.SimpleType("error"), + }, + }, + ExpectedError: spec.NewIncompatibleComparatorError(spec.ComparatorFalse, code.StructField{ + Name: "Gender", + Type: code.SimpleType("Gender"), + }), + }, { Name: "mismatched method parameter type", Method: code.Method{ @@ -1136,8 +1232,8 @@ func TestParseInterfaceMethod_Find_Invalid(t *testing.T) { t.Run(testCase.Name, func(t *testing.T) { _, err := spec.ParseInterfaceMethod(structModel, testCase.Method) - if err != testCase.ExpectedError { - t.Errorf("\nExpected = %v\nReceived = %v", testCase.ExpectedError, err) + if err.Error() != testCase.ExpectedError.Error() { + t.Errorf("\nExpected = %v\nReceived = %v", testCase.ExpectedError.Error(), err.Error()) } }) } diff --git a/internal/spec/query.go b/internal/spec/query.go index b8eb1f0..6a740a5 100644 --- a/internal/spec/query.go +++ b/internal/spec/query.go @@ -43,22 +43,31 @@ const ( ComparatorGreaterThanEqual Comparator = "GREATER_THAN_EQUAL" ComparatorBetween Comparator = "BETWEEN" ComparatorIn Comparator = "IN" + ComparatorNotIn Comparator = "NOT_IN" + ComparatorTrue Comparator = "EQUAL_TRUE" + ComparatorFalse Comparator = "EQUAL_FALSE" ) // ArgumentTypeFromFieldType returns a type of required argument from the given struct field type func (c Comparator) ArgumentTypeFromFieldType(t code.Type) code.Type { - if c == ComparatorIn { + switch c { + case ComparatorIn, ComparatorNotIn: return code.ArrayType{ContainedType: t} + default: + return t } - return t } // NumberOfArguments returns the number of arguments required to perform the comparison func (c Comparator) NumberOfArguments() int { - if c == ComparatorBetween { + switch c { + case ComparatorBetween: return 2 + case ComparatorTrue, ComparatorFalse: + return 0 + default: + return 1 } - return 1 } // Predicate is a criteria for querying a field @@ -86,12 +95,21 @@ func (t predicateToken) ToPredicate(paramIndex int) Predicate { if len(t) > 3 && t[len(t)-3] == "Greater" && t[len(t)-2] == "Than" && t[len(t)-1] == "Equal" { return Predicate{Field: strings.Join(t[:len(t)-3], ""), Comparator: ComparatorGreaterThanEqual, ParamIndex: paramIndex} } + if len(t) > 2 && t[len(t)-2] == "Not" && t[len(t)-1] == "In" { + return Predicate{Field: strings.Join(t[:len(t)-2], ""), Comparator: ComparatorNotIn, ParamIndex: paramIndex} + } if len(t) > 1 && t[len(t)-1] == "In" { return Predicate{Field: strings.Join(t[:len(t)-1], ""), Comparator: ComparatorIn, ParamIndex: paramIndex} } if len(t) > 1 && t[len(t)-1] == "Between" { return Predicate{Field: strings.Join(t[:len(t)-1], ""), Comparator: ComparatorBetween, ParamIndex: paramIndex} } + if len(t) > 1 && t[len(t)-1] == "True" { + return Predicate{Field: strings.Join(t[:len(t)-1], ""), Comparator: ComparatorTrue, ParamIndex: paramIndex} + } + if len(t) > 1 && t[len(t)-1] == "False" { + return Predicate{Field: strings.Join(t[:len(t)-1], ""), Comparator: ComparatorFalse, ParamIndex: paramIndex} + } return Predicate{Field: strings.Join(t, ""), Comparator: ComparatorEqual, ParamIndex: paramIndex} }