diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 44ae7f3..2909f37 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -17,12 +17,20 @@ jobs: with: go-version: 1.15 + - name: Install dependencies + run: go get -u golang.org/x/lint/golint + - name: Build run: go build -v ./... - name: Test run: go test -v ./... -covermode=count -coverprofile=cover.out + - name: Vet & Lint + run: | + go vet ./... + golint ./... + - uses: codecov/codecov-action@v1 with: flags: unittests diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 5207bf6..ec1d7b5 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -7,24 +7,11 @@ on: branches: [ main ] jobs: - lint: + golangci-lint: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - - name: Set up Go - uses: actions/setup-go@v2 - with: - go-version: 1.15 - - - name: Install dependencies - run: go get -u golang.org/x/lint/golint - - - name: Vet & Lint - run: | - go vet ./... - golint ./... - - name: golangci-lint uses: golangci/golangci-lint-action@v2 with: diff --git a/README.md b/README.md index 8d485cd..5ce279a 100644 --- a/README.md +++ b/README.md @@ -32,7 +32,7 @@ $ go get github.com/sunboyy/repogen ### Step 2: Write a repository specification -Write repository specification as an interface in the same file as the model struct. There are 4 types of operations that are currently supported and are determined by the first word of the method name. Single-entity and multiple-entity modes are determined be the first return value. More complex queries can also be written. +Write repository specification as an interface in the same file as the model struct. There are 5 types of operations that are currently supported and are determined by the first word of the method name. Single-entity and multiple-entity modes are determined be the first return value. More complex queries can also be written. ```go // You write this interface specification (comment is optional) @@ -54,6 +54,10 @@ type UserRepository interface { // the match count. The error will be returned only when error occurs while accessing // the database. This is a MANY mode because the first return type is an integer. DeleteByCity(ctx context.Context, city string) (int, error) + + // CountByCity returns the number of rows that match the given city parameter. If an error occurs while + // accessing the database, error value will be returned. + CountByCity(ctx context.Context, city string) (int, error) } ``` @@ -77,3 +81,7 @@ $ repogen -src=examples/getting-started/user.go -dest=examples/getting-started/u ``` You can also write the above command in the `go:generate` format inside Go files in order to generate the implementation when `go generate` command is executed. + +## License + +Licensed under [MIT](https://github.com/sunboyy/repogen/blob/main/LICENSE) diff --git a/examples/getting-started/user.go b/examples/getting-started/user.go index 0f7b297..979b972 100644 --- a/examples/getting-started/user.go +++ b/examples/getting-started/user.go @@ -35,4 +35,8 @@ type UserRepository interface { // The error will be returned only when error occurs while accessing the database. This is a MANY mode // because the first return type is an integer. DeleteByCity(ctx context.Context, city string) (int, error) + + // CountByCity returns the number of rows that match the given city parameter. If an error occurs while + // accessing the database, error value will be returned. + CountByCity(ctx context.Context, city string) (int, error) } diff --git a/examples/getting-started/user_repo.go b/examples/getting-started/user_repo.go index afd6847..d95fa31 100644 --- a/examples/getting-started/user_repo.go +++ b/examples/getting-started/user_repo.go @@ -60,3 +60,13 @@ func (r *UserRepositoryMongo) DeleteByCity(arg0 context.Context, arg1 string) (i } return int(result.DeletedCount), nil } + +func (r *UserRepositoryMongo) CountByCity(arg0 context.Context, arg1 string) (int, error) { + count, err := r.collection.CountDocuments(arg0, bson.M{ + "city": arg1, + }) + if err != nil { + return 0, err + } + return int(count), nil +} diff --git a/internal/mongo/generator.go b/internal/mongo/generator.go index d51cba9..2a9526f 100644 --- a/internal/mongo/generator.go +++ b/internal/mongo/generator.go @@ -85,6 +85,8 @@ func (g RepositoryGenerator) generateMethodImplementation(methodSpec spec.Method return g.generateUpdateImplementation(operation) case spec.DeleteOperation: return g.generateDeleteImplementation(operation) + case spec.CountOperation: + return g.generateCountImplementation(operation) } return "", OperationNotSupportedError @@ -98,8 +100,6 @@ func (g RepositoryGenerator) generateInsertImplementation(operation spec.InsertO } func (g RepositoryGenerator) generateFindImplementation(operation spec.FindOperation) (string, error) { - buffer := new(bytes.Buffer) - querySpec, err := g.mongoQuerySpec(operation.Query) if err != nil { return "", err @@ -111,31 +111,12 @@ func (g RepositoryGenerator) generateFindImplementation(operation spec.FindOpera } if operation.Mode == spec.QueryModeOne { - tmpl, err := template.New("mongo_repository_findone").Parse(findOneTemplate) - if err != nil { - return "", err - } - - if err := tmpl.Execute(buffer, tmplData); err != nil { - return "", err - } - } else { - tmpl, err := template.New("mongo_repository_findmany").Parse(findManyTemplate) - if err != nil { - return "", err - } - - if err := tmpl.Execute(buffer, tmplData); err != nil { - return "", err - } + return generateFromTemplate("mongo_repository_findone", findOneTemplate, tmplData) } - - return buffer.String(), nil + return generateFromTemplate("mongo_repository_findmany", findManyTemplate, tmplData) } func (g RepositoryGenerator) generateUpdateImplementation(operation spec.UpdateOperation) (string, error) { - buffer := new(bytes.Buffer) - var fields []updateField for _, field := range operation.Fields { bsonTag, err := g.bsonTagFromFieldName(field.Name) @@ -156,31 +137,12 @@ func (g RepositoryGenerator) generateUpdateImplementation(operation spec.UpdateO } if operation.Mode == spec.QueryModeOne { - tmpl, err := template.New("mongo_repository_updateone").Parse(updateOneTemplate) - if err != nil { - return "", err - } - - if err := tmpl.Execute(buffer, tmplData); err != nil { - return "", err - } - } else { - tmpl, err := template.New("mongo_repository_updatemany").Parse(updateManyTemplate) - if err != nil { - return "", err - } - - if err := tmpl.Execute(buffer, tmplData); err != nil { - return "", err - } + return generateFromTemplate("mongo_repository_updateone", updateOneTemplate, tmplData) } - - return buffer.String(), nil + return generateFromTemplate("mongo_repository_updatemany", updateManyTemplate, tmplData) } func (g RepositoryGenerator) generateDeleteImplementation(operation spec.DeleteOperation) (string, error) { - buffer := new(bytes.Buffer) - querySpec, err := g.mongoQuerySpec(operation.Query) if err != nil { return "", err @@ -191,26 +153,22 @@ func (g RepositoryGenerator) generateDeleteImplementation(operation spec.DeleteO } if operation.Mode == spec.QueryModeOne { - tmpl, err := template.New("mongo_repository_deleteone").Parse(deleteOneTemplate) - if err != nil { - return "", err - } + return generateFromTemplate("mongo_repository_deleteone", deleteOneTemplate, tmplData) + } + return generateFromTemplate("mongo_repository_deletemany", deleteManyTemplate, tmplData) +} - if err := tmpl.Execute(buffer, tmplData); err != nil { - return "", err - } - } else { - tmpl, err := template.New("mongo_repository_deletemany").Parse(deleteManyTemplate) - if err != nil { - return "", err - } - - if err := tmpl.Execute(buffer, tmplData); err != nil { - return "", err - } +func (g RepositoryGenerator) generateCountImplementation(operation spec.CountOperation) (string, error) { + querySpec, err := g.mongoQuerySpec(operation.Query) + if err != nil { + return "", err } - return buffer.String(), nil + tmplData := mongoCountTemplateData{ + QuerySpec: querySpec, + } + + return generateFromTemplate("mongo_repository_count", countTemplate, tmplData) } func (g RepositoryGenerator) mongoQuerySpec(query spec.QuerySpec) (querySpec, error) { @@ -252,3 +210,17 @@ func (g RepositoryGenerator) bsonTagFromFieldName(fieldName string) (string, err func (g RepositoryGenerator) structName() string { return g.InterfaceName + "Mongo" } + +func generateFromTemplate(name string, templateString string, tmplData interface{}) (string, error) { + tmpl, err := template.New(name).Parse(templateString) + if err != nil { + return "", err + } + + buffer := new(bytes.Buffer) + if err := tmpl.Execute(buffer, tmplData); err != nil { + return "", err + } + + return buffer.String(), nil +} diff --git a/internal/mongo/generator_test.go b/internal/mongo/generator_test.go index f274221..20f238f 100644 --- a/internal/mongo/generator_test.go +++ b/internal/mongo/generator_test.go @@ -1089,6 +1089,358 @@ func (r *UserRepositoryMongo) DeleteByGenderIn(arg0 context.Context, arg1 []Gend } } +func TestGenerateMethod_Count(t *testing.T) { + testTable := []GenerateMethodTestCase{ + { + Name: "simple count method", + MethodSpec: spec.MethodSpec{ + Name: "CountByGender", + Params: []code.Param{ + {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Type: code.SimpleType("Gender")}, + }, + Returns: []code.Type{ + code.SimpleType("int"), + code.SimpleType("error"), + }, + Operation: spec.CountOperation{ + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + {Field: "Gender", Comparator: spec.ComparatorEqual, ParamIndex: 1}, + }, + }, + }, + }, + ExpectedCode: ` +func (r *UserRepositoryMongo) CountByGender(arg0 context.Context, arg1 Gender) (int, error) { + count, err := r.collection.CountDocuments(arg0, bson.M{ + "gender": arg1, + }) + if err != nil { + return 0, err + } + return int(count), nil +} +`, + }, + { + Name: "count with And operator", + MethodSpec: spec.MethodSpec{ + Name: "CountByGenderAndCity", + Params: []code.Param{ + {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Type: code.SimpleType("Gender")}, + {Type: code.SimpleType("int")}, + }, + Returns: []code.Type{ + code.SimpleType("int"), + code.SimpleType("error"), + }, + Operation: spec.CountOperation{ + Query: spec.QuerySpec{ + Operator: spec.OperatorAnd, + Predicates: []spec.Predicate{ + {Field: "Gender", Comparator: spec.ComparatorEqual, ParamIndex: 1}, + {Field: "Age", Comparator: spec.ComparatorEqual, ParamIndex: 2}, + }, + }, + }, + }, + ExpectedCode: ` +func (r *UserRepositoryMongo) CountByGenderAndCity(arg0 context.Context, arg1 Gender, arg2 int) (int, error) { + count, err := r.collection.CountDocuments(arg0, bson.M{ + "gender": arg1, + "age": arg2, + }) + if err != nil { + return 0, err + } + return int(count), nil +} +`, + }, + { + Name: "count with Or operator", + MethodSpec: spec.MethodSpec{ + Name: "CountByGenderOrCity", + Params: []code.Param{ + {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Type: code.SimpleType("Gender")}, + {Type: code.SimpleType("int")}, + }, + Returns: []code.Type{ + code.SimpleType("int"), + code.SimpleType("error"), + }, + Operation: spec.CountOperation{ + Query: spec.QuerySpec{ + Operator: spec.OperatorOr, + Predicates: []spec.Predicate{ + {Field: "Gender", Comparator: spec.ComparatorEqual, ParamIndex: 1}, + {Field: "Age", Comparator: spec.ComparatorEqual, ParamIndex: 2}, + }, + }, + }, + }, + ExpectedCode: ` +func (r *UserRepositoryMongo) CountByGenderOrCity(arg0 context.Context, arg1 Gender, arg2 int) (int, error) { + count, err := r.collection.CountDocuments(arg0, bson.M{ + "$or": []bson.M{ + {"gender": arg1}, + {"age": arg2}, + }, + }) + if err != nil { + return 0, err + } + return int(count), nil +} +`, + }, + { + Name: "count with Not comparator", + MethodSpec: spec.MethodSpec{ + Name: "CountByGenderNot", + Params: []code.Param{ + {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Type: code.SimpleType("Gender")}, + }, + Returns: []code.Type{ + code.SimpleType("int"), + code.SimpleType("error"), + }, + Operation: spec.CountOperation{ + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + {Field: "Gender", Comparator: spec.ComparatorNot, ParamIndex: 1}, + }, + }, + }, + }, + ExpectedCode: ` +func (r *UserRepositoryMongo) CountByGenderNot(arg0 context.Context, arg1 Gender) (int, error) { + count, err := r.collection.CountDocuments(arg0, bson.M{ + "gender": bson.M{"$ne": arg1}, + }) + if err != nil { + return 0, err + } + return int(count), nil +} +`, + }, + { + Name: "count with LessThan comparator", + MethodSpec: spec.MethodSpec{ + Name: "CountByAgeLessThan", + Params: []code.Param{ + {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Type: code.SimpleType("int")}, + }, + Returns: []code.Type{ + code.SimpleType("int"), + code.SimpleType("error"), + }, + Operation: spec.CountOperation{ + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + {Field: "Age", Comparator: spec.ComparatorLessThan, ParamIndex: 1}, + }, + }, + }, + }, + ExpectedCode: ` +func (r *UserRepositoryMongo) CountByAgeLessThan(arg0 context.Context, arg1 int) (int, error) { + count, err := r.collection.CountDocuments(arg0, bson.M{ + "age": bson.M{"$lt": arg1}, + }) + if err != nil { + return 0, err + } + return int(count), nil +} +`, + }, + { + Name: "count with LessThanEqual comparator", + MethodSpec: spec.MethodSpec{ + Name: "CountByAgeLessThanEqual", + Params: []code.Param{ + {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Type: code.SimpleType("int")}, + }, + Returns: []code.Type{ + code.SimpleType("int"), + code.SimpleType("error"), + }, + Operation: spec.CountOperation{ + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + {Field: "Age", Comparator: spec.ComparatorLessThanEqual, ParamIndex: 1}, + }, + }, + }, + }, + ExpectedCode: ` +func (r *UserRepositoryMongo) CountByAgeLessThanEqual(arg0 context.Context, arg1 int) (int, error) { + count, err := r.collection.CountDocuments(arg0, bson.M{ + "age": bson.M{"$lte": arg1}, + }) + if err != nil { + return 0, err + } + return int(count), nil +} +`, + }, + { + Name: "count with GreaterThan comparator", + MethodSpec: spec.MethodSpec{ + Name: "CountByAgeGreaterThan", + Params: []code.Param{ + {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Type: code.SimpleType("int")}, + }, + Returns: []code.Type{ + code.SimpleType("int"), + code.SimpleType("error"), + }, + Operation: spec.CountOperation{ + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + {Field: "Age", Comparator: spec.ComparatorGreaterThan, ParamIndex: 1}, + }, + }, + }, + }, + ExpectedCode: ` +func (r *UserRepositoryMongo) CountByAgeGreaterThan(arg0 context.Context, arg1 int) (int, error) { + count, err := r.collection.CountDocuments(arg0, bson.M{ + "age": bson.M{"$gt": arg1}, + }) + if err != nil { + return 0, err + } + return int(count), nil +} +`, + }, + { + Name: "count with GreaterThanEqual comparator", + MethodSpec: spec.MethodSpec{ + Name: "CountByAgeGreaterThanEqual", + Params: []code.Param{ + {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Type: code.SimpleType("int")}, + }, + Returns: []code.Type{ + code.SimpleType("int"), + code.SimpleType("error"), + }, + Operation: spec.CountOperation{ + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + {Field: "Age", Comparator: spec.ComparatorGreaterThanEqual, ParamIndex: 1}, + }, + }, + }, + }, + ExpectedCode: ` +func (r *UserRepositoryMongo) CountByAgeGreaterThanEqual(arg0 context.Context, arg1 int) (int, error) { + count, err := r.collection.CountDocuments(arg0, bson.M{ + "age": bson.M{"$gte": arg1}, + }) + if err != nil { + return 0, err + } + return int(count), nil +} +`, + }, + { + Name: "count with Between comparator", + MethodSpec: spec.MethodSpec{ + Name: "CountByAgeBetween", + Params: []code.Param{ + {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Type: code.SimpleType("int")}, + {Type: code.SimpleType("int")}, + }, + Returns: []code.Type{ + code.SimpleType("int"), + code.SimpleType("error"), + }, + Operation: spec.CountOperation{ + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + {Field: "Age", Comparator: spec.ComparatorBetween, ParamIndex: 1}, + }, + }, + }, + }, + ExpectedCode: ` +func (r *UserRepositoryMongo) CountByAgeBetween(arg0 context.Context, arg1 int, arg2 int) (int, error) { + count, err := r.collection.CountDocuments(arg0, bson.M{ + "age": bson.M{"$gte": arg1, "$lte": arg2}, + }) + if err != nil { + return 0, err + } + return int(count), nil +} +`, + }, + { + Name: "count with In comparator", + MethodSpec: spec.MethodSpec{ + Name: "CountByAgeIn", + Params: []code.Param{ + {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Type: code.ArrayType{ContainedType: code.SimpleType("int")}}, + }, + Returns: []code.Type{ + code.SimpleType("int"), + code.SimpleType("error"), + }, + Operation: spec.CountOperation{ + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + {Field: "Age", Comparator: spec.ComparatorIn, ParamIndex: 1}, + }, + }, + }, + }, + ExpectedCode: ` +func (r *UserRepositoryMongo) CountByAgeIn(arg0 context.Context, arg1 []int) (int, error) { + count, err := r.collection.CountDocuments(arg0, bson.M{ + "age": bson.M{"$in": arg1}, + }) + if err != nil { + return 0, err + } + return int(count), nil +} +`, + }, + } + + for _, testCase := range testTable { + t.Run(testCase.Name, func(t *testing.T) { + generator := mongo.NewGenerator(userModel, "UserRepository") + buffer := new(bytes.Buffer) + + err := generator.GenerateMethod(testCase.MethodSpec, buffer) + + if err != nil { + t.Error(err) + } + if err := testutils.ExpectMultiLineString(testCase.ExpectedCode, buffer.String()); err != nil { + t.Error(err) + } + }) + } +} + type GenerateMethodInvalidTestCase struct { Name string Method spec.MethodSpec diff --git a/internal/mongo/templates.go b/internal/mongo/templates.go index ca74a49..eaaa1ac 100644 --- a/internal/mongo/templates.go +++ b/internal/mongo/templates.go @@ -159,3 +159,15 @@ const deleteManyTemplate = ` result, err := r.collection.DeleteMany(arg0, bson.M return 0, err } return int(result.DeletedCount), nil` + +type mongoCountTemplateData struct { + QuerySpec querySpec +} + +const countTemplate = ` count, err := r.collection.CountDocuments(arg0, bson.M{ +{{.QuerySpec.Code}} + }) + if err != nil { + return 0, err + } + return int(count), nil` diff --git a/internal/spec/models.go b/internal/spec/models.go index 89d0b27..9e635be 100644 --- a/internal/spec/models.go +++ b/internal/spec/models.go @@ -1,8 +1,6 @@ package spec import ( - "strings" - "github.com/sunboyy/repogen/internal/code" ) @@ -57,91 +55,7 @@ type DeleteOperation struct { Query QuerySpec } -// QuerySpec is a set of conditions of querying the database -type QuerySpec struct { - Operator Operator - Predicates []Predicate -} - -// NumberOfArguments returns number of arguments required to perform the query -func (q QuerySpec) NumberOfArguments() int { - var totalArgs int - for _, predicate := range q.Predicates { - totalArgs += predicate.Comparator.NumberOfArguments() - } - return totalArgs -} - -// Operator is a boolean operator for merging conditions -type Operator string - -// boolean operator types -const ( - OperatorAnd Operator = "AND" - OperatorOr Operator = "OR" -) - -// Comparator is a comparison operator of the condition to query the data -type Comparator string - -// comparator types -const ( - ComparatorNot Comparator = "NOT" - ComparatorEqual Comparator = "EQUAL" - ComparatorLessThan Comparator = "LESS_THAN" - ComparatorLessThanEqual Comparator = "LESS_THAN_EQUAL" - ComparatorGreaterThan Comparator = "GREATER_THAN" - ComparatorGreaterThanEqual Comparator = "GREATER_THAN_EQUAL" - ComparatorBetween Comparator = "BETWEEN" - ComparatorIn Comparator = "IN" -) - -// ArgumentTypeFromFieldType returns a type of required argument from the given struct field type -func (c Comparator) ArgumentTypeFromFieldType(t code.Type) code.Type { - if c == ComparatorIn { - return code.ArrayType{ContainedType: t} - } - return t -} - -// NumberOfArguments returns the number of arguments required to perform the comparison -func (c Comparator) NumberOfArguments() int { - if c == ComparatorBetween { - return 2 - } - return 1 -} - -// Predicate is a criteria for querying a field -type Predicate struct { - Field string - Comparator Comparator - ParamIndex int -} - -type predicateToken []string - -func (t predicateToken) ToPredicate(paramIndex int) Predicate { - if len(t) > 1 && t[len(t)-1] == "Not" { - return Predicate{Field: strings.Join(t[:len(t)-1], ""), Comparator: ComparatorNot, ParamIndex: paramIndex} - } - if len(t) > 2 && t[len(t)-2] == "Less" && t[len(t)-1] == "Than" { - return Predicate{Field: strings.Join(t[:len(t)-2], ""), Comparator: ComparatorLessThan, ParamIndex: paramIndex} - } - if len(t) > 3 && t[len(t)-3] == "Less" && t[len(t)-2] == "Than" && t[len(t)-1] == "Equal" { - return Predicate{Field: strings.Join(t[:len(t)-3], ""), Comparator: ComparatorLessThanEqual, ParamIndex: paramIndex} - } - if len(t) > 2 && t[len(t)-2] == "Greater" && t[len(t)-1] == "Than" { - return Predicate{Field: strings.Join(t[:len(t)-2], ""), Comparator: ComparatorGreaterThan, ParamIndex: paramIndex} - } - if len(t) > 3 && t[len(t)-3] == "Greater" && t[len(t)-2] == "Than" && t[len(t)-1] == "Equal" { - return Predicate{Field: strings.Join(t[:len(t)-3], ""), Comparator: ComparatorGreaterThanEqual, ParamIndex: paramIndex} - } - if len(t) > 1 && t[len(t)-1] == "In" { - return Predicate{Field: strings.Join(t[:len(t)-1], ""), Comparator: ComparatorIn, ParamIndex: paramIndex} - } - if len(t) > 1 && t[len(t)-1] == "Between" { - return Predicate{Field: strings.Join(t[:len(t)-1], ""), Comparator: ComparatorBetween, ParamIndex: paramIndex} - } - return Predicate{Field: strings.Join(t, ""), Comparator: ComparatorEqual, ParamIndex: paramIndex} +// CountOperation is a method specification for count operations +type CountOperation struct { + Query QuerySpec } diff --git a/internal/spec/parser.go b/internal/spec/parser.go index 7a6167c..5c0230f 100644 --- a/internal/spec/parser.go +++ b/internal/spec/parser.go @@ -31,6 +31,8 @@ func (p interfaceMethodParser) Parse() (MethodSpec, error) { return p.parseUpdateMethod(methodNameTokens[1:]) case "Delete": return p.parseDeleteMethod(methodNameTokens[1:]) + case "Count": + return p.parseCountMethod(methodNameTokens[1:]) } return MethodSpec{}, UnknownOperationError } @@ -55,14 +57,9 @@ func (p interfaceMethodParser) parseInsertMethod(tokens []string) (MethodSpec, e return MethodSpec{}, InvalidParamError } - return MethodSpec{ - Name: p.Method.Name, - Params: p.Method.Params, - Returns: p.Method.Returns, - Operation: InsertOperation{ - Mode: mode, - }, - }, nil + return p.createMethodSpec(InsertOperation{ + Mode: mode, + }), nil } func (p interfaceMethodParser) extractInsertReturns(returns []code.Type) (QueryMode, error) { @@ -99,12 +96,12 @@ func (p interfaceMethodParser) parseFindMethod(tokens []string) (MethodSpec, err return MethodSpec{}, UnsupportedNameError } - mode, err := p.extractFindReturns(p.Method.Returns) + mode, err := p.extractModelOrSliceReturns(p.Method.Returns) if err != nil { return MethodSpec{}, err } - querySpec, err := p.parseQuery(tokens, 1) + querySpec, err := parseQuery(tokens, 1) if err != nil { return MethodSpec{}, err } @@ -117,18 +114,13 @@ func (p interfaceMethodParser) parseFindMethod(tokens []string) (MethodSpec, err return MethodSpec{}, err } - return MethodSpec{ - Name: p.Method.Name, - Params: p.Method.Params, - Returns: p.Method.Returns, - Operation: FindOperation{ - Mode: mode, - Query: querySpec, - }, - }, nil + return p.createMethodSpec(FindOperation{ + Mode: mode, + Query: querySpec, + }), nil } -func (p interfaceMethodParser) extractFindReturns(returns []code.Type) (QueryMode, error) { +func (p interfaceMethodParser) extractModelOrSliceReturns(returns []code.Type) (QueryMode, error) { if len(returns) != 2 { return "", UnsupportedReturnError } @@ -166,7 +158,7 @@ func (p interfaceMethodParser) parseUpdateMethod(tokens []string) (MethodSpec, e return MethodSpec{}, UnsupportedNameError } - mode, err := p.extractCountReturns(p.Method.Returns) + mode, err := p.extractIntOrBoolReturns(p.Method.Returns) if err != nil { return MethodSpec{}, err } @@ -193,7 +185,7 @@ func (p interfaceMethodParser) parseUpdateMethod(tokens []string) (MethodSpec, e } fields = append(fields, UpdateField{Name: aggregatedToken, ParamIndex: paramIndex}) - querySpec, err := p.parseQuery(tokens, 1+len(fields)) + querySpec, err := parseQuery(tokens, 1+len(fields)) if err != nil { return MethodSpec{}, err } @@ -217,16 +209,11 @@ func (p interfaceMethodParser) parseUpdateMethod(tokens []string) (MethodSpec, e return MethodSpec{}, err } - return MethodSpec{ - Name: p.Method.Name, - Params: p.Method.Params, - Returns: p.Method.Returns, - Operation: UpdateOperation{ - Fields: fields, - Mode: mode, - Query: querySpec, - }, - }, nil + return p.createMethodSpec(UpdateOperation{ + Fields: fields, + Mode: mode, + Query: querySpec, + }), nil } func (p interfaceMethodParser) parseDeleteMethod(tokens []string) (MethodSpec, error) { @@ -234,12 +221,12 @@ func (p interfaceMethodParser) parseDeleteMethod(tokens []string) (MethodSpec, e return MethodSpec{}, UnsupportedNameError } - mode, err := p.extractCountReturns(p.Method.Returns) + mode, err := p.extractIntOrBoolReturns(p.Method.Returns) if err != nil { return MethodSpec{}, err } - querySpec, err := p.parseQuery(tokens, 1) + querySpec, err := parseQuery(tokens, 1) if err != nil { return MethodSpec{}, err } @@ -252,18 +239,56 @@ func (p interfaceMethodParser) parseDeleteMethod(tokens []string) (MethodSpec, e return MethodSpec{}, err } - return MethodSpec{ - Name: p.Method.Name, - Params: p.Method.Params, - Returns: p.Method.Returns, - Operation: DeleteOperation{ - Mode: mode, - Query: querySpec, - }, - }, nil + return p.createMethodSpec(DeleteOperation{ + Mode: mode, + Query: querySpec, + }), nil } -func (p interfaceMethodParser) extractCountReturns(returns []code.Type) (QueryMode, error) { +func (p interfaceMethodParser) parseCountMethod(tokens []string) (MethodSpec, error) { + if len(tokens) == 0 { + return MethodSpec{}, UnsupportedNameError + } + + if err := p.validateCountReturns(p.Method.Returns); err != nil { + return MethodSpec{}, err + } + + querySpec, err := parseQuery(tokens, 1) + if err != nil { + return MethodSpec{}, err + } + + if err := p.validateContextParam(); err != nil { + return MethodSpec{}, err + } + + if err := p.validateQueryFromParams(p.Method.Params[1:], querySpec); err != nil { + return MethodSpec{}, err + } + + return p.createMethodSpec(CountOperation{ + Query: querySpec, + }), nil +} + +func (p interfaceMethodParser) validateCountReturns(returns []code.Type) error { + if len(returns) != 2 { + return UnsupportedReturnError + } + + if returns[0] != code.SimpleType("int") { + return UnsupportedReturnError + } + + if returns[1] != code.SimpleType("error") { + return UnsupportedReturnError + } + + return nil +} + +func (p interfaceMethodParser) extractIntOrBoolReturns(returns []code.Type) (QueryMode, error) { if len(returns) != 2 { return "", UnsupportedReturnError } @@ -285,58 +310,6 @@ func (p interfaceMethodParser) extractCountReturns(returns []code.Type) (QueryMo return "", UnsupportedReturnError } -func (p interfaceMethodParser) parseQuery(tokens []string, paramIndex int) (QuerySpec, error) { - if len(tokens) == 0 { - return QuerySpec{}, InvalidQueryError - } - - if len(tokens) == 1 && tokens[0] == "All" { - return QuerySpec{}, nil - } - - if tokens[0] == "One" { - tokens = tokens[1:] - } - if tokens[0] == "By" { - tokens = tokens[1:] - } - - if tokens[0] == "And" || tokens[0] == "Or" { - return QuerySpec{}, InvalidQueryError - } - - var operator Operator - var predicates []Predicate - var aggregatedToken predicateToken - for _, token := range tokens { - if token != "And" && token != "Or" { - aggregatedToken = append(aggregatedToken, token) - } else if len(aggregatedToken) == 0 { - return QuerySpec{}, InvalidQueryError - } else if token == "And" && operator != OperatorOr { - operator = OperatorAnd - predicate := aggregatedToken.ToPredicate(paramIndex) - predicates = append(predicates, predicate) - paramIndex += predicate.Comparator.NumberOfArguments() - aggregatedToken = predicateToken{} - } else if token == "Or" && operator != OperatorAnd { - operator = OperatorOr - predicate := aggregatedToken.ToPredicate(paramIndex) - predicates = append(predicates, predicate) - paramIndex += predicate.Comparator.NumberOfArguments() - aggregatedToken = predicateToken{} - } else { - return QuerySpec{}, InvalidQueryError - } - } - if len(aggregatedToken) == 0 { - return QuerySpec{}, InvalidQueryError - } - predicates = append(predicates, aggregatedToken.ToPredicate(paramIndex)) - - return QuerySpec{Operator: operator, Predicates: predicates}, nil -} - func (p interfaceMethodParser) validateContextParam() error { contextType := code.ExternalType{PackageAlias: "context", Name: "Context"} if len(p.Method.Params) == 0 || p.Method.Params[0].Type != contextType { @@ -368,3 +341,12 @@ func (p interfaceMethodParser) validateQueryFromParams(params []code.Param, quer return nil } + +func (p interfaceMethodParser) createMethodSpec(operation Operation) MethodSpec { + return MethodSpec{ + Name: p.Method.Name, + Params: p.Method.Params, + Returns: p.Method.Returns, + Operation: operation, + } +} diff --git a/internal/spec/parser_test.go b/internal/spec/parser_test.go index b523e7a..9f7618a 100644 --- a/internal/spec/parser_test.go +++ b/internal/spec/parser_test.go @@ -778,6 +778,67 @@ func TestParseInterfaceMethod_Delete(t *testing.T) { } } +func TestParseInterfaceMethod_Count(t *testing.T) { + testTable := []ParseInterfaceMethodTestCase{ + { + Name: "CountAll method", + Method: code.Method{ + Name: "CountAll", + Params: []code.Param{ + {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + }, + Returns: []code.Type{ + code.SimpleType("int"), + code.SimpleType("error"), + }, + }, + ExpectedOperation: spec.CountOperation{ + Query: spec.QuerySpec{}, + }, + }, + { + Name: "CountByArg method", + Method: code.Method{ + Name: "CountByGender", + Params: []code.Param{ + {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Type: code.SimpleType("Gender")}, + }, + Returns: []code.Type{ + code.SimpleType("int"), + code.SimpleType("error"), + }, + }, + ExpectedOperation: spec.CountOperation{ + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + {Field: "Gender", Comparator: spec.ComparatorEqual, ParamIndex: 1}, + }, + }, + }, + }, + } + + for _, testCase := range testTable { + t.Run(testCase.Name, func(t *testing.T) { + actualSpec, err := spec.ParseInterfaceMethod(structModel, testCase.Method) + + if err != nil { + t.Errorf("Error = %s", err) + } + expectedOutput := spec.MethodSpec{ + Name: testCase.Method.Name, + Params: testCase.Method.Params, + Returns: testCase.Method.Returns, + Operation: testCase.ExpectedOperation, + } + if !reflect.DeepEqual(actualSpec, expectedOutput) { + t.Errorf("Expected = %v\nReceived = %v", expectedOutput, actualSpec) + } + }) + } +} + type ParseInterfaceMethodInvalidTestCase struct { Name string Method code.Method @@ -1403,3 +1464,142 @@ func TestParseInterfaceMethod_Delete_Invalid(t *testing.T) { }) } } + +func TestParseInterfaceMethod_Count_Invalid(t *testing.T) { + testTable := []ParseInterfaceMethodInvalidTestCase{ + { + Name: "unsupported count method name", + Method: code.Method{ + Name: "Count", + }, + ExpectedError: spec.UnsupportedNameError, + }, + { + Name: "invalid number of returns", + Method: code.Method{ + Name: "CountAll", + Returns: []code.Type{ + code.SimpleType("int"), + code.SimpleType("error"), + code.SimpleType("bool"), + }, + }, + ExpectedError: spec.UnsupportedReturnError, + }, + { + Name: "invalid number of returns", + Method: code.Method{ + Name: "CountAll", + Returns: []code.Type{ + code.SimpleType("int"), + code.SimpleType("error"), + code.SimpleType("bool"), + }, + }, + ExpectedError: spec.UnsupportedReturnError, + }, + { + Name: "invalid integer return", + Method: code.Method{ + Name: "CountAll", + Returns: []code.Type{ + code.SimpleType("int64"), + code.SimpleType("error"), + }, + }, + ExpectedError: spec.UnsupportedReturnError, + }, + { + Name: "error return not provided", + Method: code.Method{ + Name: "CountAll", + Returns: []code.Type{ + code.SimpleType("int"), + code.SimpleType("bool"), + }, + }, + ExpectedError: spec.UnsupportedReturnError, + }, + { + Name: "invalid query", + Method: code.Method{ + Name: "CountBy", + Params: []code.Param{ + {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + }, + Returns: []code.Type{ + code.SimpleType("int"), + code.SimpleType("error"), + }, + }, + ExpectedError: spec.InvalidQueryError, + }, + { + Name: "context parameter not provided", + Method: code.Method{ + Name: "CountAll", + Returns: []code.Type{ + code.SimpleType("int"), + code.SimpleType("error"), + }, + }, + ExpectedError: spec.ContextParamRequiredError, + }, + { + Name: "mismatched number of parameter", + Method: code.Method{ + Name: "CountByGender", + Params: []code.Param{ + {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Type: code.SimpleType("Gender")}, + {Type: code.SimpleType("int")}, + }, + Returns: []code.Type{ + code.SimpleType("int"), + code.SimpleType("error"), + }, + }, + ExpectedError: spec.InvalidParamError, + }, + { + Name: "mismatched method parameter type", + Method: code.Method{ + Name: "CountByGender", + Params: []code.Param{ + {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Type: code.SimpleType("string")}, + }, + Returns: []code.Type{ + code.SimpleType("int"), + code.SimpleType("error"), + }, + }, + ExpectedError: spec.InvalidParamError, + }, + { + Name: "struct field not found", + Method: code.Method{ + Name: "CountByCountry", + Params: []code.Param{ + {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Type: code.SimpleType("string")}, + }, + Returns: []code.Type{ + code.SimpleType("int"), + code.SimpleType("error"), + }, + }, + ExpectedError: spec.StructFieldNotFoundError, + }, + } + + for _, testCase := range testTable { + t.Run(testCase.Name, func(t *testing.T) { + _, err := spec.ParseInterfaceMethod(structModel, testCase.Method) + + if err != testCase.ExpectedError { + t.Errorf("\nExpected = %v\nReceived = %v", testCase.ExpectedError, err) + } + }) + } +} diff --git a/internal/spec/query.go b/internal/spec/query.go new file mode 100644 index 0000000..8e5f7da --- /dev/null +++ b/internal/spec/query.go @@ -0,0 +1,148 @@ +package spec + +import ( + "strings" + + "github.com/sunboyy/repogen/internal/code" +) + +// QuerySpec is a set of conditions of querying the database +type QuerySpec struct { + Operator Operator + Predicates []Predicate +} + +// NumberOfArguments returns number of arguments required to perform the query +func (q QuerySpec) NumberOfArguments() int { + var totalArgs int + for _, predicate := range q.Predicates { + totalArgs += predicate.Comparator.NumberOfArguments() + } + return totalArgs +} + +// Operator is a boolean operator for merging conditions +type Operator string + +// boolean operator types +const ( + OperatorAnd Operator = "AND" + OperatorOr Operator = "OR" +) + +// Comparator is a comparison operator of the condition to query the data +type Comparator string + +// comparator types +const ( + ComparatorNot Comparator = "NOT" + ComparatorEqual Comparator = "EQUAL" + ComparatorLessThan Comparator = "LESS_THAN" + ComparatorLessThanEqual Comparator = "LESS_THAN_EQUAL" + ComparatorGreaterThan Comparator = "GREATER_THAN" + ComparatorGreaterThanEqual Comparator = "GREATER_THAN_EQUAL" + ComparatorBetween Comparator = "BETWEEN" + ComparatorIn Comparator = "IN" +) + +// ArgumentTypeFromFieldType returns a type of required argument from the given struct field type +func (c Comparator) ArgumentTypeFromFieldType(t code.Type) code.Type { + if c == ComparatorIn { + return code.ArrayType{ContainedType: t} + } + return t +} + +// NumberOfArguments returns the number of arguments required to perform the comparison +func (c Comparator) NumberOfArguments() int { + if c == ComparatorBetween { + return 2 + } + return 1 +} + +// Predicate is a criteria for querying a field +type Predicate struct { + Field string + Comparator Comparator + ParamIndex int +} + +type predicateToken []string + +func (t predicateToken) ToPredicate(paramIndex int) Predicate { + if len(t) > 1 && t[len(t)-1] == "Not" { + return Predicate{Field: strings.Join(t[:len(t)-1], ""), Comparator: ComparatorNot, ParamIndex: paramIndex} + } + if len(t) > 2 && t[len(t)-2] == "Less" && t[len(t)-1] == "Than" { + return Predicate{Field: strings.Join(t[:len(t)-2], ""), Comparator: ComparatorLessThan, ParamIndex: paramIndex} + } + if len(t) > 3 && t[len(t)-3] == "Less" && t[len(t)-2] == "Than" && t[len(t)-1] == "Equal" { + return Predicate{Field: strings.Join(t[:len(t)-3], ""), Comparator: ComparatorLessThanEqual, ParamIndex: paramIndex} + } + if len(t) > 2 && t[len(t)-2] == "Greater" && t[len(t)-1] == "Than" { + return Predicate{Field: strings.Join(t[:len(t)-2], ""), Comparator: ComparatorGreaterThan, ParamIndex: paramIndex} + } + if len(t) > 3 && t[len(t)-3] == "Greater" && t[len(t)-2] == "Than" && t[len(t)-1] == "Equal" { + return Predicate{Field: strings.Join(t[:len(t)-3], ""), Comparator: ComparatorGreaterThanEqual, ParamIndex: paramIndex} + } + if len(t) > 1 && t[len(t)-1] == "In" { + return Predicate{Field: strings.Join(t[:len(t)-1], ""), Comparator: ComparatorIn, ParamIndex: paramIndex} + } + if len(t) > 1 && t[len(t)-1] == "Between" { + return Predicate{Field: strings.Join(t[:len(t)-1], ""), Comparator: ComparatorBetween, ParamIndex: paramIndex} + } + return Predicate{Field: strings.Join(t, ""), Comparator: ComparatorEqual, ParamIndex: paramIndex} +} + +func parseQuery(tokens []string, paramIndex int) (QuerySpec, error) { + if len(tokens) == 0 { + return QuerySpec{}, InvalidQueryError + } + + if len(tokens) == 1 && tokens[0] == "All" { + return QuerySpec{}, nil + } + + if tokens[0] == "One" { + tokens = tokens[1:] + } + if tokens[0] == "By" { + tokens = tokens[1:] + } + + if len(tokens) == 0 || tokens[0] == "And" || tokens[0] == "Or" { + return QuerySpec{}, InvalidQueryError + } + + var operator Operator + var predicates []Predicate + var aggregatedToken predicateToken + for _, token := range tokens { + if token != "And" && token != "Or" { + aggregatedToken = append(aggregatedToken, token) + } else if len(aggregatedToken) == 0 { + return QuerySpec{}, InvalidQueryError + } else if token == "And" && operator != OperatorOr { + operator = OperatorAnd + predicate := aggregatedToken.ToPredicate(paramIndex) + predicates = append(predicates, predicate) + paramIndex += predicate.Comparator.NumberOfArguments() + aggregatedToken = predicateToken{} + } else if token == "Or" && operator != OperatorAnd { + operator = OperatorOr + predicate := aggregatedToken.ToPredicate(paramIndex) + predicates = append(predicates, predicate) + paramIndex += predicate.Comparator.NumberOfArguments() + aggregatedToken = predicateToken{} + } else { + return QuerySpec{}, InvalidQueryError + } + } + if len(aggregatedToken) == 0 { + return QuerySpec{}, InvalidQueryError + } + predicates = append(predicates, aggregatedToken.ToPredicate(paramIndex)) + + return QuerySpec{Operator: operator, Predicates: predicates}, nil +}