diff --git a/internal/mongo/errors.go b/internal/mongo/errors.go index a5faa9b..cbfcd81 100644 --- a/internal/mongo/errors.go +++ b/internal/mongo/errors.go @@ -2,6 +2,8 @@ package mongo import ( "fmt" + + "github.com/sunboyy/repogen/internal/spec" ) // NewOperationNotSupportedError creates operationNotSupportedError @@ -29,3 +31,16 @@ type bsonTagNotFoundError struct { func (err bsonTagNotFoundError) Error() string { return fmt.Sprintf("bson tag of field '%s' not found", err.FieldName) } + +// NewUpdateTypeNotSupportedError creates updateTypeNotSupportedError +func NewUpdateTypeNotSupportedError(update spec.Update) error { + return updateTypeNotSupportedError{Update: update} +} + +type updateTypeNotSupportedError struct { + Update spec.Update +} + +func (err updateTypeNotSupportedError) Error() string { + return fmt.Sprintf("update type %s not supported", err.Update.Name()) +} diff --git a/internal/mongo/errors_test.go b/internal/mongo/errors_test.go index b3d5a61..7d60a23 100644 --- a/internal/mongo/errors_test.go +++ b/internal/mongo/errors_test.go @@ -12,6 +12,17 @@ type ErrorTestCase struct { ExpectedString string } +type StubUpdate struct { +} + +func (update StubUpdate) Name() string { + return "Stub" +} + +func (update StubUpdate) NumberOfArguments() int { + return 1 +} + func TestError(t *testing.T) { testTable := []ErrorTestCase{ { @@ -24,6 +35,11 @@ func TestError(t *testing.T) { Error: mongo.NewBsonTagNotFoundError("AccessToken"), ExpectedString: "bson tag of field 'AccessToken' not found", }, + { + Name: "UpdateTypeNotSupportedError", + Error: mongo.NewUpdateTypeNotSupportedError(StubUpdate{}), + ExpectedString: "update type Stub not supported", + }, } for _, testCase := range testTable { diff --git a/internal/mongo/generator.go b/internal/mongo/generator.go index ded09dc..a70b871 100644 --- a/internal/mongo/generator.go +++ b/internal/mongo/generator.go @@ -117,13 +117,9 @@ func (g RepositoryGenerator) generateFindImplementation(operation spec.FindOpera } func (g RepositoryGenerator) generateUpdateImplementation(operation spec.UpdateOperation) (string, error) { - var fields []updateField - for _, field := range operation.Fields { - bsonTag, err := g.bsonTagFromFieldName(field.Name) - if err != nil { - return "", err - } - fields = append(fields, updateField{BsonTag: bsonTag, ParamIndex: field.ParamIndex}) + update, err := g.getMongoUpdate(operation.Update) + if err != nil { + return "", err } querySpec, err := g.mongoQuerySpec(operation.Query) @@ -132,8 +128,8 @@ func (g RepositoryGenerator) generateUpdateImplementation(operation spec.UpdateO } tmplData := mongoUpdateTemplateData{ - UpdateFields: fields, - QuerySpec: querySpec, + Update: update, + QuerySpec: querySpec, } if operation.Mode == spec.QueryModeOne { @@ -142,6 +138,25 @@ func (g RepositoryGenerator) generateUpdateImplementation(operation spec.UpdateO return generateFromTemplate("mongo_repository_updatemany", updateManyTemplate, tmplData) } +func (g RepositoryGenerator) getMongoUpdate(updateSpec spec.Update) (update, error) { + switch updateSpec := updateSpec.(type) { + case spec.UpdateModel: + return updateModel{}, nil + case spec.UpdateFields: + var update updateFields + for _, field := range updateSpec { + bsonTag, err := g.bsonTagFromFieldName(field.Name) + if err != nil { + return nil, err + } + update.Fields = append(update.Fields, updateField{BsonTag: bsonTag, ParamIndex: field.ParamIndex}) + } + return update, nil + default: + return nil, NewUpdateTypeNotSupportedError(updateSpec) + } +} + func (g RepositoryGenerator) generateDeleteImplementation(operation spec.DeleteOperation) (string, error) { querySpec, err := g.mongoQuerySpec(operation.Query) if err != nil { diff --git a/internal/mongo/generator_test.go b/internal/mongo/generator_test.go index d44a389..fe5ae05 100644 --- a/internal/mongo/generator_test.go +++ b/internal/mongo/generator_test.go @@ -712,6 +712,43 @@ func (r *UserRepositoryMongo) FindByEnabledFalse(arg0 context.Context) ([]*UserM func TestGenerateMethod_Update(t *testing.T) { testTable := []GenerateMethodTestCase{ + { + Name: "update model method", + MethodSpec: spec.MethodSpec{ + Name: "UpdateByID", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Name: "model", Type: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + {Name: "id", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, + }, + Returns: []code.Type{ + code.SimpleType("bool"), + code.SimpleType("error"), + }, + Operation: spec.UpdateOperation{ + Update: spec.UpdateModel{}, + Mode: spec.QueryModeOne, + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + {Field: "ID", Comparator: spec.ComparatorEqual, ParamIndex: 2}, + }, + }, + }, + }, + ExpectedCode: ` +func (r *UserRepositoryMongo) UpdateByID(arg0 context.Context, arg1 *UserModel, arg2 primitive.ObjectID) (bool, error) { + result, err := r.collection.UpdateOne(arg0, bson.M{ + "_id": arg2, + }, bson.M{ + "$set": arg1, + }) + if err != nil { + return false, err + } + return result.MatchedCount > 0, err +} +`, + }, { Name: "simple update one method", MethodSpec: spec.MethodSpec{ @@ -726,7 +763,7 @@ func TestGenerateMethod_Update(t *testing.T) { code.SimpleType("error"), }, Operation: spec.UpdateOperation{ - Fields: []spec.UpdateField{ + Update: spec.UpdateFields{ {Name: "Age", ParamIndex: 1}, }, Mode: spec.QueryModeOne, @@ -767,7 +804,7 @@ func (r *UserRepositoryMongo) UpdateAgeByID(arg0 context.Context, arg1 int, arg2 code.SimpleType("error"), }, Operation: spec.UpdateOperation{ - Fields: []spec.UpdateField{ + Update: spec.UpdateFields{ {Name: "Age", ParamIndex: 1}, }, Mode: spec.QueryModeMany, @@ -1629,7 +1666,7 @@ func TestGenerateMethod_Invalid(t *testing.T) { code.SimpleType("error"), }, Operation: spec.UpdateOperation{ - Fields: []spec.UpdateField{ + Update: spec.UpdateFields{ {Name: "AccessToken", ParamIndex: 1}, }, Mode: spec.QueryModeOne, @@ -1642,6 +1679,31 @@ func TestGenerateMethod_Invalid(t *testing.T) { }, ExpectedError: mongo.NewBsonTagNotFoundError("AccessToken"), }, + { + Name: "update type not supported", + Method: spec.MethodSpec{ + Name: "UpdateAgeByID", + Params: []code.Param{ + {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Type: code.SimpleType("int")}, + {Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, + }, + Returns: []code.Type{ + code.SimpleType("bool"), + code.SimpleType("error"), + }, + Operation: spec.UpdateOperation{ + Update: StubUpdate{}, + Mode: spec.QueryModeOne, + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + {Field: "ID", Comparator: spec.ComparatorEqual, ParamIndex: 2}, + }, + }, + }, + }, + ExpectedError: mongo.NewUpdateTypeNotSupportedError(StubUpdate{}), + }, } for _, testCase := range testTable { diff --git a/internal/mongo/models.go b/internal/mongo/models.go index 849d092..9825575 100644 --- a/internal/mongo/models.go +++ b/internal/mongo/models.go @@ -12,6 +12,30 @@ type updateField struct { ParamIndex int } +type update interface { + Code() string +} + +type updateModel struct { +} + +func (u updateModel) Code() string { + return ` "$set": arg1,` +} + +type updateFields struct { + Fields []updateField +} + +func (u updateFields) Code() string { + lines := []string{` "$set": bson.M{`} + for _, field := range u.Fields { + lines = append(lines, fmt.Sprintf(` "%s": arg%d,`, field.BsonTag, field.ParamIndex)) + } + lines = append(lines, ` },`) + return strings.Join(lines, "\n") +} + type querySpec struct { Operator spec.Operator Predicates []predicate diff --git a/internal/mongo/templates.go b/internal/mongo/templates.go index eaaa1ac..a9acff0 100644 --- a/internal/mongo/templates.go +++ b/internal/mongo/templates.go @@ -112,16 +112,14 @@ const findManyTemplate = ` cursor, err := r.collection.Find(arg0, bson.M{ return entities, nil` type mongoUpdateTemplateData struct { - UpdateFields []updateField - QuerySpec querySpec + Update update + QuerySpec querySpec } const updateOneTemplate = ` result, err := r.collection.UpdateOne(arg0, bson.M{ {{.QuerySpec.Code}} }, bson.M{ - "$set": bson.M{ -{{range $index, $element := .UpdateFields}} "{{$element.BsonTag}}": arg{{$element.ParamIndex}}, -{{end}} }, +{{.Update.Code}} }) if err != nil { return false, err @@ -131,9 +129,7 @@ const updateOneTemplate = ` result, err := r.collection.UpdateOne(arg0, bson.M{ const updateManyTemplate = ` result, err := r.collection.UpdateMany(arg0, bson.M{ {{.QuerySpec.Code}} }, bson.M{ - "$set": bson.M{ -{{range $index, $element := .UpdateFields}} "{{$element.BsonTag}}": arg{{$element.ParamIndex}}, -{{end}} }, +{{.Update.Code}} }) if err != nil { return 0, err diff --git a/internal/spec/models.go b/internal/spec/models.go index 5880620..f8e1720 100644 --- a/internal/spec/models.go +++ b/internal/spec/models.go @@ -49,11 +49,44 @@ func (o FindOperation) Name() string { // UpdateOperation is a method specification for update operations type UpdateOperation struct { - Fields []UpdateField + Update Update Mode QueryMode Query QuerySpec } +// Update is an interface of update operation type +type Update interface { + Name() string + NumberOfArguments() int +} + +// UpdateFields is a type of update operation that update specific fields +type UpdateFields []UpdateField + +// Name returns UpdateFields name 'Fields' +func (u UpdateFields) Name() string { + return "Fields" +} + +// NumberOfArguments returns number of update fields +func (u UpdateFields) NumberOfArguments() int { + return len(u) +} + +// UpdateModel is a type of update operation that update the whole model +type UpdateModel struct { +} + +// Name returns UpdateModel name 'Model' +func (u UpdateModel) Name() string { + return "Model" +} + +// NumberOfArguments returns 1 +func (u UpdateModel) NumberOfArguments() int { + return 1 +} + // Name returns "Update" operation name func (o UpdateOperation) Name() string { return "Update" diff --git a/internal/spec/models_test.go b/internal/spec/models_test.go index ba56c14..396d419 100644 --- a/internal/spec/models_test.go +++ b/internal/spec/models_test.go @@ -43,3 +43,29 @@ func TestOperationName(t *testing.T) { }) } } + +type UpdateTypeTestCase struct { + Update spec.Update + ExpectedName string +} + +func TestUpdateTypeName(t *testing.T) { + testTable := []UpdateTypeTestCase{ + { + Update: spec.UpdateModel{}, + ExpectedName: "Model", + }, + { + Update: spec.UpdateFields{}, + ExpectedName: "Fields", + }, + } + + for _, testCase := range testTable { + t.Run(testCase.ExpectedName, func(t *testing.T) { + if testCase.Update.Name() != testCase.ExpectedName { + t.Errorf("Expected = %v\nReceived = %v", testCase.ExpectedName, testCase.Update.Name()) + } + }) + } +} diff --git a/internal/spec/parser.go b/internal/spec/parser.go index 23e0334..f7f57a8 100644 --- a/internal/spec/parser.go +++ b/internal/spec/parser.go @@ -169,18 +169,52 @@ func (p interfaceMethodParser) parseUpdateOperation(tokens []string) (Operation, return nil, err } - updateFieldTokens, queryTokens := p.splitUpdateFieldAndQueryTokens(tokens) + if err := p.validateContextParam(); err != nil { + return nil, err + } + + updateTokens, queryTokens := p.splitUpdateAndQueryTokens(tokens) + + update, err := p.parseUpdate(updateTokens) + if err != nil { + return nil, err + } + + querySpec, err := parseQuery(queryTokens, 1+update.NumberOfArguments()) + if err != nil { + return nil, err + } + + if err := p.validateQueryFromParams(p.Method.Params[update.NumberOfArguments()+1:], querySpec); err != nil { + return nil, err + } + + return UpdateOperation{ + Update: update, + Mode: mode, + Query: querySpec, + }, nil +} + +func (p interfaceMethodParser) parseUpdate(tokens []string) (Update, error) { + if len(tokens) == 0 { + requiredType := code.PointerType{ContainedType: p.StructModel.ReferencedType()} + if len(p.Method.Params) <= 1 || p.Method.Params[1].Type != requiredType { + return nil, InvalidUpdateFieldsError + } + return UpdateModel{}, nil + } paramIndex := 1 - var fields []UpdateField + var update UpdateFields var aggregatedToken string - for _, token := range updateFieldTokens { + for _, token := range tokens { if token != "And" { aggregatedToken += token } else if len(aggregatedToken) == 0 { return nil, InvalidUpdateFieldsError } else { - fields = append(fields, UpdateField{Name: aggregatedToken, ParamIndex: paramIndex}) + update = append(update, UpdateField{Name: aggregatedToken, ParamIndex: paramIndex}) paramIndex++ aggregatedToken = "" } @@ -188,41 +222,24 @@ func (p interfaceMethodParser) parseUpdateOperation(tokens []string) (Operation, if len(aggregatedToken) == 0 { return nil, InvalidUpdateFieldsError } - fields = append(fields, UpdateField{Name: aggregatedToken, ParamIndex: paramIndex}) + update = append(update, UpdateField{Name: aggregatedToken, ParamIndex: paramIndex}) - querySpec, err := parseQuery(queryTokens, 1+len(fields)) - if err != nil { - return nil, err - } - - if err := p.validateContextParam(); err != nil { - return nil, err - } - - for _, field := range fields { + for _, field := range update { structField, ok := p.StructModel.Fields.ByName(field.Name) if !ok { return nil, NewStructFieldNotFoundError(field.Name) } - if structField.Type != p.Method.Params[field.ParamIndex].Type { - return nil, InvalidParamError + if len(p.Method.Params) <= field.ParamIndex || structField.Type != p.Method.Params[field.ParamIndex].Type { + return nil, InvalidUpdateFieldsError } } - if err := p.validateQueryFromParams(p.Method.Params[len(fields)+1:], querySpec); err != nil { - return nil, err - } - - return UpdateOperation{ - Fields: fields, - Mode: mode, - Query: querySpec, - }, nil + return update, nil } -func (p interfaceMethodParser) splitUpdateFieldAndQueryTokens(tokens []string) ([]string, []string) { - var updateFieldTokens []string +func (p interfaceMethodParser) splitUpdateAndQueryTokens(tokens []string) ([]string, []string) { + var updateTokens []string var queryTokens []string for i, token := range tokens { @@ -230,11 +247,11 @@ func (p interfaceMethodParser) splitUpdateFieldAndQueryTokens(tokens []string) ( queryTokens = tokens[i:] break } else { - updateFieldTokens = append(updateFieldTokens, token) + updateTokens = append(updateTokens, token) } } - return updateFieldTokens, queryTokens + return updateTokens, queryTokens } func (p interfaceMethodParser) parseDeleteOperation(tokens []string) (Operation, error) { diff --git a/internal/spec/parser_test.go b/internal/spec/parser_test.go index 5b109f2..48d53fe 100644 --- a/internal/spec/parser_test.go +++ b/internal/spec/parser_test.go @@ -453,6 +453,28 @@ func TestParseInterfaceMethod_Find(t *testing.T) { func TestParseInterfaceMethod_Update(t *testing.T) { testTable := []ParseInterfaceMethodTestCase{ + { + Name: "UpdateByArg", + Method: code.Method{ + Name: "UpdateByID", + Params: []code.Param{ + {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Type: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + {Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, + }, + Returns: []code.Type{ + code.SimpleType("bool"), + code.SimpleType("error"), + }, + }, + ExpectedOperation: spec.UpdateOperation{ + Update: spec.UpdateModel{}, + Mode: spec.QueryModeOne, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + {Field: "ID", Comparator: spec.ComparatorEqual, ParamIndex: 2}, + }}, + }, + }, { Name: "UpdateArgByArg one method", Method: code.Method{ @@ -468,7 +490,7 @@ func TestParseInterfaceMethod_Update(t *testing.T) { }, }, ExpectedOperation: spec.UpdateOperation{ - Fields: []spec.UpdateField{ + Update: spec.UpdateFields{ {Name: "Gender", ParamIndex: 1}, }, Mode: spec.QueryModeOne, @@ -492,7 +514,7 @@ func TestParseInterfaceMethod_Update(t *testing.T) { }, }, ExpectedOperation: spec.UpdateOperation{ - Fields: []spec.UpdateField{ + Update: spec.UpdateFields{ {Name: "Gender", ParamIndex: 1}, }, Mode: spec.QueryModeMany, @@ -517,7 +539,7 @@ func TestParseInterfaceMethod_Update(t *testing.T) { }, }, ExpectedOperation: spec.UpdateOperation{ - Fields: []spec.UpdateField{ + Update: spec.UpdateFields{ {Name: "Gender", ParamIndex: 1}, {Name: "City", ParamIndex: 2}, }, @@ -1279,6 +1301,9 @@ func TestParseInterfaceMethod_Update_Invalid(t *testing.T) { Name: "update with no field provided", Method: code.Method{ Name: "UpdateByID", + Params: []code.Param{ + {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + }, Returns: []code.Type{ code.SimpleType("bool"), code.SimpleType("error"), @@ -1290,6 +1315,9 @@ func TestParseInterfaceMethod_Update_Invalid(t *testing.T) { Name: "misplaced And token in update fields", Method: code.Method{ Name: "UpdateAgeAndAndGenderByID", + Params: []code.Param{ + {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + }, Returns: []code.Type{ code.SimpleType("bool"), code.SimpleType("error"), @@ -1301,6 +1329,10 @@ func TestParseInterfaceMethod_Update_Invalid(t *testing.T) { Name: "update method without query", Method: code.Method{ Name: "UpdateCity", + Params: []code.Param{ + {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Type: code.SimpleType("string")}, + }, Returns: []code.Type{ code.SimpleType("bool"), code.SimpleType("error"), @@ -1312,6 +1344,10 @@ func TestParseInterfaceMethod_Update_Invalid(t *testing.T) { Name: "ambiguous query", Method: code.Method{ Name: "UpdateAgeByIDAndUsernameOrGender", + Params: []code.Param{ + {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Type: code.SimpleType("int")}, + }, Returns: []code.Type{ code.SimpleType("int"), code.SimpleType("error"), @@ -1319,6 +1355,21 @@ func TestParseInterfaceMethod_Update_Invalid(t *testing.T) { }, ExpectedError: spec.NewInvalidQueryError([]string{"By", "ID", "And", "Username", "Or", "Gender"}), }, + { + Name: "update model with invalid parameter", + Method: code.Method{ + Name: "UpdateByID", + Params: []code.Param{ + {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Type: code.SimpleType("string")}, + }, + Returns: []code.Type{ + code.SimpleType("bool"), + code.SimpleType("error"), + }, + }, + ExpectedError: spec.InvalidUpdateFieldsError, + }, { Name: "no context parameter", Method: code.Method{ @@ -1364,7 +1415,7 @@ func TestParseInterfaceMethod_Update_Invalid(t *testing.T) { code.SimpleType("error"), }, }, - ExpectedError: spec.InvalidParamError, + ExpectedError: spec.InvalidUpdateFieldsError, }, { Name: "struct field does not match parameter in query",