Add functionality to update the whole model

This commit is contained in:
sunboyy 2021-02-24 19:02:57 +07:00
parent f7b0df5463
commit 87ecd3704b
10 changed files with 310 additions and 55 deletions

View file

@ -2,6 +2,8 @@ package mongo
import ( import (
"fmt" "fmt"
"github.com/sunboyy/repogen/internal/spec"
) )
// NewOperationNotSupportedError creates operationNotSupportedError // NewOperationNotSupportedError creates operationNotSupportedError
@ -29,3 +31,16 @@ type bsonTagNotFoundError struct {
func (err bsonTagNotFoundError) Error() string { func (err bsonTagNotFoundError) Error() string {
return fmt.Sprintf("bson tag of field '%s' not found", err.FieldName) return fmt.Sprintf("bson tag of field '%s' not found", err.FieldName)
} }
// NewUpdateTypeNotSupportedError creates updateTypeNotSupportedError
func NewUpdateTypeNotSupportedError(update spec.Update) error {
return updateTypeNotSupportedError{Update: update}
}
type updateTypeNotSupportedError struct {
Update spec.Update
}
func (err updateTypeNotSupportedError) Error() string {
return fmt.Sprintf("update type %s not supported", err.Update.Name())
}

View file

@ -12,6 +12,17 @@ type ErrorTestCase struct {
ExpectedString string ExpectedString string
} }
type StubUpdate struct {
}
func (update StubUpdate) Name() string {
return "Stub"
}
func (update StubUpdate) NumberOfArguments() int {
return 1
}
func TestError(t *testing.T) { func TestError(t *testing.T) {
testTable := []ErrorTestCase{ testTable := []ErrorTestCase{
{ {
@ -24,6 +35,11 @@ func TestError(t *testing.T) {
Error: mongo.NewBsonTagNotFoundError("AccessToken"), Error: mongo.NewBsonTagNotFoundError("AccessToken"),
ExpectedString: "bson tag of field 'AccessToken' not found", ExpectedString: "bson tag of field 'AccessToken' not found",
}, },
{
Name: "UpdateTypeNotSupportedError",
Error: mongo.NewUpdateTypeNotSupportedError(StubUpdate{}),
ExpectedString: "update type Stub not supported",
},
} }
for _, testCase := range testTable { for _, testCase := range testTable {

View file

@ -117,13 +117,9 @@ func (g RepositoryGenerator) generateFindImplementation(operation spec.FindOpera
} }
func (g RepositoryGenerator) generateUpdateImplementation(operation spec.UpdateOperation) (string, error) { func (g RepositoryGenerator) generateUpdateImplementation(operation spec.UpdateOperation) (string, error) {
var fields []updateField update, err := g.getMongoUpdate(operation.Update)
for _, field := range operation.Fields { if err != nil {
bsonTag, err := g.bsonTagFromFieldName(field.Name) return "", err
if err != nil {
return "", err
}
fields = append(fields, updateField{BsonTag: bsonTag, ParamIndex: field.ParamIndex})
} }
querySpec, err := g.mongoQuerySpec(operation.Query) querySpec, err := g.mongoQuerySpec(operation.Query)
@ -132,8 +128,8 @@ func (g RepositoryGenerator) generateUpdateImplementation(operation spec.UpdateO
} }
tmplData := mongoUpdateTemplateData{ tmplData := mongoUpdateTemplateData{
UpdateFields: fields, Update: update,
QuerySpec: querySpec, QuerySpec: querySpec,
} }
if operation.Mode == spec.QueryModeOne { if operation.Mode == spec.QueryModeOne {
@ -142,6 +138,25 @@ func (g RepositoryGenerator) generateUpdateImplementation(operation spec.UpdateO
return generateFromTemplate("mongo_repository_updatemany", updateManyTemplate, tmplData) return generateFromTemplate("mongo_repository_updatemany", updateManyTemplate, tmplData)
} }
func (g RepositoryGenerator) getMongoUpdate(updateSpec spec.Update) (update, error) {
switch updateSpec := updateSpec.(type) {
case spec.UpdateModel:
return updateModel{}, nil
case spec.UpdateFields:
var update updateFields
for _, field := range updateSpec {
bsonTag, err := g.bsonTagFromFieldName(field.Name)
if err != nil {
return nil, err
}
update.Fields = append(update.Fields, updateField{BsonTag: bsonTag, ParamIndex: field.ParamIndex})
}
return update, nil
default:
return nil, NewUpdateTypeNotSupportedError(updateSpec)
}
}
func (g RepositoryGenerator) generateDeleteImplementation(operation spec.DeleteOperation) (string, error) { func (g RepositoryGenerator) generateDeleteImplementation(operation spec.DeleteOperation) (string, error) {
querySpec, err := g.mongoQuerySpec(operation.Query) querySpec, err := g.mongoQuerySpec(operation.Query)
if err != nil { if err != nil {

View file

@ -712,6 +712,43 @@ func (r *UserRepositoryMongo) FindByEnabledFalse(arg0 context.Context) ([]*UserM
func TestGenerateMethod_Update(t *testing.T) { func TestGenerateMethod_Update(t *testing.T) {
testTable := []GenerateMethodTestCase{ testTable := []GenerateMethodTestCase{
{
Name: "update model method",
MethodSpec: spec.MethodSpec{
Name: "UpdateByID",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "model", Type: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
{Name: "id", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}},
},
Returns: []code.Type{
code.SimpleType("bool"),
code.SimpleType("error"),
},
Operation: spec.UpdateOperation{
Update: spec.UpdateModel{},
Mode: spec.QueryModeOne,
Query: spec.QuerySpec{
Predicates: []spec.Predicate{
{Field: "ID", Comparator: spec.ComparatorEqual, ParamIndex: 2},
},
},
},
},
ExpectedCode: `
func (r *UserRepositoryMongo) UpdateByID(arg0 context.Context, arg1 *UserModel, arg2 primitive.ObjectID) (bool, error) {
result, err := r.collection.UpdateOne(arg0, bson.M{
"_id": arg2,
}, bson.M{
"$set": arg1,
})
if err != nil {
return false, err
}
return result.MatchedCount > 0, err
}
`,
},
{ {
Name: "simple update one method", Name: "simple update one method",
MethodSpec: spec.MethodSpec{ MethodSpec: spec.MethodSpec{
@ -726,7 +763,7 @@ func TestGenerateMethod_Update(t *testing.T) {
code.SimpleType("error"), code.SimpleType("error"),
}, },
Operation: spec.UpdateOperation{ Operation: spec.UpdateOperation{
Fields: []spec.UpdateField{ Update: spec.UpdateFields{
{Name: "Age", ParamIndex: 1}, {Name: "Age", ParamIndex: 1},
}, },
Mode: spec.QueryModeOne, Mode: spec.QueryModeOne,
@ -767,7 +804,7 @@ func (r *UserRepositoryMongo) UpdateAgeByID(arg0 context.Context, arg1 int, arg2
code.SimpleType("error"), code.SimpleType("error"),
}, },
Operation: spec.UpdateOperation{ Operation: spec.UpdateOperation{
Fields: []spec.UpdateField{ Update: spec.UpdateFields{
{Name: "Age", ParamIndex: 1}, {Name: "Age", ParamIndex: 1},
}, },
Mode: spec.QueryModeMany, Mode: spec.QueryModeMany,
@ -1629,7 +1666,7 @@ func TestGenerateMethod_Invalid(t *testing.T) {
code.SimpleType("error"), code.SimpleType("error"),
}, },
Operation: spec.UpdateOperation{ Operation: spec.UpdateOperation{
Fields: []spec.UpdateField{ Update: spec.UpdateFields{
{Name: "AccessToken", ParamIndex: 1}, {Name: "AccessToken", ParamIndex: 1},
}, },
Mode: spec.QueryModeOne, Mode: spec.QueryModeOne,
@ -1642,6 +1679,31 @@ func TestGenerateMethod_Invalid(t *testing.T) {
}, },
ExpectedError: mongo.NewBsonTagNotFoundError("AccessToken"), ExpectedError: mongo.NewBsonTagNotFoundError("AccessToken"),
}, },
{
Name: "update type not supported",
Method: spec.MethodSpec{
Name: "UpdateAgeByID",
Params: []code.Param{
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Type: code.SimpleType("int")},
{Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}},
},
Returns: []code.Type{
code.SimpleType("bool"),
code.SimpleType("error"),
},
Operation: spec.UpdateOperation{
Update: StubUpdate{},
Mode: spec.QueryModeOne,
Query: spec.QuerySpec{
Predicates: []spec.Predicate{
{Field: "ID", Comparator: spec.ComparatorEqual, ParamIndex: 2},
},
},
},
},
ExpectedError: mongo.NewUpdateTypeNotSupportedError(StubUpdate{}),
},
} }
for _, testCase := range testTable { for _, testCase := range testTable {

View file

@ -12,6 +12,30 @@ type updateField struct {
ParamIndex int ParamIndex int
} }
type update interface {
Code() string
}
type updateModel struct {
}
func (u updateModel) Code() string {
return ` "$set": arg1,`
}
type updateFields struct {
Fields []updateField
}
func (u updateFields) Code() string {
lines := []string{` "$set": bson.M{`}
for _, field := range u.Fields {
lines = append(lines, fmt.Sprintf(` "%s": arg%d,`, field.BsonTag, field.ParamIndex))
}
lines = append(lines, ` },`)
return strings.Join(lines, "\n")
}
type querySpec struct { type querySpec struct {
Operator spec.Operator Operator spec.Operator
Predicates []predicate Predicates []predicate

View file

@ -112,16 +112,14 @@ const findManyTemplate = ` cursor, err := r.collection.Find(arg0, bson.M{
return entities, nil` return entities, nil`
type mongoUpdateTemplateData struct { type mongoUpdateTemplateData struct {
UpdateFields []updateField Update update
QuerySpec querySpec QuerySpec querySpec
} }
const updateOneTemplate = ` result, err := r.collection.UpdateOne(arg0, bson.M{ const updateOneTemplate = ` result, err := r.collection.UpdateOne(arg0, bson.M{
{{.QuerySpec.Code}} {{.QuerySpec.Code}}
}, bson.M{ }, bson.M{
"$set": bson.M{ {{.Update.Code}}
{{range $index, $element := .UpdateFields}} "{{$element.BsonTag}}": arg{{$element.ParamIndex}},
{{end}} },
}) })
if err != nil { if err != nil {
return false, err return false, err
@ -131,9 +129,7 @@ const updateOneTemplate = ` result, err := r.collection.UpdateOne(arg0, bson.M{
const updateManyTemplate = ` result, err := r.collection.UpdateMany(arg0, bson.M{ const updateManyTemplate = ` result, err := r.collection.UpdateMany(arg0, bson.M{
{{.QuerySpec.Code}} {{.QuerySpec.Code}}
}, bson.M{ }, bson.M{
"$set": bson.M{ {{.Update.Code}}
{{range $index, $element := .UpdateFields}} "{{$element.BsonTag}}": arg{{$element.ParamIndex}},
{{end}} },
}) })
if err != nil { if err != nil {
return 0, err return 0, err

View file

@ -49,11 +49,44 @@ func (o FindOperation) Name() string {
// UpdateOperation is a method specification for update operations // UpdateOperation is a method specification for update operations
type UpdateOperation struct { type UpdateOperation struct {
Fields []UpdateField Update Update
Mode QueryMode Mode QueryMode
Query QuerySpec Query QuerySpec
} }
// Update is an interface of update operation type
type Update interface {
Name() string
NumberOfArguments() int
}
// UpdateFields is a type of update operation that update specific fields
type UpdateFields []UpdateField
// Name returns UpdateFields name 'Fields'
func (u UpdateFields) Name() string {
return "Fields"
}
// NumberOfArguments returns number of update fields
func (u UpdateFields) NumberOfArguments() int {
return len(u)
}
// UpdateModel is a type of update operation that update the whole model
type UpdateModel struct {
}
// Name returns UpdateModel name 'Model'
func (u UpdateModel) Name() string {
return "Model"
}
// NumberOfArguments returns 1
func (u UpdateModel) NumberOfArguments() int {
return 1
}
// Name returns "Update" operation name // Name returns "Update" operation name
func (o UpdateOperation) Name() string { func (o UpdateOperation) Name() string {
return "Update" return "Update"

View file

@ -43,3 +43,29 @@ func TestOperationName(t *testing.T) {
}) })
} }
} }
type UpdateTypeTestCase struct {
Update spec.Update
ExpectedName string
}
func TestUpdateTypeName(t *testing.T) {
testTable := []UpdateTypeTestCase{
{
Update: spec.UpdateModel{},
ExpectedName: "Model",
},
{
Update: spec.UpdateFields{},
ExpectedName: "Fields",
},
}
for _, testCase := range testTable {
t.Run(testCase.ExpectedName, func(t *testing.T) {
if testCase.Update.Name() != testCase.ExpectedName {
t.Errorf("Expected = %v\nReceived = %v", testCase.ExpectedName, testCase.Update.Name())
}
})
}
}

View file

@ -169,18 +169,52 @@ func (p interfaceMethodParser) parseUpdateOperation(tokens []string) (Operation,
return nil, err return nil, err
} }
updateFieldTokens, queryTokens := p.splitUpdateFieldAndQueryTokens(tokens) if err := p.validateContextParam(); err != nil {
return nil, err
}
updateTokens, queryTokens := p.splitUpdateAndQueryTokens(tokens)
update, err := p.parseUpdate(updateTokens)
if err != nil {
return nil, err
}
querySpec, err := parseQuery(queryTokens, 1+update.NumberOfArguments())
if err != nil {
return nil, err
}
if err := p.validateQueryFromParams(p.Method.Params[update.NumberOfArguments()+1:], querySpec); err != nil {
return nil, err
}
return UpdateOperation{
Update: update,
Mode: mode,
Query: querySpec,
}, nil
}
func (p interfaceMethodParser) parseUpdate(tokens []string) (Update, error) {
if len(tokens) == 0 {
requiredType := code.PointerType{ContainedType: p.StructModel.ReferencedType()}
if len(p.Method.Params) <= 1 || p.Method.Params[1].Type != requiredType {
return nil, InvalidUpdateFieldsError
}
return UpdateModel{}, nil
}
paramIndex := 1 paramIndex := 1
var fields []UpdateField var update UpdateFields
var aggregatedToken string var aggregatedToken string
for _, token := range updateFieldTokens { for _, token := range tokens {
if token != "And" { if token != "And" {
aggregatedToken += token aggregatedToken += token
} else if len(aggregatedToken) == 0 { } else if len(aggregatedToken) == 0 {
return nil, InvalidUpdateFieldsError return nil, InvalidUpdateFieldsError
} else { } else {
fields = append(fields, UpdateField{Name: aggregatedToken, ParamIndex: paramIndex}) update = append(update, UpdateField{Name: aggregatedToken, ParamIndex: paramIndex})
paramIndex++ paramIndex++
aggregatedToken = "" aggregatedToken = ""
} }
@ -188,41 +222,24 @@ func (p interfaceMethodParser) parseUpdateOperation(tokens []string) (Operation,
if len(aggregatedToken) == 0 { if len(aggregatedToken) == 0 {
return nil, InvalidUpdateFieldsError return nil, InvalidUpdateFieldsError
} }
fields = append(fields, UpdateField{Name: aggregatedToken, ParamIndex: paramIndex}) update = append(update, UpdateField{Name: aggregatedToken, ParamIndex: paramIndex})
querySpec, err := parseQuery(queryTokens, 1+len(fields)) for _, field := range update {
if err != nil {
return nil, err
}
if err := p.validateContextParam(); err != nil {
return nil, err
}
for _, field := range fields {
structField, ok := p.StructModel.Fields.ByName(field.Name) structField, ok := p.StructModel.Fields.ByName(field.Name)
if !ok { if !ok {
return nil, NewStructFieldNotFoundError(field.Name) return nil, NewStructFieldNotFoundError(field.Name)
} }
if structField.Type != p.Method.Params[field.ParamIndex].Type { if len(p.Method.Params) <= field.ParamIndex || structField.Type != p.Method.Params[field.ParamIndex].Type {
return nil, InvalidParamError return nil, InvalidUpdateFieldsError
} }
} }
if err := p.validateQueryFromParams(p.Method.Params[len(fields)+1:], querySpec); err != nil { return update, nil
return nil, err
}
return UpdateOperation{
Fields: fields,
Mode: mode,
Query: querySpec,
}, nil
} }
func (p interfaceMethodParser) splitUpdateFieldAndQueryTokens(tokens []string) ([]string, []string) { func (p interfaceMethodParser) splitUpdateAndQueryTokens(tokens []string) ([]string, []string) {
var updateFieldTokens []string var updateTokens []string
var queryTokens []string var queryTokens []string
for i, token := range tokens { for i, token := range tokens {
@ -230,11 +247,11 @@ func (p interfaceMethodParser) splitUpdateFieldAndQueryTokens(tokens []string) (
queryTokens = tokens[i:] queryTokens = tokens[i:]
break break
} else { } else {
updateFieldTokens = append(updateFieldTokens, token) updateTokens = append(updateTokens, token)
} }
} }
return updateFieldTokens, queryTokens return updateTokens, queryTokens
} }
func (p interfaceMethodParser) parseDeleteOperation(tokens []string) (Operation, error) { func (p interfaceMethodParser) parseDeleteOperation(tokens []string) (Operation, error) {

View file

@ -453,6 +453,28 @@ func TestParseInterfaceMethod_Find(t *testing.T) {
func TestParseInterfaceMethod_Update(t *testing.T) { func TestParseInterfaceMethod_Update(t *testing.T) {
testTable := []ParseInterfaceMethodTestCase{ testTable := []ParseInterfaceMethodTestCase{
{
Name: "UpdateByArg",
Method: code.Method{
Name: "UpdateByID",
Params: []code.Param{
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Type: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
{Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}},
},
Returns: []code.Type{
code.SimpleType("bool"),
code.SimpleType("error"),
},
},
ExpectedOperation: spec.UpdateOperation{
Update: spec.UpdateModel{},
Mode: spec.QueryModeOne,
Query: spec.QuerySpec{Predicates: []spec.Predicate{
{Field: "ID", Comparator: spec.ComparatorEqual, ParamIndex: 2},
}},
},
},
{ {
Name: "UpdateArgByArg one method", Name: "UpdateArgByArg one method",
Method: code.Method{ Method: code.Method{
@ -468,7 +490,7 @@ func TestParseInterfaceMethod_Update(t *testing.T) {
}, },
}, },
ExpectedOperation: spec.UpdateOperation{ ExpectedOperation: spec.UpdateOperation{
Fields: []spec.UpdateField{ Update: spec.UpdateFields{
{Name: "Gender", ParamIndex: 1}, {Name: "Gender", ParamIndex: 1},
}, },
Mode: spec.QueryModeOne, Mode: spec.QueryModeOne,
@ -492,7 +514,7 @@ func TestParseInterfaceMethod_Update(t *testing.T) {
}, },
}, },
ExpectedOperation: spec.UpdateOperation{ ExpectedOperation: spec.UpdateOperation{
Fields: []spec.UpdateField{ Update: spec.UpdateFields{
{Name: "Gender", ParamIndex: 1}, {Name: "Gender", ParamIndex: 1},
}, },
Mode: spec.QueryModeMany, Mode: spec.QueryModeMany,
@ -517,7 +539,7 @@ func TestParseInterfaceMethod_Update(t *testing.T) {
}, },
}, },
ExpectedOperation: spec.UpdateOperation{ ExpectedOperation: spec.UpdateOperation{
Fields: []spec.UpdateField{ Update: spec.UpdateFields{
{Name: "Gender", ParamIndex: 1}, {Name: "Gender", ParamIndex: 1},
{Name: "City", ParamIndex: 2}, {Name: "City", ParamIndex: 2},
}, },
@ -1279,6 +1301,9 @@ func TestParseInterfaceMethod_Update_Invalid(t *testing.T) {
Name: "update with no field provided", Name: "update with no field provided",
Method: code.Method{ Method: code.Method{
Name: "UpdateByID", Name: "UpdateByID",
Params: []code.Param{
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
},
Returns: []code.Type{ Returns: []code.Type{
code.SimpleType("bool"), code.SimpleType("bool"),
code.SimpleType("error"), code.SimpleType("error"),
@ -1290,6 +1315,9 @@ func TestParseInterfaceMethod_Update_Invalid(t *testing.T) {
Name: "misplaced And token in update fields", Name: "misplaced And token in update fields",
Method: code.Method{ Method: code.Method{
Name: "UpdateAgeAndAndGenderByID", Name: "UpdateAgeAndAndGenderByID",
Params: []code.Param{
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
},
Returns: []code.Type{ Returns: []code.Type{
code.SimpleType("bool"), code.SimpleType("bool"),
code.SimpleType("error"), code.SimpleType("error"),
@ -1301,6 +1329,10 @@ func TestParseInterfaceMethod_Update_Invalid(t *testing.T) {
Name: "update method without query", Name: "update method without query",
Method: code.Method{ Method: code.Method{
Name: "UpdateCity", Name: "UpdateCity",
Params: []code.Param{
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Type: code.SimpleType("string")},
},
Returns: []code.Type{ Returns: []code.Type{
code.SimpleType("bool"), code.SimpleType("bool"),
code.SimpleType("error"), code.SimpleType("error"),
@ -1312,6 +1344,10 @@ func TestParseInterfaceMethod_Update_Invalid(t *testing.T) {
Name: "ambiguous query", Name: "ambiguous query",
Method: code.Method{ Method: code.Method{
Name: "UpdateAgeByIDAndUsernameOrGender", Name: "UpdateAgeByIDAndUsernameOrGender",
Params: []code.Param{
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Type: code.SimpleType("int")},
},
Returns: []code.Type{ Returns: []code.Type{
code.SimpleType("int"), code.SimpleType("int"),
code.SimpleType("error"), code.SimpleType("error"),
@ -1319,6 +1355,21 @@ func TestParseInterfaceMethod_Update_Invalid(t *testing.T) {
}, },
ExpectedError: spec.NewInvalidQueryError([]string{"By", "ID", "And", "Username", "Or", "Gender"}), ExpectedError: spec.NewInvalidQueryError([]string{"By", "ID", "And", "Username", "Or", "Gender"}),
}, },
{
Name: "update model with invalid parameter",
Method: code.Method{
Name: "UpdateByID",
Params: []code.Param{
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Type: code.SimpleType("string")},
},
Returns: []code.Type{
code.SimpleType("bool"),
code.SimpleType("error"),
},
},
ExpectedError: spec.InvalidUpdateFieldsError,
},
{ {
Name: "no context parameter", Name: "no context parameter",
Method: code.Method{ Method: code.Method{
@ -1364,7 +1415,7 @@ func TestParseInterfaceMethod_Update_Invalid(t *testing.T) {
code.SimpleType("error"), code.SimpleType("error"),
}, },
}, },
ExpectedError: spec.InvalidParamError, ExpectedError: spec.InvalidUpdateFieldsError,
}, },
{ {
Name: "struct field does not match parameter in query", Name: "struct field does not match parameter in query",