diff --git a/internal/code/models.go b/internal/code/models.go index b321456..094efa0 100644 --- a/internal/code/models.go +++ b/internal/code/models.go @@ -88,6 +88,11 @@ func (intf InterfaceType) Code() string { return `interface{}` } +// IsNumber returns false +func (intf InterfaceType) IsNumber() bool { + return false +} + // Method is a definition of the method inside the interface type Method struct { Name string @@ -104,6 +109,7 @@ type Param struct { // Type is an interface for value types type Type interface { Code() string + IsNumber() bool } // SimpleType is a type that can be called directly @@ -114,6 +120,13 @@ func (t SimpleType) Code() string { return string(t) } +// IsNumber returns true id a SimpleType is integer or float variants. +func (t SimpleType) IsNumber() bool { + return t == "uint" || t == "uint8" || t == "uint16" || t == "uint32" || t == "uint64" || + t == "int" || t == "int8" || t == "int16" || t == "int32" || t == "int64" || + t == "float32" || t == "float64" +} + // ExternalType is a type that is called to another package type ExternalType struct { PackageAlias string @@ -125,6 +138,11 @@ func (t ExternalType) Code() string { return fmt.Sprintf("%s.%s", t.PackageAlias, t.Name) } +// IsNumber returns false +func (t ExternalType) IsNumber() bool { + return false +} + // PointerType is a model of pointer type PointerType struct { ContainedType Type @@ -135,6 +153,11 @@ func (t PointerType) Code() string { return fmt.Sprintf("*%s", t.ContainedType.Code()) } +// IsNumber returns IsNumber of its contained type +func (t PointerType) IsNumber() bool { + return t.ContainedType.IsNumber() +} + // ArrayType is a model of array type ArrayType struct { ContainedType Type @@ -144,3 +167,8 @@ type ArrayType struct { func (t ArrayType) Code() string { return fmt.Sprintf("[]%s", t.ContainedType.Code()) } + +// IsNumber returns false +func (t ArrayType) IsNumber() bool { + return false +} diff --git a/internal/code/models_test.go b/internal/code/models_test.go index 8ad7ab8..66b6f54 100644 --- a/internal/code/models_test.go +++ b/internal/code/models_test.go @@ -92,7 +92,7 @@ type TypeCodeTestCase struct { ExpectedCode string } -func TestArrayTypeCode(t *testing.T) { +func TestTypeCode(t *testing.T) { testTable := []TypeCodeTestCase{ { Name: "simple type", @@ -126,3 +126,89 @@ func TestArrayTypeCode(t *testing.T) { }) } } + +type TypeIsNumberTestCase struct { + Name string + Type code.Type + IsNumber bool +} + +func TestTypeIsNumber(t *testing.T) { + testTable := []TypeIsNumberTestCase{ + { + Name: "simple type: int", + Type: code.SimpleType("int"), + IsNumber: true, + }, + { + Name: "simple type: other integer variants", + Type: code.SimpleType("int64"), + IsNumber: true, + }, + { + Name: "simple type: uint", + Type: code.SimpleType("uint"), + IsNumber: true, + }, + { + Name: "simple type: other unsigned integer variants", + Type: code.SimpleType("uint64"), + IsNumber: true, + }, + { + Name: "simple type: float32", + Type: code.SimpleType("float32"), + IsNumber: true, + }, + { + Name: "simple type: other float variants", + Type: code.SimpleType("float64"), + IsNumber: true, + }, + { + Name: "simple type: non-number primitive type", + Type: code.SimpleType("string"), + IsNumber: false, + }, + { + Name: "simple type: non-number custom type", + Type: code.SimpleType("UserModel"), + IsNumber: false, + }, + { + Name: "external type", + Type: code.ExternalType{PackageAlias: "context", Name: "Context"}, + IsNumber: false, + }, + { + Name: "pointer type: number", + Type: code.PointerType{ContainedType: code.SimpleType("int")}, + IsNumber: true, + }, + { + Name: "pointer type: non-number", + Type: code.PointerType{ContainedType: code.SimpleType("string")}, + IsNumber: false, + }, + { + Name: "array type", + Type: code.ArrayType{ContainedType: code.SimpleType("int")}, + IsNumber: false, + }, + { + Name: "interface type", + Type: code.InterfaceType{}, + IsNumber: false, + }, + } + + for _, testCase := range testTable { + t.Run(testCase.Name, func(t *testing.T) { + isNumber := testCase.Type.IsNumber() + + if isNumber != testCase.IsNumber { + t.Errorf("Expected = %+v\nReceived = %+v", testCase.IsNumber, isNumber) + } + }) + } +} diff --git a/internal/mongo/generator.go b/internal/mongo/generator.go index d128197..bde2b5a 100644 --- a/internal/mongo/generator.go +++ b/internal/mongo/generator.go @@ -122,8 +122,8 @@ func (g RepositoryGenerator) generateFindImplementation(operation spec.FindOpera return generateFromTemplate("mongo_repository_findmany", findManyTemplate, tmplData) } -func (g RepositoryGenerator) mongoSorts(sortSpec []spec.Sort) ([]sort, error) { - var sorts []sort +func (g RepositoryGenerator) mongoSorts(sortSpec []spec.Sort) ([]findSort, error) { + var sorts []findSort for _, s := range sortSpec { bsonFieldReference, err := g.bsonFieldReference(s.FieldReference) @@ -131,7 +131,7 @@ func (g RepositoryGenerator) mongoSorts(sortSpec []spec.Sort) ([]sort, error) { return nil, err } - sorts = append(sorts, sort{ + sorts = append(sorts, findSort{ BsonTag: bsonFieldReference, Ordering: s.Ordering, }) @@ -196,6 +196,8 @@ func getUpdateOperatorKey(operator spec.UpdateOperator) string { return "$set" case spec.UpdateOperatorPush: return "$push" + case spec.UpdateOperatorInc: + return "$inc" default: return "" } diff --git a/internal/mongo/generator_test.go b/internal/mongo/generator_test.go index b72d7ae..d35a57e 100644 --- a/internal/mongo/generator_test.go +++ b/internal/mongo/generator_test.go @@ -1063,7 +1063,7 @@ func (r *UserRepositoryMongo) UpdateAgeByGender(arg0 context.Context, arg1 int, Params: []code.Param{ {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, {Name: "consentHistory", Type: code.SimpleType("ConsentHistory")}, - {Name: "gender", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, + {Name: "id", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, }, Returns: []code.Type{ code.SimpleType("bool"), @@ -1095,6 +1095,47 @@ func (r *UserRepositoryMongo) UpdateConsentHistoryPushByID(arg0 context.Context, } return result.MatchedCount > 0, err } +`, + }, + { + Name: "simple update inc method", + MethodSpec: spec.MethodSpec{ + Name: "UpdateAgeIncByID", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Name: "age", Type: code.SimpleType("int")}, + {Name: "id", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, + }, + Returns: []code.Type{ + code.SimpleType("bool"), + code.SimpleType("error"), + }, + Operation: spec.UpdateOperation{ + Update: spec.UpdateFields{ + {FieldReference: spec.FieldReference{ageField}, ParamIndex: 1, Operator: spec.UpdateOperatorInc}, + }, + Mode: spec.QueryModeOne, + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + {FieldReference: spec.FieldReference{idField}, Comparator: spec.ComparatorEqual, ParamIndex: 2}, + }, + }, + }, + }, + ExpectedCode: ` +func (r *UserRepositoryMongo) UpdateAgeIncByID(arg0 context.Context, arg1 int, arg2 primitive.ObjectID) (bool, error) { + result, err := r.collection.UpdateOne(arg0, bson.M{ + "_id": arg2, + }, bson.M{ + "$inc": bson.M{ + "age": arg1, + }, + }) + if err != nil { + return false, err + } + return result.MatchedCount > 0, err +} `, }, { @@ -1129,12 +1170,12 @@ func (r *UserRepositoryMongo) UpdateEnabledAndConsentHistoryPushByID(arg0 contex result, err := r.collection.UpdateOne(arg0, bson.M{ "_id": arg3, }, bson.M{ - "$set": bson.M{ - "enabled": arg1, - }, "$push": bson.M{ "consent_history": arg2, }, + "$set": bson.M{ + "enabled": arg1, + }, }) if err != nil { return false, err diff --git a/internal/mongo/models.go b/internal/mongo/models.go index 807b393..1f12a58 100644 --- a/internal/mongo/models.go +++ b/internal/mongo/models.go @@ -2,6 +2,7 @@ package mongo import ( "fmt" + "sort" "strings" "github.com/sunboyy/repogen/internal/spec" @@ -26,11 +27,17 @@ func (u updateModel) Code() string { type updateFields map[string][]updateField func (u updateFields) Code() string { - var lines []string - for k, v := range u { - lines = append(lines, fmt.Sprintf(` "%s": bson.M{`, k)) + var keys []string + for k := range u { + keys = append(keys, k) + } + sort.Strings(keys) - for _, field := range v { + var lines []string + for _, key := range keys { + lines = append(lines, fmt.Sprintf(` "%s": bson.M{`, key)) + + for _, field := range u[key] { lines = append(lines, fmt.Sprintf(` "%s": arg%d,`, field.BsonTag, field.ParamIndex)) } diff --git a/internal/mongo/templates.go b/internal/mongo/templates.go index a00feb8..da3112e 100644 --- a/internal/mongo/templates.go +++ b/internal/mongo/templates.go @@ -91,15 +91,15 @@ const insertManyTemplate = ` var entities []interface{} type mongoFindTemplateData struct { EntityType string QuerySpec querySpec - Sorts []sort + Sorts []findSort } -type sort struct { +type findSort struct { BsonTag string Ordering spec.Ordering } -func (s sort) OrderNum() int { +func (s findSort) OrderNum() int { if s.Ordering == spec.OrderingAscending { return 1 } diff --git a/internal/spec/errors.go b/internal/spec/errors.go index 40f987f..7580991 100644 --- a/internal/spec/errors.go +++ b/internal/spec/errors.go @@ -22,8 +22,6 @@ func (err ParsingError) Error() string { return "update fields are invalid" case ContextParamRequiredError: return "context parameter is required" - case PushNonArrayError: - return "cannot use push operation in a non-array type" } return string(err) } @@ -35,7 +33,6 @@ const ( InvalidParamError ParsingError = "ERROR_INVALID_PARAM" InvalidUpdateFieldsError ParsingError = "ERROR_INVALID_UPDATE_FIELDS" ContextParamRequiredError ParsingError = "ERROR_CONTEXT_PARAM_REQUIRED" - PushNonArrayError ParsingError = "ERROR_PUSH_NON_ARRAY" ) // NewInvalidQueryError creates invalidQueryError @@ -64,6 +61,26 @@ func (err invalidSortError) Error() string { return fmt.Sprintf("invalid sort '%s'", err.SortString) } +// NewArgumentTypeNotMatchedError creates argumentTypeNotMatchedError +func NewArgumentTypeNotMatchedError(fieldName string, requiredType code.Type, givenType code.Type) error { + return argumentTypeNotMatchedError{ + FieldName: fieldName, + RequiredType: requiredType, + GivenType: givenType, + } +} + +type argumentTypeNotMatchedError struct { + FieldName string + RequiredType code.Type + GivenType code.Type +} + +func (err argumentTypeNotMatchedError) Error() string { + return fmt.Sprintf("field '%s' requires an argument of type '%s' (got '%s')", + err.FieldName, err.RequiredType.Code(), err.GivenType.Code()) +} + // NewUnknownOperationError creates unknownOperationError func NewUnknownOperationError(operationName string) error { return unknownOperationError{OperationName: operationName} @@ -107,3 +124,23 @@ func (err incompatibleComparatorError) Error() string { return fmt.Sprintf("cannot use comparator %s with struct field '%s' of type '%s'", err.Comparator, err.Field.Name, err.Field.Type.Code()) } + +// NewIncompatibleUpdateOperatorError creates incompatibleUpdateOperatorError +func NewIncompatibleUpdateOperatorError(updateOperator UpdateOperator, fieldReference FieldReference) error { + return incompatibleUpdateOperatorError{ + UpdateOperator: updateOperator, + ReferencingCode: fieldReference.ReferencingCode(), + ReferencedType: fieldReference.ReferencedField().Type, + } +} + +type incompatibleUpdateOperatorError struct { + UpdateOperator UpdateOperator + ReferencingCode string + ReferencedType code.Type +} + +func (err incompatibleUpdateOperatorError) Error() string { + return fmt.Sprintf("cannot use update operator %s with struct field '%s' of type '%s'", + err.UpdateOperator, err.ReferencingCode, err.ReferencedType.Code()) +} diff --git a/internal/spec/errors_test.go b/internal/spec/errors_test.go index 0fe35a4..8d7423f 100644 --- a/internal/spec/errors_test.go +++ b/internal/spec/errors_test.go @@ -43,6 +43,19 @@ func TestError(t *testing.T) { Error: spec.NewInvalidSortError([]string{"Order", "By"}), ExpectedString: "invalid sort 'OrderBy'", }, + { + Name: "ArgumentTypeNotMatchedError", + Error: spec.NewArgumentTypeNotMatchedError("Age", code.SimpleType("int"), code.SimpleType("float64")), + ExpectedString: "field 'Age' requires an argument of type 'int' (got 'float64')", + }, + { + Name: "IncompatibleUpdateOperatorError", + Error: spec.NewIncompatibleUpdateOperatorError(spec.UpdateOperatorInc, spec.FieldReference{{ + Name: "City", + Type: code.SimpleType("string"), + }}), + ExpectedString: "cannot use update operator INC with struct field 'City' of type 'string'", + }, } for _, testCase := range testTable { diff --git a/internal/spec/field.go b/internal/spec/field.go index 4689371..2532cba 100644 --- a/internal/spec/field.go +++ b/internal/spec/field.go @@ -14,6 +14,15 @@ func (r FieldReference) ReferencedField() code.StructField { return r[len(r)-1] } +// ReferencingCode returns a string containing name of the referenced fields concatenating with period (.). +func (r FieldReference) ReferencingCode() string { + var fieldNames []string + for _, field := range r { + fieldNames = append(fieldNames, field.Name) + } + return strings.Join(fieldNames, ".") +} + type fieldResolver struct { Structs code.Structs } diff --git a/internal/spec/parser.go b/internal/spec/parser.go index 1d74b1e..45a3b6b 100644 --- a/internal/spec/parser.go +++ b/internal/spec/parser.go @@ -381,9 +381,11 @@ func (p interfaceMethodParser) validateQueryFromParams(params []code.Param, quer } for i := 0; i < predicate.Comparator.NumberOfArguments(); i++ { - if params[currentParamIndex].Type != predicate.Comparator.ArgumentTypeFromFieldType( - predicate.FieldReference.ReferencedField().Type) { - return InvalidParamError + requiredType := predicate.Comparator.ArgumentTypeFromFieldType(predicate.FieldReference.ReferencedField().Type) + + if params[currentParamIndex].Type != requiredType { + return NewArgumentTypeNotMatchedError(predicate.FieldReference.ReferencingCode(), requiredType, + params[currentParamIndex].Type) } currentParamIndex++ } diff --git a/internal/spec/parser_test.go b/internal/spec/parser_test.go index d97eb81..3bc84c2 100644 --- a/internal/spec/parser_test.go +++ b/internal/spec/parser_test.go @@ -811,6 +811,30 @@ func TestParseInterfaceMethod_Update(t *testing.T) { }}, }, }, + { + Name: "UpdateArgPushByArg method", + Method: code.Method{ + Name: "UpdateAgeIncByID", + Params: []code.Param{ + {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Type: code.SimpleType("int")}, + {Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, + }, + Returns: []code.Type{ + code.SimpleType("bool"), + code.SimpleType("error"), + }, + }, + ExpectedOperation: spec.UpdateOperation{ + Update: spec.UpdateFields{ + {FieldReference: spec.FieldReference{ageField}, ParamIndex: 1, Operator: spec.UpdateOperatorInc}, + }, + Mode: spec.QueryModeOne, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + {FieldReference: spec.FieldReference{idField}, Comparator: spec.ComparatorEqual, ParamIndex: 2}, + }}, + }, + }, { Name: "UpdateArgAndArgPushByArg method", Method: code.Method{ @@ -1564,7 +1588,7 @@ func TestParseInterfaceMethod_Find_Invalid(t *testing.T) { code.SimpleType("error"), }, }, - ExpectedError: spec.InvalidParamError, + ExpectedError: spec.NewArgumentTypeNotMatchedError(genderField.Name, genderField.Type, code.SimpleType("string")), }, { Name: "mismatched method parameter type for special case", @@ -1579,7 +1603,8 @@ func TestParseInterfaceMethod_Find_Invalid(t *testing.T) { code.SimpleType("error"), }, }, - ExpectedError: spec.InvalidParamError, + ExpectedError: spec.NewArgumentTypeNotMatchedError(cityField.Name, + code.ArrayType{ContainedType: code.SimpleType("string")}, code.SimpleType("string")), }, { Name: "misplaced operator token (leftmost)", @@ -1716,7 +1741,29 @@ func TestParseInterfaceMethod_Update_Invalid(t *testing.T) { code.SimpleType("error"), }, }, - ExpectedError: spec.PushNonArrayError, + ExpectedError: spec.NewIncompatibleUpdateOperatorError(spec.UpdateOperatorPush, spec.FieldReference{{ + Name: "Gender", + Type: code.SimpleType("Gender"), + }}), + }, + { + Name: "inc operator in non-number field", + Method: code.Method{ + Name: "UpdateCityIncByID", + Params: []code.Param{ + {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Type: code.SimpleType("Gender")}, + {Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, + }, + Returns: []code.Type{ + code.SimpleType("bool"), + code.SimpleType("error"), + }, + }, + ExpectedError: spec.NewIncompatibleUpdateOperatorError(spec.UpdateOperatorInc, spec.FieldReference{{ + Name: "City", + Type: code.SimpleType("string"), + }}), }, { Name: "update method without query", @@ -1762,7 +1809,8 @@ func TestParseInterfaceMethod_Update_Invalid(t *testing.T) { code.SimpleType("error"), }, }, - ExpectedError: spec.InvalidUpdateFieldsError, + ExpectedError: spec.NewArgumentTypeNotMatchedError(consentHistoryField.Name, code.SimpleType("ConsentHistoryItem"), + code.ArrayType{ContainedType: code.SimpleType("ConsentHistoryItem")}), }, { Name: "insufficient function parameters", @@ -1839,7 +1887,7 @@ func TestParseInterfaceMethod_Update_Invalid(t *testing.T) { code.SimpleType("error"), }, }, - ExpectedError: spec.InvalidUpdateFieldsError, + ExpectedError: spec.NewArgumentTypeNotMatchedError(ageField.Name, ageField.Type, code.SimpleType("float64")), }, { Name: "struct field does not match parameter in query", @@ -1855,7 +1903,7 @@ func TestParseInterfaceMethod_Update_Invalid(t *testing.T) { code.SimpleType("error"), }, }, - ExpectedError: spec.InvalidParamError, + ExpectedError: spec.NewArgumentTypeNotMatchedError(genderField.Name, genderField.Type, code.SimpleType("string")), }, } @@ -2019,7 +2067,7 @@ func TestParseInterfaceMethod_Delete_Invalid(t *testing.T) { code.SimpleType("error"), }, }, - ExpectedError: spec.InvalidParamError, + ExpectedError: spec.NewArgumentTypeNotMatchedError("Gender", code.SimpleType("Gender"), code.SimpleType("string")), }, { Name: "mismatched method parameter type for special case", @@ -2034,7 +2082,8 @@ func TestParseInterfaceMethod_Delete_Invalid(t *testing.T) { code.SimpleType("error"), }, }, - ExpectedError: spec.InvalidParamError, + ExpectedError: spec.NewArgumentTypeNotMatchedError("City", + code.ArrayType{ContainedType: code.SimpleType("string")}, code.SimpleType("string")), }, } @@ -2150,7 +2199,7 @@ func TestParseInterfaceMethod_Count_Invalid(t *testing.T) { code.SimpleType("error"), }, }, - ExpectedError: spec.InvalidParamError, + ExpectedError: spec.NewArgumentTypeNotMatchedError("Gender", code.SimpleType("Gender"), code.SimpleType("string")), }, { Name: "struct field not found", diff --git a/internal/spec/update.go b/internal/spec/update.go index ba9338d..245807a 100644 --- a/internal/spec/update.go +++ b/internal/spec/update.go @@ -61,6 +61,7 @@ type UpdateOperator string const ( UpdateOperatorSet UpdateOperator = "SET" UpdateOperatorPush UpdateOperator = "PUSH" + UpdateOperatorInc UpdateOperator = "INC" ) // NumberOfArguments returns number of arguments required to perform an update operation @@ -69,16 +70,13 @@ func (o UpdateOperator) NumberOfArguments() int { } // ArgumentType returns type that is required for function parameter -func (o UpdateOperator) ArgumentType(fieldType code.Type) (code.Type, error) { +func (o UpdateOperator) ArgumentType(fieldType code.Type) code.Type { switch o { case UpdateOperatorPush: - arrayType, ok := fieldType.(code.ArrayType) - if !ok { - return nil, PushNonArrayError - } - return arrayType.ContainedType, nil + arrayType := fieldType.(code.ArrayType) + return arrayType.ContainedType default: - return fieldType, nil + return fieldType } } @@ -147,14 +145,12 @@ func (p interfaceMethodParser) parseUpdate(tokens []string) (Update, error) { return nil, InvalidUpdateFieldsError } - requiredType, err := field.Operator.ArgumentType(field.FieldReference.ReferencedField().Type) - if err != nil { - return nil, err - } + requiredType := field.Operator.ArgumentType(field.FieldReference.ReferencedField().Type) for i := 0; i < field.Operator.NumberOfArguments(); i++ { if requiredType != p.Method.Params[field.ParamIndex+i].Type { - return nil, InvalidUpdateFieldsError + return nil, NewArgumentTypeNotMatchedError(field.FieldReference.ReferencingCode(), requiredType, + p.Method.Params[field.ParamIndex+i].Type) } } } @@ -166,6 +162,9 @@ func (p interfaceMethodParser) parseUpdateField(t []string, paramIndex int) (Upd if len(t) > 1 && t[len(t)-1] == "Push" { return p.createUpdateField(t[:len(t)-1], UpdateOperatorPush, paramIndex) } + if len(t) > 1 && t[len(t)-1] == "Inc" { + return p.createUpdateField(t[:len(t)-1], UpdateOperatorInc, paramIndex) + } return p.createUpdateField(t, UpdateOperatorSet, paramIndex) } @@ -175,9 +174,24 @@ func (p interfaceMethodParser) createUpdateField(t []string, operator UpdateOper return UpdateField{}, NewStructFieldNotFoundError(t) } + if !p.validateUpdateOperator(fieldReference.ReferencedField().Type, operator) { + return UpdateField{}, NewIncompatibleUpdateOperatorError(operator, fieldReference) + } + return UpdateField{ FieldReference: fieldReference, ParamIndex: paramIndex, Operator: operator, }, nil } + +func (p interfaceMethodParser) validateUpdateOperator(referencedType code.Type, operator UpdateOperator) bool { + switch operator { + case UpdateOperatorPush: + _, ok := referencedType.(code.ArrayType) + return ok + case UpdateOperatorInc: + return referencedType.IsNumber() + } + return true +}