Add count operation

This commit is contained in:
sunboyy 2021-02-06 18:05:47 +07:00
parent c752849518
commit 496f9541cb
12 changed files with 858 additions and 261 deletions

View file

@ -17,12 +17,20 @@ jobs:
with:
go-version: 1.15
- name: Install dependencies
run: go get -u golang.org/x/lint/golint
- name: Build
run: go build -v ./...
- name: Test
run: go test -v ./... -covermode=count -coverprofile=cover.out
- name: Vet & Lint
run: |
go vet ./...
golint ./...
- uses: codecov/codecov-action@v1
with:
flags: unittests

View file

@ -7,24 +7,11 @@ on:
branches: [ main ]
jobs:
lint:
golangci-lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Go
uses: actions/setup-go@v2
with:
go-version: 1.15
- name: Install dependencies
run: go get -u golang.org/x/lint/golint
- name: Vet & Lint
run: |
go vet ./...
golint ./...
- name: golangci-lint
uses: golangci/golangci-lint-action@v2
with:

View file

@ -32,7 +32,7 @@ $ go get github.com/sunboyy/repogen
### Step 2: Write a repository specification
Write repository specification as an interface in the same file as the model struct. There are 4 types of operations that are currently supported and are determined by the first word of the method name. Single-entity and multiple-entity modes are determined be the first return value. More complex queries can also be written.
Write repository specification as an interface in the same file as the model struct. There are 5 types of operations that are currently supported and are determined by the first word of the method name. Single-entity and multiple-entity modes are determined be the first return value. More complex queries can also be written.
```go
// You write this interface specification (comment is optional)
@ -54,6 +54,10 @@ type UserRepository interface {
// the match count. The error will be returned only when error occurs while accessing
// the database. This is a MANY mode because the first return type is an integer.
DeleteByCity(ctx context.Context, city string) (int, error)
// CountByCity returns the number of rows that match the given city parameter. If an error occurs while
// accessing the database, error value will be returned.
CountByCity(ctx context.Context, city string) (int, error)
}
```
@ -77,3 +81,7 @@ $ repogen -src=examples/getting-started/user.go -dest=examples/getting-started/u
```
You can also write the above command in the `go:generate` format inside Go files in order to generate the implementation when `go generate` command is executed.
## License
Licensed under [MIT](https://github.com/sunboyy/repogen/blob/main/LICENSE)

View file

@ -35,4 +35,8 @@ type UserRepository interface {
// The error will be returned only when error occurs while accessing the database. This is a MANY mode
// because the first return type is an integer.
DeleteByCity(ctx context.Context, city string) (int, error)
// CountByCity returns the number of rows that match the given city parameter. If an error occurs while
// accessing the database, error value will be returned.
CountByCity(ctx context.Context, city string) (int, error)
}

View file

@ -60,3 +60,13 @@ func (r *UserRepositoryMongo) DeleteByCity(arg0 context.Context, arg1 string) (i
}
return int(result.DeletedCount), nil
}
func (r *UserRepositoryMongo) CountByCity(arg0 context.Context, arg1 string) (int, error) {
count, err := r.collection.CountDocuments(arg0, bson.M{
"city": arg1,
})
if err != nil {
return 0, err
}
return int(count), nil
}

View file

@ -85,6 +85,8 @@ func (g RepositoryGenerator) generateMethodImplementation(methodSpec spec.Method
return g.generateUpdateImplementation(operation)
case spec.DeleteOperation:
return g.generateDeleteImplementation(operation)
case spec.CountOperation:
return g.generateCountImplementation(operation)
}
return "", OperationNotSupportedError
@ -98,8 +100,6 @@ func (g RepositoryGenerator) generateInsertImplementation(operation spec.InsertO
}
func (g RepositoryGenerator) generateFindImplementation(operation spec.FindOperation) (string, error) {
buffer := new(bytes.Buffer)
querySpec, err := g.mongoQuerySpec(operation.Query)
if err != nil {
return "", err
@ -111,31 +111,12 @@ func (g RepositoryGenerator) generateFindImplementation(operation spec.FindOpera
}
if operation.Mode == spec.QueryModeOne {
tmpl, err := template.New("mongo_repository_findone").Parse(findOneTemplate)
if err != nil {
return "", err
return generateFromTemplate("mongo_repository_findone", findOneTemplate, tmplData)
}
if err := tmpl.Execute(buffer, tmplData); err != nil {
return "", err
}
} else {
tmpl, err := template.New("mongo_repository_findmany").Parse(findManyTemplate)
if err != nil {
return "", err
}
if err := tmpl.Execute(buffer, tmplData); err != nil {
return "", err
}
}
return buffer.String(), nil
return generateFromTemplate("mongo_repository_findmany", findManyTemplate, tmplData)
}
func (g RepositoryGenerator) generateUpdateImplementation(operation spec.UpdateOperation) (string, error) {
buffer := new(bytes.Buffer)
var fields []updateField
for _, field := range operation.Fields {
bsonTag, err := g.bsonTagFromFieldName(field.Name)
@ -156,31 +137,12 @@ func (g RepositoryGenerator) generateUpdateImplementation(operation spec.UpdateO
}
if operation.Mode == spec.QueryModeOne {
tmpl, err := template.New("mongo_repository_updateone").Parse(updateOneTemplate)
if err != nil {
return "", err
return generateFromTemplate("mongo_repository_updateone", updateOneTemplate, tmplData)
}
if err := tmpl.Execute(buffer, tmplData); err != nil {
return "", err
}
} else {
tmpl, err := template.New("mongo_repository_updatemany").Parse(updateManyTemplate)
if err != nil {
return "", err
}
if err := tmpl.Execute(buffer, tmplData); err != nil {
return "", err
}
}
return buffer.String(), nil
return generateFromTemplate("mongo_repository_updatemany", updateManyTemplate, tmplData)
}
func (g RepositoryGenerator) generateDeleteImplementation(operation spec.DeleteOperation) (string, error) {
buffer := new(bytes.Buffer)
querySpec, err := g.mongoQuerySpec(operation.Query)
if err != nil {
return "", err
@ -191,26 +153,22 @@ func (g RepositoryGenerator) generateDeleteImplementation(operation spec.DeleteO
}
if operation.Mode == spec.QueryModeOne {
tmpl, err := template.New("mongo_repository_deleteone").Parse(deleteOneTemplate)
return generateFromTemplate("mongo_repository_deleteone", deleteOneTemplate, tmplData)
}
return generateFromTemplate("mongo_repository_deletemany", deleteManyTemplate, tmplData)
}
func (g RepositoryGenerator) generateCountImplementation(operation spec.CountOperation) (string, error) {
querySpec, err := g.mongoQuerySpec(operation.Query)
if err != nil {
return "", err
}
if err := tmpl.Execute(buffer, tmplData); err != nil {
return "", err
}
} else {
tmpl, err := template.New("mongo_repository_deletemany").Parse(deleteManyTemplate)
if err != nil {
return "", err
tmplData := mongoCountTemplateData{
QuerySpec: querySpec,
}
if err := tmpl.Execute(buffer, tmplData); err != nil {
return "", err
}
}
return buffer.String(), nil
return generateFromTemplate("mongo_repository_count", countTemplate, tmplData)
}
func (g RepositoryGenerator) mongoQuerySpec(query spec.QuerySpec) (querySpec, error) {
@ -252,3 +210,17 @@ func (g RepositoryGenerator) bsonTagFromFieldName(fieldName string) (string, err
func (g RepositoryGenerator) structName() string {
return g.InterfaceName + "Mongo"
}
func generateFromTemplate(name string, templateString string, tmplData interface{}) (string, error) {
tmpl, err := template.New(name).Parse(templateString)
if err != nil {
return "", err
}
buffer := new(bytes.Buffer)
if err := tmpl.Execute(buffer, tmplData); err != nil {
return "", err
}
return buffer.String(), nil
}

View file

@ -1089,6 +1089,358 @@ func (r *UserRepositoryMongo) DeleteByGenderIn(arg0 context.Context, arg1 []Gend
}
}
func TestGenerateMethod_Count(t *testing.T) {
testTable := []GenerateMethodTestCase{
{
Name: "simple count method",
MethodSpec: spec.MethodSpec{
Name: "CountByGender",
Params: []code.Param{
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Type: code.SimpleType("Gender")},
},
Returns: []code.Type{
code.SimpleType("int"),
code.SimpleType("error"),
},
Operation: spec.CountOperation{
Query: spec.QuerySpec{
Predicates: []spec.Predicate{
{Field: "Gender", Comparator: spec.ComparatorEqual, ParamIndex: 1},
},
},
},
},
ExpectedCode: `
func (r *UserRepositoryMongo) CountByGender(arg0 context.Context, arg1 Gender) (int, error) {
count, err := r.collection.CountDocuments(arg0, bson.M{
"gender": arg1,
})
if err != nil {
return 0, err
}
return int(count), nil
}
`,
},
{
Name: "count with And operator",
MethodSpec: spec.MethodSpec{
Name: "CountByGenderAndCity",
Params: []code.Param{
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Type: code.SimpleType("Gender")},
{Type: code.SimpleType("int")},
},
Returns: []code.Type{
code.SimpleType("int"),
code.SimpleType("error"),
},
Operation: spec.CountOperation{
Query: spec.QuerySpec{
Operator: spec.OperatorAnd,
Predicates: []spec.Predicate{
{Field: "Gender", Comparator: spec.ComparatorEqual, ParamIndex: 1},
{Field: "Age", Comparator: spec.ComparatorEqual, ParamIndex: 2},
},
},
},
},
ExpectedCode: `
func (r *UserRepositoryMongo) CountByGenderAndCity(arg0 context.Context, arg1 Gender, arg2 int) (int, error) {
count, err := r.collection.CountDocuments(arg0, bson.M{
"gender": arg1,
"age": arg2,
})
if err != nil {
return 0, err
}
return int(count), nil
}
`,
},
{
Name: "count with Or operator",
MethodSpec: spec.MethodSpec{
Name: "CountByGenderOrCity",
Params: []code.Param{
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Type: code.SimpleType("Gender")},
{Type: code.SimpleType("int")},
},
Returns: []code.Type{
code.SimpleType("int"),
code.SimpleType("error"),
},
Operation: spec.CountOperation{
Query: spec.QuerySpec{
Operator: spec.OperatorOr,
Predicates: []spec.Predicate{
{Field: "Gender", Comparator: spec.ComparatorEqual, ParamIndex: 1},
{Field: "Age", Comparator: spec.ComparatorEqual, ParamIndex: 2},
},
},
},
},
ExpectedCode: `
func (r *UserRepositoryMongo) CountByGenderOrCity(arg0 context.Context, arg1 Gender, arg2 int) (int, error) {
count, err := r.collection.CountDocuments(arg0, bson.M{
"$or": []bson.M{
{"gender": arg1},
{"age": arg2},
},
})
if err != nil {
return 0, err
}
return int(count), nil
}
`,
},
{
Name: "count with Not comparator",
MethodSpec: spec.MethodSpec{
Name: "CountByGenderNot",
Params: []code.Param{
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Type: code.SimpleType("Gender")},
},
Returns: []code.Type{
code.SimpleType("int"),
code.SimpleType("error"),
},
Operation: spec.CountOperation{
Query: spec.QuerySpec{
Predicates: []spec.Predicate{
{Field: "Gender", Comparator: spec.ComparatorNot, ParamIndex: 1},
},
},
},
},
ExpectedCode: `
func (r *UserRepositoryMongo) CountByGenderNot(arg0 context.Context, arg1 Gender) (int, error) {
count, err := r.collection.CountDocuments(arg0, bson.M{
"gender": bson.M{"$ne": arg1},
})
if err != nil {
return 0, err
}
return int(count), nil
}
`,
},
{
Name: "count with LessThan comparator",
MethodSpec: spec.MethodSpec{
Name: "CountByAgeLessThan",
Params: []code.Param{
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Type: code.SimpleType("int")},
},
Returns: []code.Type{
code.SimpleType("int"),
code.SimpleType("error"),
},
Operation: spec.CountOperation{
Query: spec.QuerySpec{
Predicates: []spec.Predicate{
{Field: "Age", Comparator: spec.ComparatorLessThan, ParamIndex: 1},
},
},
},
},
ExpectedCode: `
func (r *UserRepositoryMongo) CountByAgeLessThan(arg0 context.Context, arg1 int) (int, error) {
count, err := r.collection.CountDocuments(arg0, bson.M{
"age": bson.M{"$lt": arg1},
})
if err != nil {
return 0, err
}
return int(count), nil
}
`,
},
{
Name: "count with LessThanEqual comparator",
MethodSpec: spec.MethodSpec{
Name: "CountByAgeLessThanEqual",
Params: []code.Param{
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Type: code.SimpleType("int")},
},
Returns: []code.Type{
code.SimpleType("int"),
code.SimpleType("error"),
},
Operation: spec.CountOperation{
Query: spec.QuerySpec{
Predicates: []spec.Predicate{
{Field: "Age", Comparator: spec.ComparatorLessThanEqual, ParamIndex: 1},
},
},
},
},
ExpectedCode: `
func (r *UserRepositoryMongo) CountByAgeLessThanEqual(arg0 context.Context, arg1 int) (int, error) {
count, err := r.collection.CountDocuments(arg0, bson.M{
"age": bson.M{"$lte": arg1},
})
if err != nil {
return 0, err
}
return int(count), nil
}
`,
},
{
Name: "count with GreaterThan comparator",
MethodSpec: spec.MethodSpec{
Name: "CountByAgeGreaterThan",
Params: []code.Param{
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Type: code.SimpleType("int")},
},
Returns: []code.Type{
code.SimpleType("int"),
code.SimpleType("error"),
},
Operation: spec.CountOperation{
Query: spec.QuerySpec{
Predicates: []spec.Predicate{
{Field: "Age", Comparator: spec.ComparatorGreaterThan, ParamIndex: 1},
},
},
},
},
ExpectedCode: `
func (r *UserRepositoryMongo) CountByAgeGreaterThan(arg0 context.Context, arg1 int) (int, error) {
count, err := r.collection.CountDocuments(arg0, bson.M{
"age": bson.M{"$gt": arg1},
})
if err != nil {
return 0, err
}
return int(count), nil
}
`,
},
{
Name: "count with GreaterThanEqual comparator",
MethodSpec: spec.MethodSpec{
Name: "CountByAgeGreaterThanEqual",
Params: []code.Param{
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Type: code.SimpleType("int")},
},
Returns: []code.Type{
code.SimpleType("int"),
code.SimpleType("error"),
},
Operation: spec.CountOperation{
Query: spec.QuerySpec{
Predicates: []spec.Predicate{
{Field: "Age", Comparator: spec.ComparatorGreaterThanEqual, ParamIndex: 1},
},
},
},
},
ExpectedCode: `
func (r *UserRepositoryMongo) CountByAgeGreaterThanEqual(arg0 context.Context, arg1 int) (int, error) {
count, err := r.collection.CountDocuments(arg0, bson.M{
"age": bson.M{"$gte": arg1},
})
if err != nil {
return 0, err
}
return int(count), nil
}
`,
},
{
Name: "count with Between comparator",
MethodSpec: spec.MethodSpec{
Name: "CountByAgeBetween",
Params: []code.Param{
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Type: code.SimpleType("int")},
{Type: code.SimpleType("int")},
},
Returns: []code.Type{
code.SimpleType("int"),
code.SimpleType("error"),
},
Operation: spec.CountOperation{
Query: spec.QuerySpec{
Predicates: []spec.Predicate{
{Field: "Age", Comparator: spec.ComparatorBetween, ParamIndex: 1},
},
},
},
},
ExpectedCode: `
func (r *UserRepositoryMongo) CountByAgeBetween(arg0 context.Context, arg1 int, arg2 int) (int, error) {
count, err := r.collection.CountDocuments(arg0, bson.M{
"age": bson.M{"$gte": arg1, "$lte": arg2},
})
if err != nil {
return 0, err
}
return int(count), nil
}
`,
},
{
Name: "count with In comparator",
MethodSpec: spec.MethodSpec{
Name: "CountByAgeIn",
Params: []code.Param{
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Type: code.ArrayType{ContainedType: code.SimpleType("int")}},
},
Returns: []code.Type{
code.SimpleType("int"),
code.SimpleType("error"),
},
Operation: spec.CountOperation{
Query: spec.QuerySpec{
Predicates: []spec.Predicate{
{Field: "Age", Comparator: spec.ComparatorIn, ParamIndex: 1},
},
},
},
},
ExpectedCode: `
func (r *UserRepositoryMongo) CountByAgeIn(arg0 context.Context, arg1 []int) (int, error) {
count, err := r.collection.CountDocuments(arg0, bson.M{
"age": bson.M{"$in": arg1},
})
if err != nil {
return 0, err
}
return int(count), nil
}
`,
},
}
for _, testCase := range testTable {
t.Run(testCase.Name, func(t *testing.T) {
generator := mongo.NewGenerator(userModel, "UserRepository")
buffer := new(bytes.Buffer)
err := generator.GenerateMethod(testCase.MethodSpec, buffer)
if err != nil {
t.Error(err)
}
if err := testutils.ExpectMultiLineString(testCase.ExpectedCode, buffer.String()); err != nil {
t.Error(err)
}
})
}
}
type GenerateMethodInvalidTestCase struct {
Name string
Method spec.MethodSpec

View file

@ -159,3 +159,15 @@ const deleteManyTemplate = ` result, err := r.collection.DeleteMany(arg0, bson.M
return 0, err
}
return int(result.DeletedCount), nil`
type mongoCountTemplateData struct {
QuerySpec querySpec
}
const countTemplate = ` count, err := r.collection.CountDocuments(arg0, bson.M{
{{.QuerySpec.Code}}
})
if err != nil {
return 0, err
}
return int(count), nil`

View file

@ -1,8 +1,6 @@
package spec
import (
"strings"
"github.com/sunboyy/repogen/internal/code"
)
@ -57,91 +55,7 @@ type DeleteOperation struct {
Query QuerySpec
}
// QuerySpec is a set of conditions of querying the database
type QuerySpec struct {
Operator Operator
Predicates []Predicate
}
// NumberOfArguments returns number of arguments required to perform the query
func (q QuerySpec) NumberOfArguments() int {
var totalArgs int
for _, predicate := range q.Predicates {
totalArgs += predicate.Comparator.NumberOfArguments()
}
return totalArgs
}
// Operator is a boolean operator for merging conditions
type Operator string
// boolean operator types
const (
OperatorAnd Operator = "AND"
OperatorOr Operator = "OR"
)
// Comparator is a comparison operator of the condition to query the data
type Comparator string
// comparator types
const (
ComparatorNot Comparator = "NOT"
ComparatorEqual Comparator = "EQUAL"
ComparatorLessThan Comparator = "LESS_THAN"
ComparatorLessThanEqual Comparator = "LESS_THAN_EQUAL"
ComparatorGreaterThan Comparator = "GREATER_THAN"
ComparatorGreaterThanEqual Comparator = "GREATER_THAN_EQUAL"
ComparatorBetween Comparator = "BETWEEN"
ComparatorIn Comparator = "IN"
)
// ArgumentTypeFromFieldType returns a type of required argument from the given struct field type
func (c Comparator) ArgumentTypeFromFieldType(t code.Type) code.Type {
if c == ComparatorIn {
return code.ArrayType{ContainedType: t}
}
return t
}
// NumberOfArguments returns the number of arguments required to perform the comparison
func (c Comparator) NumberOfArguments() int {
if c == ComparatorBetween {
return 2
}
return 1
}
// Predicate is a criteria for querying a field
type Predicate struct {
Field string
Comparator Comparator
ParamIndex int
}
type predicateToken []string
func (t predicateToken) ToPredicate(paramIndex int) Predicate {
if len(t) > 1 && t[len(t)-1] == "Not" {
return Predicate{Field: strings.Join(t[:len(t)-1], ""), Comparator: ComparatorNot, ParamIndex: paramIndex}
}
if len(t) > 2 && t[len(t)-2] == "Less" && t[len(t)-1] == "Than" {
return Predicate{Field: strings.Join(t[:len(t)-2], ""), Comparator: ComparatorLessThan, ParamIndex: paramIndex}
}
if len(t) > 3 && t[len(t)-3] == "Less" && t[len(t)-2] == "Than" && t[len(t)-1] == "Equal" {
return Predicate{Field: strings.Join(t[:len(t)-3], ""), Comparator: ComparatorLessThanEqual, ParamIndex: paramIndex}
}
if len(t) > 2 && t[len(t)-2] == "Greater" && t[len(t)-1] == "Than" {
return Predicate{Field: strings.Join(t[:len(t)-2], ""), Comparator: ComparatorGreaterThan, ParamIndex: paramIndex}
}
if len(t) > 3 && t[len(t)-3] == "Greater" && t[len(t)-2] == "Than" && t[len(t)-1] == "Equal" {
return Predicate{Field: strings.Join(t[:len(t)-3], ""), Comparator: ComparatorGreaterThanEqual, ParamIndex: paramIndex}
}
if len(t) > 1 && t[len(t)-1] == "In" {
return Predicate{Field: strings.Join(t[:len(t)-1], ""), Comparator: ComparatorIn, ParamIndex: paramIndex}
}
if len(t) > 1 && t[len(t)-1] == "Between" {
return Predicate{Field: strings.Join(t[:len(t)-1], ""), Comparator: ComparatorBetween, ParamIndex: paramIndex}
}
return Predicate{Field: strings.Join(t, ""), Comparator: ComparatorEqual, ParamIndex: paramIndex}
// CountOperation is a method specification for count operations
type CountOperation struct {
Query QuerySpec
}

View file

@ -31,6 +31,8 @@ func (p interfaceMethodParser) Parse() (MethodSpec, error) {
return p.parseUpdateMethod(methodNameTokens[1:])
case "Delete":
return p.parseDeleteMethod(methodNameTokens[1:])
case "Count":
return p.parseCountMethod(methodNameTokens[1:])
}
return MethodSpec{}, UnknownOperationError
}
@ -55,14 +57,9 @@ func (p interfaceMethodParser) parseInsertMethod(tokens []string) (MethodSpec, e
return MethodSpec{}, InvalidParamError
}
return MethodSpec{
Name: p.Method.Name,
Params: p.Method.Params,
Returns: p.Method.Returns,
Operation: InsertOperation{
return p.createMethodSpec(InsertOperation{
Mode: mode,
},
}, nil
}), nil
}
func (p interfaceMethodParser) extractInsertReturns(returns []code.Type) (QueryMode, error) {
@ -99,12 +96,12 @@ func (p interfaceMethodParser) parseFindMethod(tokens []string) (MethodSpec, err
return MethodSpec{}, UnsupportedNameError
}
mode, err := p.extractFindReturns(p.Method.Returns)
mode, err := p.extractModelOrSliceReturns(p.Method.Returns)
if err != nil {
return MethodSpec{}, err
}
querySpec, err := p.parseQuery(tokens, 1)
querySpec, err := parseQuery(tokens, 1)
if err != nil {
return MethodSpec{}, err
}
@ -117,18 +114,13 @@ func (p interfaceMethodParser) parseFindMethod(tokens []string) (MethodSpec, err
return MethodSpec{}, err
}
return MethodSpec{
Name: p.Method.Name,
Params: p.Method.Params,
Returns: p.Method.Returns,
Operation: FindOperation{
return p.createMethodSpec(FindOperation{
Mode: mode,
Query: querySpec,
},
}, nil
}), nil
}
func (p interfaceMethodParser) extractFindReturns(returns []code.Type) (QueryMode, error) {
func (p interfaceMethodParser) extractModelOrSliceReturns(returns []code.Type) (QueryMode, error) {
if len(returns) != 2 {
return "", UnsupportedReturnError
}
@ -166,7 +158,7 @@ func (p interfaceMethodParser) parseUpdateMethod(tokens []string) (MethodSpec, e
return MethodSpec{}, UnsupportedNameError
}
mode, err := p.extractCountReturns(p.Method.Returns)
mode, err := p.extractIntOrBoolReturns(p.Method.Returns)
if err != nil {
return MethodSpec{}, err
}
@ -193,7 +185,7 @@ func (p interfaceMethodParser) parseUpdateMethod(tokens []string) (MethodSpec, e
}
fields = append(fields, UpdateField{Name: aggregatedToken, ParamIndex: paramIndex})
querySpec, err := p.parseQuery(tokens, 1+len(fields))
querySpec, err := parseQuery(tokens, 1+len(fields))
if err != nil {
return MethodSpec{}, err
}
@ -217,16 +209,11 @@ func (p interfaceMethodParser) parseUpdateMethod(tokens []string) (MethodSpec, e
return MethodSpec{}, err
}
return MethodSpec{
Name: p.Method.Name,
Params: p.Method.Params,
Returns: p.Method.Returns,
Operation: UpdateOperation{
return p.createMethodSpec(UpdateOperation{
Fields: fields,
Mode: mode,
Query: querySpec,
},
}, nil
}), nil
}
func (p interfaceMethodParser) parseDeleteMethod(tokens []string) (MethodSpec, error) {
@ -234,12 +221,12 @@ func (p interfaceMethodParser) parseDeleteMethod(tokens []string) (MethodSpec, e
return MethodSpec{}, UnsupportedNameError
}
mode, err := p.extractCountReturns(p.Method.Returns)
mode, err := p.extractIntOrBoolReturns(p.Method.Returns)
if err != nil {
return MethodSpec{}, err
}
querySpec, err := p.parseQuery(tokens, 1)
querySpec, err := parseQuery(tokens, 1)
if err != nil {
return MethodSpec{}, err
}
@ -252,18 +239,56 @@ func (p interfaceMethodParser) parseDeleteMethod(tokens []string) (MethodSpec, e
return MethodSpec{}, err
}
return MethodSpec{
Name: p.Method.Name,
Params: p.Method.Params,
Returns: p.Method.Returns,
Operation: DeleteOperation{
return p.createMethodSpec(DeleteOperation{
Mode: mode,
Query: querySpec,
},
}, nil
}), nil
}
func (p interfaceMethodParser) extractCountReturns(returns []code.Type) (QueryMode, error) {
func (p interfaceMethodParser) parseCountMethod(tokens []string) (MethodSpec, error) {
if len(tokens) == 0 {
return MethodSpec{}, UnsupportedNameError
}
if err := p.validateCountReturns(p.Method.Returns); err != nil {
return MethodSpec{}, err
}
querySpec, err := parseQuery(tokens, 1)
if err != nil {
return MethodSpec{}, err
}
if err := p.validateContextParam(); err != nil {
return MethodSpec{}, err
}
if err := p.validateQueryFromParams(p.Method.Params[1:], querySpec); err != nil {
return MethodSpec{}, err
}
return p.createMethodSpec(CountOperation{
Query: querySpec,
}), nil
}
func (p interfaceMethodParser) validateCountReturns(returns []code.Type) error {
if len(returns) != 2 {
return UnsupportedReturnError
}
if returns[0] != code.SimpleType("int") {
return UnsupportedReturnError
}
if returns[1] != code.SimpleType("error") {
return UnsupportedReturnError
}
return nil
}
func (p interfaceMethodParser) extractIntOrBoolReturns(returns []code.Type) (QueryMode, error) {
if len(returns) != 2 {
return "", UnsupportedReturnError
}
@ -285,58 +310,6 @@ func (p interfaceMethodParser) extractCountReturns(returns []code.Type) (QueryMo
return "", UnsupportedReturnError
}
func (p interfaceMethodParser) parseQuery(tokens []string, paramIndex int) (QuerySpec, error) {
if len(tokens) == 0 {
return QuerySpec{}, InvalidQueryError
}
if len(tokens) == 1 && tokens[0] == "All" {
return QuerySpec{}, nil
}
if tokens[0] == "One" {
tokens = tokens[1:]
}
if tokens[0] == "By" {
tokens = tokens[1:]
}
if tokens[0] == "And" || tokens[0] == "Or" {
return QuerySpec{}, InvalidQueryError
}
var operator Operator
var predicates []Predicate
var aggregatedToken predicateToken
for _, token := range tokens {
if token != "And" && token != "Or" {
aggregatedToken = append(aggregatedToken, token)
} else if len(aggregatedToken) == 0 {
return QuerySpec{}, InvalidQueryError
} else if token == "And" && operator != OperatorOr {
operator = OperatorAnd
predicate := aggregatedToken.ToPredicate(paramIndex)
predicates = append(predicates, predicate)
paramIndex += predicate.Comparator.NumberOfArguments()
aggregatedToken = predicateToken{}
} else if token == "Or" && operator != OperatorAnd {
operator = OperatorOr
predicate := aggregatedToken.ToPredicate(paramIndex)
predicates = append(predicates, predicate)
paramIndex += predicate.Comparator.NumberOfArguments()
aggregatedToken = predicateToken{}
} else {
return QuerySpec{}, InvalidQueryError
}
}
if len(aggregatedToken) == 0 {
return QuerySpec{}, InvalidQueryError
}
predicates = append(predicates, aggregatedToken.ToPredicate(paramIndex))
return QuerySpec{Operator: operator, Predicates: predicates}, nil
}
func (p interfaceMethodParser) validateContextParam() error {
contextType := code.ExternalType{PackageAlias: "context", Name: "Context"}
if len(p.Method.Params) == 0 || p.Method.Params[0].Type != contextType {
@ -368,3 +341,12 @@ func (p interfaceMethodParser) validateQueryFromParams(params []code.Param, quer
return nil
}
func (p interfaceMethodParser) createMethodSpec(operation Operation) MethodSpec {
return MethodSpec{
Name: p.Method.Name,
Params: p.Method.Params,
Returns: p.Method.Returns,
Operation: operation,
}
}

View file

@ -778,6 +778,67 @@ func TestParseInterfaceMethod_Delete(t *testing.T) {
}
}
func TestParseInterfaceMethod_Count(t *testing.T) {
testTable := []ParseInterfaceMethodTestCase{
{
Name: "CountAll method",
Method: code.Method{
Name: "CountAll",
Params: []code.Param{
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
},
Returns: []code.Type{
code.SimpleType("int"),
code.SimpleType("error"),
},
},
ExpectedOperation: spec.CountOperation{
Query: spec.QuerySpec{},
},
},
{
Name: "CountByArg method",
Method: code.Method{
Name: "CountByGender",
Params: []code.Param{
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Type: code.SimpleType("Gender")},
},
Returns: []code.Type{
code.SimpleType("int"),
code.SimpleType("error"),
},
},
ExpectedOperation: spec.CountOperation{
Query: spec.QuerySpec{
Predicates: []spec.Predicate{
{Field: "Gender", Comparator: spec.ComparatorEqual, ParamIndex: 1},
},
},
},
},
}
for _, testCase := range testTable {
t.Run(testCase.Name, func(t *testing.T) {
actualSpec, err := spec.ParseInterfaceMethod(structModel, testCase.Method)
if err != nil {
t.Errorf("Error = %s", err)
}
expectedOutput := spec.MethodSpec{
Name: testCase.Method.Name,
Params: testCase.Method.Params,
Returns: testCase.Method.Returns,
Operation: testCase.ExpectedOperation,
}
if !reflect.DeepEqual(actualSpec, expectedOutput) {
t.Errorf("Expected = %v\nReceived = %v", expectedOutput, actualSpec)
}
})
}
}
type ParseInterfaceMethodInvalidTestCase struct {
Name string
Method code.Method
@ -1403,3 +1464,142 @@ func TestParseInterfaceMethod_Delete_Invalid(t *testing.T) {
})
}
}
func TestParseInterfaceMethod_Count_Invalid(t *testing.T) {
testTable := []ParseInterfaceMethodInvalidTestCase{
{
Name: "unsupported count method name",
Method: code.Method{
Name: "Count",
},
ExpectedError: spec.UnsupportedNameError,
},
{
Name: "invalid number of returns",
Method: code.Method{
Name: "CountAll",
Returns: []code.Type{
code.SimpleType("int"),
code.SimpleType("error"),
code.SimpleType("bool"),
},
},
ExpectedError: spec.UnsupportedReturnError,
},
{
Name: "invalid number of returns",
Method: code.Method{
Name: "CountAll",
Returns: []code.Type{
code.SimpleType("int"),
code.SimpleType("error"),
code.SimpleType("bool"),
},
},
ExpectedError: spec.UnsupportedReturnError,
},
{
Name: "invalid integer return",
Method: code.Method{
Name: "CountAll",
Returns: []code.Type{
code.SimpleType("int64"),
code.SimpleType("error"),
},
},
ExpectedError: spec.UnsupportedReturnError,
},
{
Name: "error return not provided",
Method: code.Method{
Name: "CountAll",
Returns: []code.Type{
code.SimpleType("int"),
code.SimpleType("bool"),
},
},
ExpectedError: spec.UnsupportedReturnError,
},
{
Name: "invalid query",
Method: code.Method{
Name: "CountBy",
Params: []code.Param{
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
},
Returns: []code.Type{
code.SimpleType("int"),
code.SimpleType("error"),
},
},
ExpectedError: spec.InvalidQueryError,
},
{
Name: "context parameter not provided",
Method: code.Method{
Name: "CountAll",
Returns: []code.Type{
code.SimpleType("int"),
code.SimpleType("error"),
},
},
ExpectedError: spec.ContextParamRequiredError,
},
{
Name: "mismatched number of parameter",
Method: code.Method{
Name: "CountByGender",
Params: []code.Param{
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Type: code.SimpleType("Gender")},
{Type: code.SimpleType("int")},
},
Returns: []code.Type{
code.SimpleType("int"),
code.SimpleType("error"),
},
},
ExpectedError: spec.InvalidParamError,
},
{
Name: "mismatched method parameter type",
Method: code.Method{
Name: "CountByGender",
Params: []code.Param{
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Type: code.SimpleType("string")},
},
Returns: []code.Type{
code.SimpleType("int"),
code.SimpleType("error"),
},
},
ExpectedError: spec.InvalidParamError,
},
{
Name: "struct field not found",
Method: code.Method{
Name: "CountByCountry",
Params: []code.Param{
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Type: code.SimpleType("string")},
},
Returns: []code.Type{
code.SimpleType("int"),
code.SimpleType("error"),
},
},
ExpectedError: spec.StructFieldNotFoundError,
},
}
for _, testCase := range testTable {
t.Run(testCase.Name, func(t *testing.T) {
_, err := spec.ParseInterfaceMethod(structModel, testCase.Method)
if err != testCase.ExpectedError {
t.Errorf("\nExpected = %v\nReceived = %v", testCase.ExpectedError, err)
}
})
}
}

148
internal/spec/query.go Normal file
View file

@ -0,0 +1,148 @@
package spec
import (
"strings"
"github.com/sunboyy/repogen/internal/code"
)
// QuerySpec is a set of conditions of querying the database
type QuerySpec struct {
Operator Operator
Predicates []Predicate
}
// NumberOfArguments returns number of arguments required to perform the query
func (q QuerySpec) NumberOfArguments() int {
var totalArgs int
for _, predicate := range q.Predicates {
totalArgs += predicate.Comparator.NumberOfArguments()
}
return totalArgs
}
// Operator is a boolean operator for merging conditions
type Operator string
// boolean operator types
const (
OperatorAnd Operator = "AND"
OperatorOr Operator = "OR"
)
// Comparator is a comparison operator of the condition to query the data
type Comparator string
// comparator types
const (
ComparatorNot Comparator = "NOT"
ComparatorEqual Comparator = "EQUAL"
ComparatorLessThan Comparator = "LESS_THAN"
ComparatorLessThanEqual Comparator = "LESS_THAN_EQUAL"
ComparatorGreaterThan Comparator = "GREATER_THAN"
ComparatorGreaterThanEqual Comparator = "GREATER_THAN_EQUAL"
ComparatorBetween Comparator = "BETWEEN"
ComparatorIn Comparator = "IN"
)
// ArgumentTypeFromFieldType returns a type of required argument from the given struct field type
func (c Comparator) ArgumentTypeFromFieldType(t code.Type) code.Type {
if c == ComparatorIn {
return code.ArrayType{ContainedType: t}
}
return t
}
// NumberOfArguments returns the number of arguments required to perform the comparison
func (c Comparator) NumberOfArguments() int {
if c == ComparatorBetween {
return 2
}
return 1
}
// Predicate is a criteria for querying a field
type Predicate struct {
Field string
Comparator Comparator
ParamIndex int
}
type predicateToken []string
func (t predicateToken) ToPredicate(paramIndex int) Predicate {
if len(t) > 1 && t[len(t)-1] == "Not" {
return Predicate{Field: strings.Join(t[:len(t)-1], ""), Comparator: ComparatorNot, ParamIndex: paramIndex}
}
if len(t) > 2 && t[len(t)-2] == "Less" && t[len(t)-1] == "Than" {
return Predicate{Field: strings.Join(t[:len(t)-2], ""), Comparator: ComparatorLessThan, ParamIndex: paramIndex}
}
if len(t) > 3 && t[len(t)-3] == "Less" && t[len(t)-2] == "Than" && t[len(t)-1] == "Equal" {
return Predicate{Field: strings.Join(t[:len(t)-3], ""), Comparator: ComparatorLessThanEqual, ParamIndex: paramIndex}
}
if len(t) > 2 && t[len(t)-2] == "Greater" && t[len(t)-1] == "Than" {
return Predicate{Field: strings.Join(t[:len(t)-2], ""), Comparator: ComparatorGreaterThan, ParamIndex: paramIndex}
}
if len(t) > 3 && t[len(t)-3] == "Greater" && t[len(t)-2] == "Than" && t[len(t)-1] == "Equal" {
return Predicate{Field: strings.Join(t[:len(t)-3], ""), Comparator: ComparatorGreaterThanEqual, ParamIndex: paramIndex}
}
if len(t) > 1 && t[len(t)-1] == "In" {
return Predicate{Field: strings.Join(t[:len(t)-1], ""), Comparator: ComparatorIn, ParamIndex: paramIndex}
}
if len(t) > 1 && t[len(t)-1] == "Between" {
return Predicate{Field: strings.Join(t[:len(t)-1], ""), Comparator: ComparatorBetween, ParamIndex: paramIndex}
}
return Predicate{Field: strings.Join(t, ""), Comparator: ComparatorEqual, ParamIndex: paramIndex}
}
func parseQuery(tokens []string, paramIndex int) (QuerySpec, error) {
if len(tokens) == 0 {
return QuerySpec{}, InvalidQueryError
}
if len(tokens) == 1 && tokens[0] == "All" {
return QuerySpec{}, nil
}
if tokens[0] == "One" {
tokens = tokens[1:]
}
if tokens[0] == "By" {
tokens = tokens[1:]
}
if len(tokens) == 0 || tokens[0] == "And" || tokens[0] == "Or" {
return QuerySpec{}, InvalidQueryError
}
var operator Operator
var predicates []Predicate
var aggregatedToken predicateToken
for _, token := range tokens {
if token != "And" && token != "Or" {
aggregatedToken = append(aggregatedToken, token)
} else if len(aggregatedToken) == 0 {
return QuerySpec{}, InvalidQueryError
} else if token == "And" && operator != OperatorOr {
operator = OperatorAnd
predicate := aggregatedToken.ToPredicate(paramIndex)
predicates = append(predicates, predicate)
paramIndex += predicate.Comparator.NumberOfArguments()
aggregatedToken = predicateToken{}
} else if token == "Or" && operator != OperatorAnd {
operator = OperatorOr
predicate := aggregatedToken.ToPredicate(paramIndex)
predicates = append(predicates, predicate)
paramIndex += predicate.Comparator.NumberOfArguments()
aggregatedToken = predicateToken{}
} else {
return QuerySpec{}, InvalidQueryError
}
}
if len(aggregatedToken) == 0 {
return QuerySpec{}, InvalidQueryError
}
predicates = append(predicates, aggregatedToken.ToPredicate(paramIndex))
return QuerySpec{Operator: operator, Predicates: predicates}, nil
}