From 295f457e64c95a2ef6e10ee5ae50c9bae6a611bc Mon Sep 17 00:00:00 2001 From: sunboyy Date: Wed, 31 Mar 2021 18:44:31 +0700 Subject: [PATCH] Add push update operator --- README.md | 16 ++-- internal/code/extractor_test.go | 2 +- internal/code/models_test.go | 8 +- internal/mongo/errors.go | 13 +++ internal/mongo/errors_test.go | 8 +- internal/mongo/generator.go | 24 +++++- internal/mongo/generator_test.go | 130 ++++++++++++++++++++++++++++-- internal/mongo/models.go | 17 ++-- internal/spec/errors.go | 5 +- internal/spec/errors_test.go | 2 +- internal/spec/models_test.go | 2 +- internal/spec/parser_test.go | 134 +++++++++++++++++++++++++++---- internal/spec/update.go | 76 +++++++++++++++--- internal/spec/update_test.go | 2 +- 14 files changed, 382 insertions(+), 57 deletions(-) diff --git a/README.md b/README.md index ea6d240..ab6cf96 100644 --- a/README.md +++ b/README.md @@ -200,13 +200,15 @@ Assuming that the `City` field in the `UserModel` struct is of type `string` and When you specify the query like `ByAge`, it finds documents that contains age value **equal to** the provided parameter value. However, there are other types of comparators provided for you to use as follows. -- `Not`: The value in the document is not equal to the provided parameter value. -- `LessThan`: The value in the document is less than the provided parameter value. -- `LessThanEqual`: The value in the document is less than or equal to the provided parameter value. -- `GreaterThan`: The value in the document is greater than the provided parameter value. -- `GreaterThanEqual`: The value in the document is greater than of equal to the provided parameter value. -- `Between`: The value in the document is between two provided parameter values (inclusive). -- `In`: The value in the document is in the provided parameter value (which is a slice). +| Keyword | Meaning | Sample | +|--------------------|-----------------|--------------------------------------| +| - | == $1 | `FindByUsername(ctx, $1)` | +| `LessThan` | < $1 | `FindByAgeLessThan(ctx, $1)` | +| `LessThanEqual` | <= $1 | `FindByAgeLessThanEqual(ctx, $1)` | +| `GreaterThan` | > $1 | `FindByAgeGreaterThan(ctx, $1)` | +| `GreaterThanEqual` | >= $1 | `FindByAgeGreaterThanEqual(ctx, $1)` | +| `Between` | >= $1 and <= $2 | `FindByAgeBetween(ctx, $1, $2)` | +| `In` | in slice $1 | `FindByCityIn(ctx, $1)` | To apply these comparators to the query, place these words after the field name such as `ByAgeGreaterThan`. You can also use comparators along with `And` and `Or` operators. For example, `ByGenderNotOrAgeLessThan` will apply `Not` comparator to the `Gender` field and `LessThan` comparator to the `Age` field. diff --git a/internal/code/extractor_test.go b/internal/code/extractor_test.go index 4a1a705..86ced4d 100644 --- a/internal/code/extractor_test.go +++ b/internal/code/extractor_test.go @@ -277,7 +277,7 @@ type UserRepository interface { file := code.ExtractComponents(f) if !reflect.DeepEqual(file, testCase.ExpectedOutput) { - t.Errorf("Expected = %v\nReceived = %v", testCase.ExpectedOutput, file) + t.Errorf("Expected = %+v\nReceived = %+v", testCase.ExpectedOutput, file) } }) } diff --git a/internal/code/models_test.go b/internal/code/models_test.go index f1ccb38..8ad7ab8 100644 --- a/internal/code/models_test.go +++ b/internal/code/models_test.go @@ -24,7 +24,7 @@ func TestStructsByName(t *testing.T) { t.Fail() } if !reflect.DeepEqual(structModel, userStruct) { - t.Errorf("Expected = %v\nReceived = %v", userStruct, structModel) + t.Errorf("Expected = %+v\nReceived = %+v", userStruct, structModel) } }) @@ -49,7 +49,7 @@ func TestStructFieldsByName(t *testing.T) { t.Fail() } if !reflect.DeepEqual(field, usernameField) { - t.Errorf("Expected = %v\nReceived = %v", usernameField, field) + t.Errorf("Expected = %+v\nReceived = %+v", usernameField, field) } }) @@ -73,7 +73,7 @@ func TestInterfacesByName(t *testing.T) { t.Fail() } if !reflect.DeepEqual(intf, userRepoIntf) { - t.Errorf("Expected = %v\nReceived = %v", userRepoIntf, intf) + t.Errorf("Expected = %+v\nReceived = %+v", userRepoIntf, intf) } }) @@ -121,7 +121,7 @@ func TestArrayTypeCode(t *testing.T) { code := testCase.Type.Code() if code != testCase.ExpectedCode { - t.Errorf("Expected = %v\nReceived = %v", testCase.ExpectedCode, code) + t.Errorf("Expected = %+v\nReceived = %+v", testCase.ExpectedCode, code) } }) } diff --git a/internal/mongo/errors.go b/internal/mongo/errors.go index cbfcd81..327f406 100644 --- a/internal/mongo/errors.go +++ b/internal/mongo/errors.go @@ -44,3 +44,16 @@ type updateTypeNotSupportedError struct { func (err updateTypeNotSupportedError) Error() string { return fmt.Sprintf("update type %s not supported", err.Update.Name()) } + +// NewUpdateOperatorNotSupportedError creates updateOperatorNotSupportedError +func NewUpdateOperatorNotSupportedError(operator spec.UpdateOperator) error { + return updateOperatorNotSupportedError{Operator: operator} +} + +type updateOperatorNotSupportedError struct { + Operator spec.UpdateOperator +} + +func (err updateOperatorNotSupportedError) Error() string { + return fmt.Sprintf("update operator %s not supported", err.Operator) +} diff --git a/internal/mongo/errors_test.go b/internal/mongo/errors_test.go index 7d60a23..8d90bd9 100644 --- a/internal/mongo/errors_test.go +++ b/internal/mongo/errors_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/sunboyy/repogen/internal/mongo" + "github.com/sunboyy/repogen/internal/spec" ) type ErrorTestCase struct { @@ -40,12 +41,17 @@ func TestError(t *testing.T) { Error: mongo.NewUpdateTypeNotSupportedError(StubUpdate{}), ExpectedString: "update type Stub not supported", }, + { + Name: "UpdateOperatorNotSupportedError", + Error: mongo.NewUpdateOperatorNotSupportedError(spec.UpdateOperator("STUB")), + ExpectedString: "update operator STUB not supported", + }, } for _, testCase := range testTable { t.Run(testCase.Name, func(t *testing.T) { if testCase.Error.Error() != testCase.ExpectedString { - t.Errorf("Expected = %v\nReceived = %v", testCase.ExpectedString, testCase.Error.Error()) + t.Errorf("Expected = %+v\nReceived = %+v", testCase.ExpectedString, testCase.Error.Error()) } }) } diff --git a/internal/mongo/generator.go b/internal/mongo/generator.go index 64c3d05..d128197 100644 --- a/internal/mongo/generator.go +++ b/internal/mongo/generator.go @@ -167,13 +167,22 @@ func (g RepositoryGenerator) getMongoUpdate(updateSpec spec.Update) (update, err case spec.UpdateModel: return updateModel{}, nil case spec.UpdateFields: - var update updateFields + update := make(updateFields) for _, field := range updateSpec { bsonFieldReference, err := g.bsonFieldReference(field.FieldReference) if err != nil { return querySpec{}, err } - update.Fields = append(update.Fields, updateField{BsonTag: bsonFieldReference, ParamIndex: field.ParamIndex}) + + updateKey := getUpdateOperatorKey(field.Operator) + if updateKey == "" { + return querySpec{}, NewUpdateOperatorNotSupportedError(field.Operator) + } + updateField := updateField{ + BsonTag: bsonFieldReference, + ParamIndex: field.ParamIndex, + } + update[updateKey] = append(update[updateKey], updateField) } return update, nil default: @@ -181,6 +190,17 @@ func (g RepositoryGenerator) getMongoUpdate(updateSpec spec.Update) (update, err } } +func getUpdateOperatorKey(operator spec.UpdateOperator) string { + switch operator { + case spec.UpdateOperatorSet: + return "$set" + case spec.UpdateOperatorPush: + return "$push" + default: + return "" + } +} + func (g RepositoryGenerator) generateDeleteImplementation(operation spec.DeleteOperation) (string, error) { querySpec, err := g.mongoQuerySpec(operation.Query) if err != nil { diff --git a/internal/mongo/generator_test.go b/internal/mongo/generator_test.go index a675ff0..b72d7ae 100644 --- a/internal/mongo/generator_test.go +++ b/internal/mongo/generator_test.go @@ -31,6 +31,11 @@ var ( Type: code.SimpleType("NameModel"), Tags: map[string][]string{"bson": {"name"}}, } + consentHistoryField = code.StructField{ + Name: "ConsentHistory", + Type: code.ArrayType{ContainedType: code.SimpleType("ConsentHistory")}, + Tags: map[string][]string{"bson": {"consent_history"}}, + } enabledField = code.StructField{ Name: "Enabled", Type: code.SimpleType("bool"), @@ -60,6 +65,7 @@ var userModel = code.Struct{ genderField, ageField, nameField, + consentHistoryField, enabledField, accessTokenField, }, @@ -983,7 +989,7 @@ func (r *UserRepositoryMongo) UpdateByID(arg0 context.Context, arg1 *UserModel, }, Operation: spec.UpdateOperation{ Update: spec.UpdateFields{ - {FieldReference: spec.FieldReference{ageField}, ParamIndex: 1}, + {FieldReference: spec.FieldReference{ageField}, ParamIndex: 1, Operator: spec.UpdateOperatorSet}, }, Mode: spec.QueryModeOne, Query: spec.QuerySpec{ @@ -1024,7 +1030,7 @@ func (r *UserRepositoryMongo) UpdateAgeByID(arg0 context.Context, arg1 int, arg2 }, Operation: spec.UpdateOperation{ Update: spec.UpdateFields{ - {FieldReference: spec.FieldReference{ageField}, ParamIndex: 1}, + {FieldReference: spec.FieldReference{ageField}, ParamIndex: 1, Operator: spec.UpdateOperatorSet}, }, Mode: spec.QueryModeMany, Query: spec.QuerySpec{ @@ -1048,6 +1054,93 @@ func (r *UserRepositoryMongo) UpdateAgeByGender(arg0 context.Context, arg1 int, } return int(result.MatchedCount), err } +`, + }, + { + Name: "simple update push method", + MethodSpec: spec.MethodSpec{ + Name: "UpdateConsentHistoryPushByID", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Name: "consentHistory", Type: code.SimpleType("ConsentHistory")}, + {Name: "gender", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, + }, + Returns: []code.Type{ + code.SimpleType("bool"), + code.SimpleType("error"), + }, + Operation: spec.UpdateOperation{ + Update: spec.UpdateFields{ + {FieldReference: spec.FieldReference{consentHistoryField}, ParamIndex: 1, Operator: spec.UpdateOperatorPush}, + }, + Mode: spec.QueryModeOne, + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + {FieldReference: spec.FieldReference{idField}, Comparator: spec.ComparatorEqual, ParamIndex: 2}, + }, + }, + }, + }, + ExpectedCode: ` +func (r *UserRepositoryMongo) UpdateConsentHistoryPushByID(arg0 context.Context, arg1 ConsentHistory, arg2 primitive.ObjectID) (bool, error) { + result, err := r.collection.UpdateOne(arg0, bson.M{ + "_id": arg2, + }, bson.M{ + "$push": bson.M{ + "consent_history": arg1, + }, + }) + if err != nil { + return false, err + } + return result.MatchedCount > 0, err +} +`, + }, + { + Name: "simple update set and push method", + MethodSpec: spec.MethodSpec{ + Name: "UpdateEnabledAndConsentHistoryPushByID", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Name: "enabled", Type: code.SimpleType("bool")}, + {Name: "consentHistory", Type: code.SimpleType("ConsentHistory")}, + {Name: "gender", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, + }, + Returns: []code.Type{ + code.SimpleType("bool"), + code.SimpleType("error"), + }, + Operation: spec.UpdateOperation{ + Update: spec.UpdateFields{ + {FieldReference: spec.FieldReference{enabledField}, ParamIndex: 1, Operator: spec.UpdateOperatorSet}, + {FieldReference: spec.FieldReference{consentHistoryField}, ParamIndex: 2, Operator: spec.UpdateOperatorPush}, + }, + Mode: spec.QueryModeOne, + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + {FieldReference: spec.FieldReference{idField}, Comparator: spec.ComparatorEqual, ParamIndex: 3}, + }, + }, + }, + }, + ExpectedCode: ` +func (r *UserRepositoryMongo) UpdateEnabledAndConsentHistoryPushByID(arg0 context.Context, arg1 bool, arg2 ConsentHistory, arg3 primitive.ObjectID) (bool, error) { + result, err := r.collection.UpdateOne(arg0, bson.M{ + "_id": arg3, + }, bson.M{ + "$set": bson.M{ + "enabled": arg1, + }, + "$push": bson.M{ + "consent_history": arg2, + }, + }) + if err != nil { + return false, err + } + return result.MatchedCount > 0, err +} `, }, { @@ -1065,7 +1158,7 @@ func (r *UserRepositoryMongo) UpdateAgeByGender(arg0 context.Context, arg1 int, }, Operation: spec.UpdateOperation{ Update: spec.UpdateFields{ - {FieldReference: spec.FieldReference{nameField, firstNameField}, ParamIndex: 1}, + {FieldReference: spec.FieldReference{nameField, firstNameField}, ParamIndex: 1, Operator: spec.UpdateOperatorSet}, }, Mode: spec.QueryModeOne, Query: spec.QuerySpec{ @@ -1947,7 +2040,7 @@ func TestGenerateMethod_Invalid(t *testing.T) { }, Operation: spec.UpdateOperation{ Update: spec.UpdateFields{ - {FieldReference: spec.FieldReference{accessTokenField}, ParamIndex: 1}, + {FieldReference: spec.FieldReference{accessTokenField}, ParamIndex: 1, Operator: spec.UpdateOperatorSet}, }, Mode: spec.QueryModeOne, Query: spec.QuerySpec{ @@ -1984,6 +2077,33 @@ func TestGenerateMethod_Invalid(t *testing.T) { }, ExpectedError: mongo.NewUpdateTypeNotSupportedError(StubUpdate{}), }, + { + Name: "update operator not supported", + Method: spec.MethodSpec{ + Name: "UpdateConsentHistoryAppendByID", + Params: []code.Param{ + {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Type: code.SimpleType("int")}, + {Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, + }, + Returns: []code.Type{ + code.SimpleType("bool"), + code.SimpleType("error"), + }, + Operation: spec.UpdateOperation{ + Update: spec.UpdateFields{ + {FieldReference: spec.FieldReference{consentHistoryField}, ParamIndex: 1, Operator: "APPEND"}, + }, + Mode: spec.QueryModeOne, + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + {FieldReference: spec.FieldReference{idField}, Comparator: spec.ComparatorEqual, ParamIndex: 2}, + }, + }, + }, + }, + ExpectedError: mongo.NewUpdateOperatorNotSupportedError("APPEND"), + }, } for _, testCase := range testTable { @@ -1994,7 +2114,7 @@ func TestGenerateMethod_Invalid(t *testing.T) { err := generator.GenerateMethod(testCase.Method, buffer) if err != testCase.ExpectedError { - t.Errorf("\nExpected = %v\nReceived = %v", testCase.ExpectedError, err) + t.Errorf("\nExpected = %+v\nReceived = %+v", testCase.ExpectedError, err) } }) } diff --git a/internal/mongo/models.go b/internal/mongo/models.go index 9825575..807b393 100644 --- a/internal/mongo/models.go +++ b/internal/mongo/models.go @@ -23,16 +23,19 @@ func (u updateModel) Code() string { return ` "$set": arg1,` } -type updateFields struct { - Fields []updateField -} +type updateFields map[string][]updateField func (u updateFields) Code() string { - lines := []string{` "$set": bson.M{`} - for _, field := range u.Fields { - lines = append(lines, fmt.Sprintf(` "%s": arg%d,`, field.BsonTag, field.ParamIndex)) + var lines []string + for k, v := range u { + lines = append(lines, fmt.Sprintf(` "%s": bson.M{`, k)) + + for _, field := range v { + lines = append(lines, fmt.Sprintf(` "%s": arg%d,`, field.BsonTag, field.ParamIndex)) + } + + lines = append(lines, ` },`) } - lines = append(lines, ` },`) return strings.Join(lines, "\n") } diff --git a/internal/spec/errors.go b/internal/spec/errors.go index 996bdc3..40f987f 100644 --- a/internal/spec/errors.go +++ b/internal/spec/errors.go @@ -19,9 +19,11 @@ func (err ParsingError) Error() string { case InvalidParamError: return "parameters do not match the query" case InvalidUpdateFieldsError: - return "update fields is invalid" + return "update fields are invalid" case ContextParamRequiredError: return "context parameter is required" + case PushNonArrayError: + return "cannot use push operation in a non-array type" } return string(err) } @@ -33,6 +35,7 @@ const ( InvalidParamError ParsingError = "ERROR_INVALID_PARAM" InvalidUpdateFieldsError ParsingError = "ERROR_INVALID_UPDATE_FIELDS" ContextParamRequiredError ParsingError = "ERROR_CONTEXT_PARAM_REQUIRED" + PushNonArrayError ParsingError = "ERROR_PUSH_NON_ARRAY" ) // NewInvalidQueryError creates invalidQueryError diff --git a/internal/spec/errors_test.go b/internal/spec/errors_test.go index 89884bd..0fe35a4 100644 --- a/internal/spec/errors_test.go +++ b/internal/spec/errors_test.go @@ -48,7 +48,7 @@ func TestError(t *testing.T) { for _, testCase := range testTable { t.Run(testCase.Name, func(t *testing.T) { if testCase.Error.Error() != testCase.ExpectedString { - t.Errorf("Expected = %v\nReceived = %v", testCase.ExpectedString, testCase.Error.Error()) + t.Errorf("Expected = %+v\nReceived = %+v", testCase.ExpectedString, testCase.Error.Error()) } }) } diff --git a/internal/spec/models_test.go b/internal/spec/models_test.go index ba56c14..7916885 100644 --- a/internal/spec/models_test.go +++ b/internal/spec/models_test.go @@ -38,7 +38,7 @@ func TestOperationName(t *testing.T) { for _, testCase := range testTable { t.Run(testCase.ExpectedName, func(t *testing.T) { if testCase.Operation.Name() != testCase.ExpectedName { - t.Errorf("Expected = %v\nReceived = %v", testCase.ExpectedName, testCase.Operation.Name()) + t.Errorf("Expected = %+v\nReceived = %+v", testCase.ExpectedName, testCase.Operation.Name()) } }) } diff --git a/internal/spec/parser_test.go b/internal/spec/parser_test.go index 77a6e4b..d97eb81 100644 --- a/internal/spec/parser_test.go +++ b/internal/spec/parser_test.go @@ -49,6 +49,10 @@ var ( Name: "Enabled", Type: code.SimpleType("bool"), } + consentHistoryField = code.StructField{ + Name: "ConsentHistory", + Type: code.ArrayType{ContainedType: code.SimpleType("ConsentHistoryItem")}, + } firstNameField = code.StructField{ Name: "First", @@ -81,6 +85,7 @@ var ( contactField, referrerField, defaultPaymentField, + consentHistoryField, enabledField, }, } @@ -149,7 +154,7 @@ func TestParseInterfaceMethod_Insert(t *testing.T) { Operation: testCase.ExpectedOperation, } if !reflect.DeepEqual(actualSpec, expectedOutput) { - t.Errorf("Expected = %v\nReceived = %v", expectedOutput, actualSpec) + t.Errorf("Expected = %+v\nReceived = %+v", expectedOutput, actualSpec) } }) } @@ -654,7 +659,7 @@ func TestParseInterfaceMethod_Find(t *testing.T) { Operation: testCase.ExpectedOperation, } if !reflect.DeepEqual(actualSpec, expectedOutput) { - t.Errorf("Expected = %v\nReceived = %v", expectedOutput, actualSpec) + t.Errorf("Expected = %+v\nReceived = %+v", expectedOutput, actualSpec) } }) } @@ -700,7 +705,7 @@ func TestParseInterfaceMethod_Update(t *testing.T) { }, ExpectedOperation: spec.UpdateOperation{ Update: spec.UpdateFields{ - {FieldReference: spec.FieldReference{genderField}, ParamIndex: 1}, + {FieldReference: spec.FieldReference{genderField}, ParamIndex: 1, Operator: spec.UpdateOperatorSet}, }, Mode: spec.QueryModeOne, Query: spec.QuerySpec{Predicates: []spec.Predicate{ @@ -724,7 +729,7 @@ func TestParseInterfaceMethod_Update(t *testing.T) { }, ExpectedOperation: spec.UpdateOperation{ Update: spec.UpdateFields{ - {FieldReference: spec.FieldReference{genderField}, ParamIndex: 1}, + {FieldReference: spec.FieldReference{genderField}, ParamIndex: 1, Operator: spec.UpdateOperatorSet}, }, Mode: spec.QueryModeMany, Query: spec.QuerySpec{Predicates: []spec.Predicate{ @@ -748,7 +753,7 @@ func TestParseInterfaceMethod_Update(t *testing.T) { }, ExpectedOperation: spec.UpdateOperation{ Update: spec.UpdateFields{ - {FieldReference: spec.FieldReference{nameField, firstNameField}, ParamIndex: 1}, + {FieldReference: spec.FieldReference{nameField, firstNameField}, ParamIndex: 1, Operator: spec.UpdateOperatorSet}, }, Mode: spec.QueryModeOne, Query: spec.QuerySpec{Predicates: []spec.Predicate{ @@ -773,8 +778,58 @@ func TestParseInterfaceMethod_Update(t *testing.T) { }, ExpectedOperation: spec.UpdateOperation{ Update: spec.UpdateFields{ - {FieldReference: spec.FieldReference{genderField}, ParamIndex: 1}, - {FieldReference: spec.FieldReference{cityField}, ParamIndex: 2}, + {FieldReference: spec.FieldReference{genderField}, ParamIndex: 1, Operator: spec.UpdateOperatorSet}, + {FieldReference: spec.FieldReference{cityField}, ParamIndex: 2, Operator: spec.UpdateOperatorSet}, + }, + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + {FieldReference: spec.FieldReference{idField}, Comparator: spec.ComparatorEqual, ParamIndex: 3}, + }}, + }, + }, + { + Name: "UpdateArgPushByArg method", + Method: code.Method{ + Name: "UpdateConsentHistoryPushByID", + Params: []code.Param{ + {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Type: code.SimpleType("ConsentHistoryItem")}, + {Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, + }, + Returns: []code.Type{ + code.SimpleType("int"), + code.SimpleType("error"), + }, + }, + ExpectedOperation: spec.UpdateOperation{ + Update: spec.UpdateFields{ + {FieldReference: spec.FieldReference{consentHistoryField}, ParamIndex: 1, Operator: spec.UpdateOperatorPush}, + }, + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + {FieldReference: spec.FieldReference{idField}, Comparator: spec.ComparatorEqual, ParamIndex: 2}, + }}, + }, + }, + { + Name: "UpdateArgAndArgPushByArg method", + Method: code.Method{ + Name: "UpdateEnabledAndConsentHistoryPushByID", + Params: []code.Param{ + {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Type: code.SimpleType("bool")}, + {Type: code.SimpleType("ConsentHistoryItem")}, + {Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, + }, + Returns: []code.Type{ + code.SimpleType("int"), + code.SimpleType("error"), + }, + }, + ExpectedOperation: spec.UpdateOperation{ + Update: spec.UpdateFields{ + {FieldReference: spec.FieldReference{enabledField}, ParamIndex: 1, Operator: spec.UpdateOperatorSet}, + {FieldReference: spec.FieldReference{consentHistoryField}, ParamIndex: 2, Operator: spec.UpdateOperatorPush}, }, Mode: spec.QueryModeMany, Query: spec.QuerySpec{Predicates: []spec.Predicate{ @@ -798,7 +853,7 @@ func TestParseInterfaceMethod_Update(t *testing.T) { Operation: testCase.ExpectedOperation, } if !reflect.DeepEqual(actualSpec, expectedOutput) { - t.Errorf("Expected = %v\nReceived = %v", expectedOutput, actualSpec) + t.Errorf("Expected = %+v\nReceived = %+v", expectedOutput, actualSpec) } }) } @@ -1089,7 +1144,7 @@ func TestParseInterfaceMethod_Delete(t *testing.T) { Operation: testCase.ExpectedOperation, } if !reflect.DeepEqual(actualSpec, expectedOutput) { - t.Errorf("Expected = %v\nReceived = %v", expectedOutput, actualSpec) + t.Errorf("Expected = %+v\nReceived = %+v", expectedOutput, actualSpec) } }) } @@ -1150,7 +1205,7 @@ func TestParseInterfaceMethod_Count(t *testing.T) { Operation: testCase.ExpectedOperation, } if !reflect.DeepEqual(actualSpec, expectedOutput) { - t.Errorf("Expected = %v\nReceived = %v", expectedOutput, actualSpec) + t.Errorf("Expected = %+v\nReceived = %+v", expectedOutput, actualSpec) } }) } @@ -1169,7 +1224,7 @@ func TestParseInterfaceMethod_Invalid(t *testing.T) { expectedError := spec.NewUnknownOperationError("Search") if err != expectedError { - t.Errorf("\nExpected = %v\nReceived = %v", expectedError, err) + t.Errorf("\nExpected = %+v\nReceived = %+v", expectedError, err) } } @@ -1275,7 +1330,7 @@ func TestParseInterfaceMethod_Insert_Invalid(t *testing.T) { _, err := spec.ParseInterfaceMethod(structs, structModel, testCase.Method) if err != testCase.ExpectedError { - t.Errorf("\nExpected = %v\nReceived = %v", testCase.ExpectedError, err) + t.Errorf("\nExpected = %+v\nReceived = %+v", testCase.ExpectedError, err) } }) } @@ -1577,7 +1632,7 @@ func TestParseInterfaceMethod_Find_Invalid(t *testing.T) { _, err := spec.ParseInterfaceMethod(structs, structModel, testCase.Method) if err.Error() != testCase.ExpectedError.Error() { - t.Errorf("\nExpected = %v\nReceived = %v", testCase.ExpectedError.Error(), err.Error()) + t.Errorf("\nExpected = %+v\nReceived = %+v", testCase.ExpectedError.Error(), err.Error()) } }) } @@ -1647,6 +1702,22 @@ func TestParseInterfaceMethod_Update_Invalid(t *testing.T) { }, ExpectedError: spec.InvalidUpdateFieldsError, }, + { + Name: "push operator in non-array field", + Method: code.Method{ + Name: "UpdateGenderPushByID", + Params: []code.Param{ + {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Type: code.SimpleType("Gender")}, + {Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, + }, + Returns: []code.Type{ + code.SimpleType("bool"), + code.SimpleType("error"), + }, + }, + ExpectedError: spec.PushNonArrayError, + }, { Name: "update method without query", Method: code.Method{ @@ -1677,6 +1748,37 @@ func TestParseInterfaceMethod_Update_Invalid(t *testing.T) { }, ExpectedError: spec.NewInvalidQueryError([]string{"ID", "And", "Username", "Or", "Gender"}), }, + { + Name: "parameters for push operator is not array's contained type", + Method: code.Method{ + Name: "UpdateConsentHistoryPushByID", + Params: []code.Param{ + {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Type: code.ArrayType{ContainedType: code.SimpleType("ConsentHistoryItem")}}, + {Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, + }, + Returns: []code.Type{ + code.SimpleType("int"), + code.SimpleType("error"), + }, + }, + ExpectedError: spec.InvalidUpdateFieldsError, + }, + { + Name: "insufficient function parameters", + Method: code.Method{ + Name: "UpdateEnabledAll", + Params: []code.Param{ + {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + // {Type: code.SimpleType("Enabled")}, + }, + Returns: []code.Type{ + code.SimpleType("int"), + code.SimpleType("error"), + }, + }, + ExpectedError: spec.InvalidUpdateFieldsError, + }, { Name: "update model with invalid parameter", Method: code.Method{ @@ -1762,7 +1864,7 @@ func TestParseInterfaceMethod_Update_Invalid(t *testing.T) { _, err := spec.ParseInterfaceMethod(structs, structModel, testCase.Method) if err != testCase.ExpectedError { - t.Errorf("\nExpected = %v\nReceived = %v", testCase.ExpectedError, err) + t.Errorf("\nExpected = %+v\nReceived = %+v", testCase.ExpectedError, err) } }) } @@ -1941,7 +2043,7 @@ func TestParseInterfaceMethod_Delete_Invalid(t *testing.T) { _, err := spec.ParseInterfaceMethod(structs, structModel, testCase.Method) if err != testCase.ExpectedError { - t.Errorf("\nExpected = %v\nReceived = %v", testCase.ExpectedError, err) + t.Errorf("\nExpected = %+v\nReceived = %+v", testCase.ExpectedError, err) } }) } @@ -2072,7 +2174,7 @@ func TestParseInterfaceMethod_Count_Invalid(t *testing.T) { _, err := spec.ParseInterfaceMethod(structs, structModel, testCase.Method) if err != testCase.ExpectedError { - t.Errorf("\nExpected = %v\nReceived = %v", testCase.ExpectedError, err) + t.Errorf("\nExpected = %+v\nReceived = %+v", testCase.ExpectedError, err) } }) } diff --git a/internal/spec/update.go b/internal/spec/update.go index 6be9552..ba9338d 100644 --- a/internal/spec/update.go +++ b/internal/spec/update.go @@ -51,6 +51,35 @@ func (u UpdateFields) NumberOfArguments() int { type UpdateField struct { FieldReference FieldReference ParamIndex int + Operator UpdateOperator +} + +// UpdateOperator is a custom type that declares update operator to be used in an update operation +type UpdateOperator string + +// UpdateOperator constants +const ( + UpdateOperatorSet UpdateOperator = "SET" + UpdateOperatorPush UpdateOperator = "PUSH" +) + +// NumberOfArguments returns number of arguments required to perform an update operation +func (o UpdateOperator) NumberOfArguments() int { + return 1 +} + +// ArgumentType returns type that is required for function parameter +func (o UpdateOperator) ArgumentType(fieldType code.Type) (code.Type, error) { + switch o { + case UpdateOperatorPush: + arrayType, ok := fieldType.(code.ArrayType) + if !ok { + return nil, PushNonArrayError + } + return arrayType.ContainedType, nil + default: + return fieldType, nil + } } func (p interfaceMethodParser) parseUpdateOperation(tokens []string) (Operation, error) { @@ -104,24 +133,51 @@ func (p interfaceMethodParser) parseUpdate(tokens []string) (Update, error) { paramIndex := 1 for _, updateFieldToken := range updateFieldTokens { - updateFieldReference, ok := p.fieldResolver.ResolveStructField(p.StructModel, updateFieldToken) - if !ok { - return nil, NewStructFieldNotFoundError(updateFieldToken) + updateField, err := p.parseUpdateField(updateFieldToken, paramIndex) + if err != nil { + return nil, err } - updateFields = append(updateFields, UpdateField{ - FieldReference: updateFieldReference, - ParamIndex: paramIndex, - }) - paramIndex++ + updateFields = append(updateFields, updateField) + paramIndex += updateField.Operator.NumberOfArguments() } for _, field := range updateFields { - if len(p.Method.Params) <= field.ParamIndex || - field.FieldReference.ReferencedField().Type != p.Method.Params[field.ParamIndex].Type { + if len(p.Method.Params) < field.ParamIndex+field.Operator.NumberOfArguments() { return nil, InvalidUpdateFieldsError } + + requiredType, err := field.Operator.ArgumentType(field.FieldReference.ReferencedField().Type) + if err != nil { + return nil, err + } + + for i := 0; i < field.Operator.NumberOfArguments(); i++ { + if requiredType != p.Method.Params[field.ParamIndex+i].Type { + return nil, InvalidUpdateFieldsError + } + } } return updateFields, nil } + +func (p interfaceMethodParser) parseUpdateField(t []string, paramIndex int) (UpdateField, error) { + if len(t) > 1 && t[len(t)-1] == "Push" { + return p.createUpdateField(t[:len(t)-1], UpdateOperatorPush, paramIndex) + } + return p.createUpdateField(t, UpdateOperatorSet, paramIndex) +} + +func (p interfaceMethodParser) createUpdateField(t []string, operator UpdateOperator, paramIndex int) (UpdateField, error) { + fieldReference, ok := p.fieldResolver.ResolveStructField(p.StructModel, t) + if !ok { + return UpdateField{}, NewStructFieldNotFoundError(t) + } + + return UpdateField{ + FieldReference: fieldReference, + ParamIndex: paramIndex, + Operator: operator, + }, nil +} diff --git a/internal/spec/update_test.go b/internal/spec/update_test.go index fcfab6a..f3b802f 100644 --- a/internal/spec/update_test.go +++ b/internal/spec/update_test.go @@ -26,7 +26,7 @@ func TestUpdateTypeName(t *testing.T) { for _, testCase := range testTable { t.Run(testCase.ExpectedName, func(t *testing.T) { if testCase.Update.Name() != testCase.ExpectedName { - t.Errorf("Expected = %v\nReceived = %v", testCase.ExpectedName, testCase.Update.Name()) + t.Errorf("Expected = %+v\nReceived = %+v", testCase.ExpectedName, testCase.Update.Name()) } }) }