Merge pull request #24 from sunboyy/update-inc
Add increment update operator
This commit is contained in:
commit
1368386584
12 changed files with 330 additions and 42 deletions
|
@ -88,6 +88,11 @@ func (intf InterfaceType) Code() string {
|
||||||
return `interface{}`
|
return `interface{}`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsNumber returns false
|
||||||
|
func (intf InterfaceType) IsNumber() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// Method is a definition of the method inside the interface
|
// Method is a definition of the method inside the interface
|
||||||
type Method struct {
|
type Method struct {
|
||||||
Name string
|
Name string
|
||||||
|
@ -104,6 +109,7 @@ type Param struct {
|
||||||
// Type is an interface for value types
|
// Type is an interface for value types
|
||||||
type Type interface {
|
type Type interface {
|
||||||
Code() string
|
Code() string
|
||||||
|
IsNumber() bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// SimpleType is a type that can be called directly
|
// SimpleType is a type that can be called directly
|
||||||
|
@ -114,6 +120,13 @@ func (t SimpleType) Code() string {
|
||||||
return string(t)
|
return string(t)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsNumber returns true id a SimpleType is integer or float variants.
|
||||||
|
func (t SimpleType) IsNumber() bool {
|
||||||
|
return t == "uint" || t == "uint8" || t == "uint16" || t == "uint32" || t == "uint64" ||
|
||||||
|
t == "int" || t == "int8" || t == "int16" || t == "int32" || t == "int64" ||
|
||||||
|
t == "float32" || t == "float64"
|
||||||
|
}
|
||||||
|
|
||||||
// ExternalType is a type that is called to another package
|
// ExternalType is a type that is called to another package
|
||||||
type ExternalType struct {
|
type ExternalType struct {
|
||||||
PackageAlias string
|
PackageAlias string
|
||||||
|
@ -125,6 +138,11 @@ func (t ExternalType) Code() string {
|
||||||
return fmt.Sprintf("%s.%s", t.PackageAlias, t.Name)
|
return fmt.Sprintf("%s.%s", t.PackageAlias, t.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsNumber returns false
|
||||||
|
func (t ExternalType) IsNumber() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// PointerType is a model of pointer
|
// PointerType is a model of pointer
|
||||||
type PointerType struct {
|
type PointerType struct {
|
||||||
ContainedType Type
|
ContainedType Type
|
||||||
|
@ -135,6 +153,11 @@ func (t PointerType) Code() string {
|
||||||
return fmt.Sprintf("*%s", t.ContainedType.Code())
|
return fmt.Sprintf("*%s", t.ContainedType.Code())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsNumber returns IsNumber of its contained type
|
||||||
|
func (t PointerType) IsNumber() bool {
|
||||||
|
return t.ContainedType.IsNumber()
|
||||||
|
}
|
||||||
|
|
||||||
// ArrayType is a model of array
|
// ArrayType is a model of array
|
||||||
type ArrayType struct {
|
type ArrayType struct {
|
||||||
ContainedType Type
|
ContainedType Type
|
||||||
|
@ -144,3 +167,8 @@ type ArrayType struct {
|
||||||
func (t ArrayType) Code() string {
|
func (t ArrayType) Code() string {
|
||||||
return fmt.Sprintf("[]%s", t.ContainedType.Code())
|
return fmt.Sprintf("[]%s", t.ContainedType.Code())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsNumber returns false
|
||||||
|
func (t ArrayType) IsNumber() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
|
@ -92,7 +92,7 @@ type TypeCodeTestCase struct {
|
||||||
ExpectedCode string
|
ExpectedCode string
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestArrayTypeCode(t *testing.T) {
|
func TestTypeCode(t *testing.T) {
|
||||||
testTable := []TypeCodeTestCase{
|
testTable := []TypeCodeTestCase{
|
||||||
{
|
{
|
||||||
Name: "simple type",
|
Name: "simple type",
|
||||||
|
@ -126,3 +126,89 @@ func TestArrayTypeCode(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type TypeIsNumberTestCase struct {
|
||||||
|
Name string
|
||||||
|
Type code.Type
|
||||||
|
IsNumber bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTypeIsNumber(t *testing.T) {
|
||||||
|
testTable := []TypeIsNumberTestCase{
|
||||||
|
{
|
||||||
|
Name: "simple type: int",
|
||||||
|
Type: code.SimpleType("int"),
|
||||||
|
IsNumber: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "simple type: other integer variants",
|
||||||
|
Type: code.SimpleType("int64"),
|
||||||
|
IsNumber: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "simple type: uint",
|
||||||
|
Type: code.SimpleType("uint"),
|
||||||
|
IsNumber: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "simple type: other unsigned integer variants",
|
||||||
|
Type: code.SimpleType("uint64"),
|
||||||
|
IsNumber: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "simple type: float32",
|
||||||
|
Type: code.SimpleType("float32"),
|
||||||
|
IsNumber: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "simple type: other float variants",
|
||||||
|
Type: code.SimpleType("float64"),
|
||||||
|
IsNumber: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "simple type: non-number primitive type",
|
||||||
|
Type: code.SimpleType("string"),
|
||||||
|
IsNumber: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "simple type: non-number custom type",
|
||||||
|
Type: code.SimpleType("UserModel"),
|
||||||
|
IsNumber: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "external type",
|
||||||
|
Type: code.ExternalType{PackageAlias: "context", Name: "Context"},
|
||||||
|
IsNumber: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "pointer type: number",
|
||||||
|
Type: code.PointerType{ContainedType: code.SimpleType("int")},
|
||||||
|
IsNumber: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "pointer type: non-number",
|
||||||
|
Type: code.PointerType{ContainedType: code.SimpleType("string")},
|
||||||
|
IsNumber: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "array type",
|
||||||
|
Type: code.ArrayType{ContainedType: code.SimpleType("int")},
|
||||||
|
IsNumber: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "interface type",
|
||||||
|
Type: code.InterfaceType{},
|
||||||
|
IsNumber: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, testCase := range testTable {
|
||||||
|
t.Run(testCase.Name, func(t *testing.T) {
|
||||||
|
isNumber := testCase.Type.IsNumber()
|
||||||
|
|
||||||
|
if isNumber != testCase.IsNumber {
|
||||||
|
t.Errorf("Expected = %+v\nReceived = %+v", testCase.IsNumber, isNumber)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -122,8 +122,8 @@ func (g RepositoryGenerator) generateFindImplementation(operation spec.FindOpera
|
||||||
return generateFromTemplate("mongo_repository_findmany", findManyTemplate, tmplData)
|
return generateFromTemplate("mongo_repository_findmany", findManyTemplate, tmplData)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g RepositoryGenerator) mongoSorts(sortSpec []spec.Sort) ([]sort, error) {
|
func (g RepositoryGenerator) mongoSorts(sortSpec []spec.Sort) ([]findSort, error) {
|
||||||
var sorts []sort
|
var sorts []findSort
|
||||||
|
|
||||||
for _, s := range sortSpec {
|
for _, s := range sortSpec {
|
||||||
bsonFieldReference, err := g.bsonFieldReference(s.FieldReference)
|
bsonFieldReference, err := g.bsonFieldReference(s.FieldReference)
|
||||||
|
@ -131,7 +131,7 @@ func (g RepositoryGenerator) mongoSorts(sortSpec []spec.Sort) ([]sort, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
sorts = append(sorts, sort{
|
sorts = append(sorts, findSort{
|
||||||
BsonTag: bsonFieldReference,
|
BsonTag: bsonFieldReference,
|
||||||
Ordering: s.Ordering,
|
Ordering: s.Ordering,
|
||||||
})
|
})
|
||||||
|
@ -196,6 +196,8 @@ func getUpdateOperatorKey(operator spec.UpdateOperator) string {
|
||||||
return "$set"
|
return "$set"
|
||||||
case spec.UpdateOperatorPush:
|
case spec.UpdateOperatorPush:
|
||||||
return "$push"
|
return "$push"
|
||||||
|
case spec.UpdateOperatorInc:
|
||||||
|
return "$inc"
|
||||||
default:
|
default:
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
|
@ -1063,7 +1063,7 @@ func (r *UserRepositoryMongo) UpdateAgeByGender(arg0 context.Context, arg1 int,
|
||||||
Params: []code.Param{
|
Params: []code.Param{
|
||||||
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
|
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
|
||||||
{Name: "consentHistory", Type: code.SimpleType("ConsentHistory")},
|
{Name: "consentHistory", Type: code.SimpleType("ConsentHistory")},
|
||||||
{Name: "gender", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}},
|
{Name: "id", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}},
|
||||||
},
|
},
|
||||||
Returns: []code.Type{
|
Returns: []code.Type{
|
||||||
code.SimpleType("bool"),
|
code.SimpleType("bool"),
|
||||||
|
@ -1095,6 +1095,47 @@ func (r *UserRepositoryMongo) UpdateConsentHistoryPushByID(arg0 context.Context,
|
||||||
}
|
}
|
||||||
return result.MatchedCount > 0, err
|
return result.MatchedCount > 0, err
|
||||||
}
|
}
|
||||||
|
`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "simple update inc method",
|
||||||
|
MethodSpec: spec.MethodSpec{
|
||||||
|
Name: "UpdateAgeIncByID",
|
||||||
|
Params: []code.Param{
|
||||||
|
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
|
||||||
|
{Name: "age", Type: code.SimpleType("int")},
|
||||||
|
{Name: "id", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}},
|
||||||
|
},
|
||||||
|
Returns: []code.Type{
|
||||||
|
code.SimpleType("bool"),
|
||||||
|
code.SimpleType("error"),
|
||||||
|
},
|
||||||
|
Operation: spec.UpdateOperation{
|
||||||
|
Update: spec.UpdateFields{
|
||||||
|
{FieldReference: spec.FieldReference{ageField}, ParamIndex: 1, Operator: spec.UpdateOperatorInc},
|
||||||
|
},
|
||||||
|
Mode: spec.QueryModeOne,
|
||||||
|
Query: spec.QuerySpec{
|
||||||
|
Predicates: []spec.Predicate{
|
||||||
|
{FieldReference: spec.FieldReference{idField}, Comparator: spec.ComparatorEqual, ParamIndex: 2},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ExpectedCode: `
|
||||||
|
func (r *UserRepositoryMongo) UpdateAgeIncByID(arg0 context.Context, arg1 int, arg2 primitive.ObjectID) (bool, error) {
|
||||||
|
result, err := r.collection.UpdateOne(arg0, bson.M{
|
||||||
|
"_id": arg2,
|
||||||
|
}, bson.M{
|
||||||
|
"$inc": bson.M{
|
||||||
|
"age": arg1,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return result.MatchedCount > 0, err
|
||||||
|
}
|
||||||
`,
|
`,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -1129,12 +1170,12 @@ func (r *UserRepositoryMongo) UpdateEnabledAndConsentHistoryPushByID(arg0 contex
|
||||||
result, err := r.collection.UpdateOne(arg0, bson.M{
|
result, err := r.collection.UpdateOne(arg0, bson.M{
|
||||||
"_id": arg3,
|
"_id": arg3,
|
||||||
}, bson.M{
|
}, bson.M{
|
||||||
"$set": bson.M{
|
|
||||||
"enabled": arg1,
|
|
||||||
},
|
|
||||||
"$push": bson.M{
|
"$push": bson.M{
|
||||||
"consent_history": arg2,
|
"consent_history": arg2,
|
||||||
},
|
},
|
||||||
|
"$set": bson.M{
|
||||||
|
"enabled": arg1,
|
||||||
|
},
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
|
|
|
@ -2,6 +2,7 @@ package mongo
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/sunboyy/repogen/internal/spec"
|
"github.com/sunboyy/repogen/internal/spec"
|
||||||
|
@ -26,11 +27,17 @@ func (u updateModel) Code() string {
|
||||||
type updateFields map[string][]updateField
|
type updateFields map[string][]updateField
|
||||||
|
|
||||||
func (u updateFields) Code() string {
|
func (u updateFields) Code() string {
|
||||||
var lines []string
|
var keys []string
|
||||||
for k, v := range u {
|
for k := range u {
|
||||||
lines = append(lines, fmt.Sprintf(` "%s": bson.M{`, k))
|
keys = append(keys, k)
|
||||||
|
}
|
||||||
|
sort.Strings(keys)
|
||||||
|
|
||||||
for _, field := range v {
|
var lines []string
|
||||||
|
for _, key := range keys {
|
||||||
|
lines = append(lines, fmt.Sprintf(` "%s": bson.M{`, key))
|
||||||
|
|
||||||
|
for _, field := range u[key] {
|
||||||
lines = append(lines, fmt.Sprintf(` "%s": arg%d,`, field.BsonTag, field.ParamIndex))
|
lines = append(lines, fmt.Sprintf(` "%s": arg%d,`, field.BsonTag, field.ParamIndex))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -91,15 +91,15 @@ const insertManyTemplate = ` var entities []interface{}
|
||||||
type mongoFindTemplateData struct {
|
type mongoFindTemplateData struct {
|
||||||
EntityType string
|
EntityType string
|
||||||
QuerySpec querySpec
|
QuerySpec querySpec
|
||||||
Sorts []sort
|
Sorts []findSort
|
||||||
}
|
}
|
||||||
|
|
||||||
type sort struct {
|
type findSort struct {
|
||||||
BsonTag string
|
BsonTag string
|
||||||
Ordering spec.Ordering
|
Ordering spec.Ordering
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s sort) OrderNum() int {
|
func (s findSort) OrderNum() int {
|
||||||
if s.Ordering == spec.OrderingAscending {
|
if s.Ordering == spec.OrderingAscending {
|
||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,8 +22,6 @@ func (err ParsingError) Error() string {
|
||||||
return "update fields are invalid"
|
return "update fields are invalid"
|
||||||
case ContextParamRequiredError:
|
case ContextParamRequiredError:
|
||||||
return "context parameter is required"
|
return "context parameter is required"
|
||||||
case PushNonArrayError:
|
|
||||||
return "cannot use push operation in a non-array type"
|
|
||||||
}
|
}
|
||||||
return string(err)
|
return string(err)
|
||||||
}
|
}
|
||||||
|
@ -35,7 +33,6 @@ const (
|
||||||
InvalidParamError ParsingError = "ERROR_INVALID_PARAM"
|
InvalidParamError ParsingError = "ERROR_INVALID_PARAM"
|
||||||
InvalidUpdateFieldsError ParsingError = "ERROR_INVALID_UPDATE_FIELDS"
|
InvalidUpdateFieldsError ParsingError = "ERROR_INVALID_UPDATE_FIELDS"
|
||||||
ContextParamRequiredError ParsingError = "ERROR_CONTEXT_PARAM_REQUIRED"
|
ContextParamRequiredError ParsingError = "ERROR_CONTEXT_PARAM_REQUIRED"
|
||||||
PushNonArrayError ParsingError = "ERROR_PUSH_NON_ARRAY"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewInvalidQueryError creates invalidQueryError
|
// NewInvalidQueryError creates invalidQueryError
|
||||||
|
@ -64,6 +61,26 @@ func (err invalidSortError) Error() string {
|
||||||
return fmt.Sprintf("invalid sort '%s'", err.SortString)
|
return fmt.Sprintf("invalid sort '%s'", err.SortString)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewArgumentTypeNotMatchedError creates argumentTypeNotMatchedError
|
||||||
|
func NewArgumentTypeNotMatchedError(fieldName string, requiredType code.Type, givenType code.Type) error {
|
||||||
|
return argumentTypeNotMatchedError{
|
||||||
|
FieldName: fieldName,
|
||||||
|
RequiredType: requiredType,
|
||||||
|
GivenType: givenType,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type argumentTypeNotMatchedError struct {
|
||||||
|
FieldName string
|
||||||
|
RequiredType code.Type
|
||||||
|
GivenType code.Type
|
||||||
|
}
|
||||||
|
|
||||||
|
func (err argumentTypeNotMatchedError) Error() string {
|
||||||
|
return fmt.Sprintf("field '%s' requires an argument of type '%s' (got '%s')",
|
||||||
|
err.FieldName, err.RequiredType.Code(), err.GivenType.Code())
|
||||||
|
}
|
||||||
|
|
||||||
// NewUnknownOperationError creates unknownOperationError
|
// NewUnknownOperationError creates unknownOperationError
|
||||||
func NewUnknownOperationError(operationName string) error {
|
func NewUnknownOperationError(operationName string) error {
|
||||||
return unknownOperationError{OperationName: operationName}
|
return unknownOperationError{OperationName: operationName}
|
||||||
|
@ -107,3 +124,23 @@ func (err incompatibleComparatorError) Error() string {
|
||||||
return fmt.Sprintf("cannot use comparator %s with struct field '%s' of type '%s'",
|
return fmt.Sprintf("cannot use comparator %s with struct field '%s' of type '%s'",
|
||||||
err.Comparator, err.Field.Name, err.Field.Type.Code())
|
err.Comparator, err.Field.Name, err.Field.Type.Code())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewIncompatibleUpdateOperatorError creates incompatibleUpdateOperatorError
|
||||||
|
func NewIncompatibleUpdateOperatorError(updateOperator UpdateOperator, fieldReference FieldReference) error {
|
||||||
|
return incompatibleUpdateOperatorError{
|
||||||
|
UpdateOperator: updateOperator,
|
||||||
|
ReferencingCode: fieldReference.ReferencingCode(),
|
||||||
|
ReferencedType: fieldReference.ReferencedField().Type,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type incompatibleUpdateOperatorError struct {
|
||||||
|
UpdateOperator UpdateOperator
|
||||||
|
ReferencingCode string
|
||||||
|
ReferencedType code.Type
|
||||||
|
}
|
||||||
|
|
||||||
|
func (err incompatibleUpdateOperatorError) Error() string {
|
||||||
|
return fmt.Sprintf("cannot use update operator %s with struct field '%s' of type '%s'",
|
||||||
|
err.UpdateOperator, err.ReferencingCode, err.ReferencedType.Code())
|
||||||
|
}
|
||||||
|
|
|
@ -43,6 +43,19 @@ func TestError(t *testing.T) {
|
||||||
Error: spec.NewInvalidSortError([]string{"Order", "By"}),
|
Error: spec.NewInvalidSortError([]string{"Order", "By"}),
|
||||||
ExpectedString: "invalid sort 'OrderBy'",
|
ExpectedString: "invalid sort 'OrderBy'",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Name: "ArgumentTypeNotMatchedError",
|
||||||
|
Error: spec.NewArgumentTypeNotMatchedError("Age", code.SimpleType("int"), code.SimpleType("float64")),
|
||||||
|
ExpectedString: "field 'Age' requires an argument of type 'int' (got 'float64')",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "IncompatibleUpdateOperatorError",
|
||||||
|
Error: spec.NewIncompatibleUpdateOperatorError(spec.UpdateOperatorInc, spec.FieldReference{{
|
||||||
|
Name: "City",
|
||||||
|
Type: code.SimpleType("string"),
|
||||||
|
}}),
|
||||||
|
ExpectedString: "cannot use update operator INC with struct field 'City' of type 'string'",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, testCase := range testTable {
|
for _, testCase := range testTable {
|
||||||
|
|
|
@ -14,6 +14,15 @@ func (r FieldReference) ReferencedField() code.StructField {
|
||||||
return r[len(r)-1]
|
return r[len(r)-1]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ReferencingCode returns a string containing name of the referenced fields concatenating with period (.).
|
||||||
|
func (r FieldReference) ReferencingCode() string {
|
||||||
|
var fieldNames []string
|
||||||
|
for _, field := range r {
|
||||||
|
fieldNames = append(fieldNames, field.Name)
|
||||||
|
}
|
||||||
|
return strings.Join(fieldNames, ".")
|
||||||
|
}
|
||||||
|
|
||||||
type fieldResolver struct {
|
type fieldResolver struct {
|
||||||
Structs code.Structs
|
Structs code.Structs
|
||||||
}
|
}
|
||||||
|
|
|
@ -381,9 +381,11 @@ func (p interfaceMethodParser) validateQueryFromParams(params []code.Param, quer
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < predicate.Comparator.NumberOfArguments(); i++ {
|
for i := 0; i < predicate.Comparator.NumberOfArguments(); i++ {
|
||||||
if params[currentParamIndex].Type != predicate.Comparator.ArgumentTypeFromFieldType(
|
requiredType := predicate.Comparator.ArgumentTypeFromFieldType(predicate.FieldReference.ReferencedField().Type)
|
||||||
predicate.FieldReference.ReferencedField().Type) {
|
|
||||||
return InvalidParamError
|
if params[currentParamIndex].Type != requiredType {
|
||||||
|
return NewArgumentTypeNotMatchedError(predicate.FieldReference.ReferencingCode(), requiredType,
|
||||||
|
params[currentParamIndex].Type)
|
||||||
}
|
}
|
||||||
currentParamIndex++
|
currentParamIndex++
|
||||||
}
|
}
|
||||||
|
|
|
@ -811,6 +811,30 @@ func TestParseInterfaceMethod_Update(t *testing.T) {
|
||||||
}},
|
}},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Name: "UpdateArgPushByArg method",
|
||||||
|
Method: code.Method{
|
||||||
|
Name: "UpdateAgeIncByID",
|
||||||
|
Params: []code.Param{
|
||||||
|
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
|
||||||
|
{Type: code.SimpleType("int")},
|
||||||
|
{Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}},
|
||||||
|
},
|
||||||
|
Returns: []code.Type{
|
||||||
|
code.SimpleType("bool"),
|
||||||
|
code.SimpleType("error"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ExpectedOperation: spec.UpdateOperation{
|
||||||
|
Update: spec.UpdateFields{
|
||||||
|
{FieldReference: spec.FieldReference{ageField}, ParamIndex: 1, Operator: spec.UpdateOperatorInc},
|
||||||
|
},
|
||||||
|
Mode: spec.QueryModeOne,
|
||||||
|
Query: spec.QuerySpec{Predicates: []spec.Predicate{
|
||||||
|
{FieldReference: spec.FieldReference{idField}, Comparator: spec.ComparatorEqual, ParamIndex: 2},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
Name: "UpdateArgAndArgPushByArg method",
|
Name: "UpdateArgAndArgPushByArg method",
|
||||||
Method: code.Method{
|
Method: code.Method{
|
||||||
|
@ -1564,7 +1588,7 @@ func TestParseInterfaceMethod_Find_Invalid(t *testing.T) {
|
||||||
code.SimpleType("error"),
|
code.SimpleType("error"),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
ExpectedError: spec.InvalidParamError,
|
ExpectedError: spec.NewArgumentTypeNotMatchedError(genderField.Name, genderField.Type, code.SimpleType("string")),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "mismatched method parameter type for special case",
|
Name: "mismatched method parameter type for special case",
|
||||||
|
@ -1579,7 +1603,8 @@ func TestParseInterfaceMethod_Find_Invalid(t *testing.T) {
|
||||||
code.SimpleType("error"),
|
code.SimpleType("error"),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
ExpectedError: spec.InvalidParamError,
|
ExpectedError: spec.NewArgumentTypeNotMatchedError(cityField.Name,
|
||||||
|
code.ArrayType{ContainedType: code.SimpleType("string")}, code.SimpleType("string")),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "misplaced operator token (leftmost)",
|
Name: "misplaced operator token (leftmost)",
|
||||||
|
@ -1716,7 +1741,29 @@ func TestParseInterfaceMethod_Update_Invalid(t *testing.T) {
|
||||||
code.SimpleType("error"),
|
code.SimpleType("error"),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
ExpectedError: spec.PushNonArrayError,
|
ExpectedError: spec.NewIncompatibleUpdateOperatorError(spec.UpdateOperatorPush, spec.FieldReference{{
|
||||||
|
Name: "Gender",
|
||||||
|
Type: code.SimpleType("Gender"),
|
||||||
|
}}),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "inc operator in non-number field",
|
||||||
|
Method: code.Method{
|
||||||
|
Name: "UpdateCityIncByID",
|
||||||
|
Params: []code.Param{
|
||||||
|
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
|
||||||
|
{Type: code.SimpleType("Gender")},
|
||||||
|
{Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}},
|
||||||
|
},
|
||||||
|
Returns: []code.Type{
|
||||||
|
code.SimpleType("bool"),
|
||||||
|
code.SimpleType("error"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ExpectedError: spec.NewIncompatibleUpdateOperatorError(spec.UpdateOperatorInc, spec.FieldReference{{
|
||||||
|
Name: "City",
|
||||||
|
Type: code.SimpleType("string"),
|
||||||
|
}}),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "update method without query",
|
Name: "update method without query",
|
||||||
|
@ -1762,7 +1809,8 @@ func TestParseInterfaceMethod_Update_Invalid(t *testing.T) {
|
||||||
code.SimpleType("error"),
|
code.SimpleType("error"),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
ExpectedError: spec.InvalidUpdateFieldsError,
|
ExpectedError: spec.NewArgumentTypeNotMatchedError(consentHistoryField.Name, code.SimpleType("ConsentHistoryItem"),
|
||||||
|
code.ArrayType{ContainedType: code.SimpleType("ConsentHistoryItem")}),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "insufficient function parameters",
|
Name: "insufficient function parameters",
|
||||||
|
@ -1839,7 +1887,7 @@ func TestParseInterfaceMethod_Update_Invalid(t *testing.T) {
|
||||||
code.SimpleType("error"),
|
code.SimpleType("error"),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
ExpectedError: spec.InvalidUpdateFieldsError,
|
ExpectedError: spec.NewArgumentTypeNotMatchedError(ageField.Name, ageField.Type, code.SimpleType("float64")),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "struct field does not match parameter in query",
|
Name: "struct field does not match parameter in query",
|
||||||
|
@ -1855,7 +1903,7 @@ func TestParseInterfaceMethod_Update_Invalid(t *testing.T) {
|
||||||
code.SimpleType("error"),
|
code.SimpleType("error"),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
ExpectedError: spec.InvalidParamError,
|
ExpectedError: spec.NewArgumentTypeNotMatchedError(genderField.Name, genderField.Type, code.SimpleType("string")),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2019,7 +2067,7 @@ func TestParseInterfaceMethod_Delete_Invalid(t *testing.T) {
|
||||||
code.SimpleType("error"),
|
code.SimpleType("error"),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
ExpectedError: spec.InvalidParamError,
|
ExpectedError: spec.NewArgumentTypeNotMatchedError("Gender", code.SimpleType("Gender"), code.SimpleType("string")),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "mismatched method parameter type for special case",
|
Name: "mismatched method parameter type for special case",
|
||||||
|
@ -2034,7 +2082,8 @@ func TestParseInterfaceMethod_Delete_Invalid(t *testing.T) {
|
||||||
code.SimpleType("error"),
|
code.SimpleType("error"),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
ExpectedError: spec.InvalidParamError,
|
ExpectedError: spec.NewArgumentTypeNotMatchedError("City",
|
||||||
|
code.ArrayType{ContainedType: code.SimpleType("string")}, code.SimpleType("string")),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2150,7 +2199,7 @@ func TestParseInterfaceMethod_Count_Invalid(t *testing.T) {
|
||||||
code.SimpleType("error"),
|
code.SimpleType("error"),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
ExpectedError: spec.InvalidParamError,
|
ExpectedError: spec.NewArgumentTypeNotMatchedError("Gender", code.SimpleType("Gender"), code.SimpleType("string")),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "struct field not found",
|
Name: "struct field not found",
|
||||||
|
|
|
@ -61,6 +61,7 @@ type UpdateOperator string
|
||||||
const (
|
const (
|
||||||
UpdateOperatorSet UpdateOperator = "SET"
|
UpdateOperatorSet UpdateOperator = "SET"
|
||||||
UpdateOperatorPush UpdateOperator = "PUSH"
|
UpdateOperatorPush UpdateOperator = "PUSH"
|
||||||
|
UpdateOperatorInc UpdateOperator = "INC"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NumberOfArguments returns number of arguments required to perform an update operation
|
// NumberOfArguments returns number of arguments required to perform an update operation
|
||||||
|
@ -69,16 +70,13 @@ func (o UpdateOperator) NumberOfArguments() int {
|
||||||
}
|
}
|
||||||
|
|
||||||
// ArgumentType returns type that is required for function parameter
|
// ArgumentType returns type that is required for function parameter
|
||||||
func (o UpdateOperator) ArgumentType(fieldType code.Type) (code.Type, error) {
|
func (o UpdateOperator) ArgumentType(fieldType code.Type) code.Type {
|
||||||
switch o {
|
switch o {
|
||||||
case UpdateOperatorPush:
|
case UpdateOperatorPush:
|
||||||
arrayType, ok := fieldType.(code.ArrayType)
|
arrayType := fieldType.(code.ArrayType)
|
||||||
if !ok {
|
return arrayType.ContainedType
|
||||||
return nil, PushNonArrayError
|
|
||||||
}
|
|
||||||
return arrayType.ContainedType, nil
|
|
||||||
default:
|
default:
|
||||||
return fieldType, nil
|
return fieldType
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -147,14 +145,12 @@ func (p interfaceMethodParser) parseUpdate(tokens []string) (Update, error) {
|
||||||
return nil, InvalidUpdateFieldsError
|
return nil, InvalidUpdateFieldsError
|
||||||
}
|
}
|
||||||
|
|
||||||
requiredType, err := field.Operator.ArgumentType(field.FieldReference.ReferencedField().Type)
|
requiredType := field.Operator.ArgumentType(field.FieldReference.ReferencedField().Type)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := 0; i < field.Operator.NumberOfArguments(); i++ {
|
for i := 0; i < field.Operator.NumberOfArguments(); i++ {
|
||||||
if requiredType != p.Method.Params[field.ParamIndex+i].Type {
|
if requiredType != p.Method.Params[field.ParamIndex+i].Type {
|
||||||
return nil, InvalidUpdateFieldsError
|
return nil, NewArgumentTypeNotMatchedError(field.FieldReference.ReferencingCode(), requiredType,
|
||||||
|
p.Method.Params[field.ParamIndex+i].Type)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -166,6 +162,9 @@ func (p interfaceMethodParser) parseUpdateField(t []string, paramIndex int) (Upd
|
||||||
if len(t) > 1 && t[len(t)-1] == "Push" {
|
if len(t) > 1 && t[len(t)-1] == "Push" {
|
||||||
return p.createUpdateField(t[:len(t)-1], UpdateOperatorPush, paramIndex)
|
return p.createUpdateField(t[:len(t)-1], UpdateOperatorPush, paramIndex)
|
||||||
}
|
}
|
||||||
|
if len(t) > 1 && t[len(t)-1] == "Inc" {
|
||||||
|
return p.createUpdateField(t[:len(t)-1], UpdateOperatorInc, paramIndex)
|
||||||
|
}
|
||||||
return p.createUpdateField(t, UpdateOperatorSet, paramIndex)
|
return p.createUpdateField(t, UpdateOperatorSet, paramIndex)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -175,9 +174,24 @@ func (p interfaceMethodParser) createUpdateField(t []string, operator UpdateOper
|
||||||
return UpdateField{}, NewStructFieldNotFoundError(t)
|
return UpdateField{}, NewStructFieldNotFoundError(t)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !p.validateUpdateOperator(fieldReference.ReferencedField().Type, operator) {
|
||||||
|
return UpdateField{}, NewIncompatibleUpdateOperatorError(operator, fieldReference)
|
||||||
|
}
|
||||||
|
|
||||||
return UpdateField{
|
return UpdateField{
|
||||||
FieldReference: fieldReference,
|
FieldReference: fieldReference,
|
||||||
ParamIndex: paramIndex,
|
ParamIndex: paramIndex,
|
||||||
Operator: operator,
|
Operator: operator,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p interfaceMethodParser) validateUpdateOperator(referencedType code.Type, operator UpdateOperator) bool {
|
||||||
|
switch operator {
|
||||||
|
case UpdateOperatorPush:
|
||||||
|
_, ok := referencedType.(code.ArrayType)
|
||||||
|
return ok
|
||||||
|
case UpdateOperatorInc:
|
||||||
|
return referencedType.IsNumber()
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue