Merge pull request #14 from sunboyy/operation-count

Add count operation
This commit is contained in:
sunboyy 2021-02-06 18:08:45 +07:00 committed by GitHub
commit 8b19829a3c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 858 additions and 261 deletions

View file

@ -17,12 +17,20 @@ jobs:
with: with:
go-version: 1.15 go-version: 1.15
- name: Install dependencies
run: go get -u golang.org/x/lint/golint
- name: Build - name: Build
run: go build -v ./... run: go build -v ./...
- name: Test - name: Test
run: go test -v ./... -covermode=count -coverprofile=cover.out run: go test -v ./... -covermode=count -coverprofile=cover.out
- name: Vet & Lint
run: |
go vet ./...
golint ./...
- uses: codecov/codecov-action@v1 - uses: codecov/codecov-action@v1
with: with:
flags: unittests flags: unittests

View file

@ -7,24 +7,11 @@ on:
branches: [ main ] branches: [ main ]
jobs: jobs:
lint: golangci-lint:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v2
- name: Set up Go
uses: actions/setup-go@v2
with:
go-version: 1.15
- name: Install dependencies
run: go get -u golang.org/x/lint/golint
- name: Vet & Lint
run: |
go vet ./...
golint ./...
- name: golangci-lint - name: golangci-lint
uses: golangci/golangci-lint-action@v2 uses: golangci/golangci-lint-action@v2
with: with:

View file

@ -32,7 +32,7 @@ $ go get github.com/sunboyy/repogen
### Step 2: Write a repository specification ### Step 2: Write a repository specification
Write repository specification as an interface in the same file as the model struct. There are 4 types of operations that are currently supported and are determined by the first word of the method name. Single-entity and multiple-entity modes are determined be the first return value. More complex queries can also be written. Write repository specification as an interface in the same file as the model struct. There are 5 types of operations that are currently supported and are determined by the first word of the method name. Single-entity and multiple-entity modes are determined be the first return value. More complex queries can also be written.
```go ```go
// You write this interface specification (comment is optional) // You write this interface specification (comment is optional)
@ -54,6 +54,10 @@ type UserRepository interface {
// the match count. The error will be returned only when error occurs while accessing // the match count. The error will be returned only when error occurs while accessing
// the database. This is a MANY mode because the first return type is an integer. // the database. This is a MANY mode because the first return type is an integer.
DeleteByCity(ctx context.Context, city string) (int, error) DeleteByCity(ctx context.Context, city string) (int, error)
// CountByCity returns the number of rows that match the given city parameter. If an error occurs while
// accessing the database, error value will be returned.
CountByCity(ctx context.Context, city string) (int, error)
} }
``` ```
@ -77,3 +81,7 @@ $ repogen -src=examples/getting-started/user.go -dest=examples/getting-started/u
``` ```
You can also write the above command in the `go:generate` format inside Go files in order to generate the implementation when `go generate` command is executed. You can also write the above command in the `go:generate` format inside Go files in order to generate the implementation when `go generate` command is executed.
## License
Licensed under [MIT](https://github.com/sunboyy/repogen/blob/main/LICENSE)

View file

@ -35,4 +35,8 @@ type UserRepository interface {
// The error will be returned only when error occurs while accessing the database. This is a MANY mode // The error will be returned only when error occurs while accessing the database. This is a MANY mode
// because the first return type is an integer. // because the first return type is an integer.
DeleteByCity(ctx context.Context, city string) (int, error) DeleteByCity(ctx context.Context, city string) (int, error)
// CountByCity returns the number of rows that match the given city parameter. If an error occurs while
// accessing the database, error value will be returned.
CountByCity(ctx context.Context, city string) (int, error)
} }

View file

@ -60,3 +60,13 @@ func (r *UserRepositoryMongo) DeleteByCity(arg0 context.Context, arg1 string) (i
} }
return int(result.DeletedCount), nil return int(result.DeletedCount), nil
} }
func (r *UserRepositoryMongo) CountByCity(arg0 context.Context, arg1 string) (int, error) {
count, err := r.collection.CountDocuments(arg0, bson.M{
"city": arg1,
})
if err != nil {
return 0, err
}
return int(count), nil
}

View file

@ -85,6 +85,8 @@ func (g RepositoryGenerator) generateMethodImplementation(methodSpec spec.Method
return g.generateUpdateImplementation(operation) return g.generateUpdateImplementation(operation)
case spec.DeleteOperation: case spec.DeleteOperation:
return g.generateDeleteImplementation(operation) return g.generateDeleteImplementation(operation)
case spec.CountOperation:
return g.generateCountImplementation(operation)
} }
return "", OperationNotSupportedError return "", OperationNotSupportedError
@ -98,8 +100,6 @@ func (g RepositoryGenerator) generateInsertImplementation(operation spec.InsertO
} }
func (g RepositoryGenerator) generateFindImplementation(operation spec.FindOperation) (string, error) { func (g RepositoryGenerator) generateFindImplementation(operation spec.FindOperation) (string, error) {
buffer := new(bytes.Buffer)
querySpec, err := g.mongoQuerySpec(operation.Query) querySpec, err := g.mongoQuerySpec(operation.Query)
if err != nil { if err != nil {
return "", err return "", err
@ -111,31 +111,12 @@ func (g RepositoryGenerator) generateFindImplementation(operation spec.FindOpera
} }
if operation.Mode == spec.QueryModeOne { if operation.Mode == spec.QueryModeOne {
tmpl, err := template.New("mongo_repository_findone").Parse(findOneTemplate) return generateFromTemplate("mongo_repository_findone", findOneTemplate, tmplData)
if err != nil {
return "", err
} }
return generateFromTemplate("mongo_repository_findmany", findManyTemplate, tmplData)
if err := tmpl.Execute(buffer, tmplData); err != nil {
return "", err
}
} else {
tmpl, err := template.New("mongo_repository_findmany").Parse(findManyTemplate)
if err != nil {
return "", err
}
if err := tmpl.Execute(buffer, tmplData); err != nil {
return "", err
}
}
return buffer.String(), nil
} }
func (g RepositoryGenerator) generateUpdateImplementation(operation spec.UpdateOperation) (string, error) { func (g RepositoryGenerator) generateUpdateImplementation(operation spec.UpdateOperation) (string, error) {
buffer := new(bytes.Buffer)
var fields []updateField var fields []updateField
for _, field := range operation.Fields { for _, field := range operation.Fields {
bsonTag, err := g.bsonTagFromFieldName(field.Name) bsonTag, err := g.bsonTagFromFieldName(field.Name)
@ -156,31 +137,12 @@ func (g RepositoryGenerator) generateUpdateImplementation(operation spec.UpdateO
} }
if operation.Mode == spec.QueryModeOne { if operation.Mode == spec.QueryModeOne {
tmpl, err := template.New("mongo_repository_updateone").Parse(updateOneTemplate) return generateFromTemplate("mongo_repository_updateone", updateOneTemplate, tmplData)
if err != nil {
return "", err
} }
return generateFromTemplate("mongo_repository_updatemany", updateManyTemplate, tmplData)
if err := tmpl.Execute(buffer, tmplData); err != nil {
return "", err
}
} else {
tmpl, err := template.New("mongo_repository_updatemany").Parse(updateManyTemplate)
if err != nil {
return "", err
}
if err := tmpl.Execute(buffer, tmplData); err != nil {
return "", err
}
}
return buffer.String(), nil
} }
func (g RepositoryGenerator) generateDeleteImplementation(operation spec.DeleteOperation) (string, error) { func (g RepositoryGenerator) generateDeleteImplementation(operation spec.DeleteOperation) (string, error) {
buffer := new(bytes.Buffer)
querySpec, err := g.mongoQuerySpec(operation.Query) querySpec, err := g.mongoQuerySpec(operation.Query)
if err != nil { if err != nil {
return "", err return "", err
@ -191,26 +153,22 @@ func (g RepositoryGenerator) generateDeleteImplementation(operation spec.DeleteO
} }
if operation.Mode == spec.QueryModeOne { if operation.Mode == spec.QueryModeOne {
tmpl, err := template.New("mongo_repository_deleteone").Parse(deleteOneTemplate) return generateFromTemplate("mongo_repository_deleteone", deleteOneTemplate, tmplData)
}
return generateFromTemplate("mongo_repository_deletemany", deleteManyTemplate, tmplData)
}
func (g RepositoryGenerator) generateCountImplementation(operation spec.CountOperation) (string, error) {
querySpec, err := g.mongoQuerySpec(operation.Query)
if err != nil { if err != nil {
return "", err return "", err
} }
if err := tmpl.Execute(buffer, tmplData); err != nil { tmplData := mongoCountTemplateData{
return "", err QuerySpec: querySpec,
}
} else {
tmpl, err := template.New("mongo_repository_deletemany").Parse(deleteManyTemplate)
if err != nil {
return "", err
} }
if err := tmpl.Execute(buffer, tmplData); err != nil { return generateFromTemplate("mongo_repository_count", countTemplate, tmplData)
return "", err
}
}
return buffer.String(), nil
} }
func (g RepositoryGenerator) mongoQuerySpec(query spec.QuerySpec) (querySpec, error) { func (g RepositoryGenerator) mongoQuerySpec(query spec.QuerySpec) (querySpec, error) {
@ -252,3 +210,17 @@ func (g RepositoryGenerator) bsonTagFromFieldName(fieldName string) (string, err
func (g RepositoryGenerator) structName() string { func (g RepositoryGenerator) structName() string {
return g.InterfaceName + "Mongo" return g.InterfaceName + "Mongo"
} }
func generateFromTemplate(name string, templateString string, tmplData interface{}) (string, error) {
tmpl, err := template.New(name).Parse(templateString)
if err != nil {
return "", err
}
buffer := new(bytes.Buffer)
if err := tmpl.Execute(buffer, tmplData); err != nil {
return "", err
}
return buffer.String(), nil
}

View file

@ -1089,6 +1089,358 @@ func (r *UserRepositoryMongo) DeleteByGenderIn(arg0 context.Context, arg1 []Gend
} }
} }
func TestGenerateMethod_Count(t *testing.T) {
testTable := []GenerateMethodTestCase{
{
Name: "simple count method",
MethodSpec: spec.MethodSpec{
Name: "CountByGender",
Params: []code.Param{
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Type: code.SimpleType("Gender")},
},
Returns: []code.Type{
code.SimpleType("int"),
code.SimpleType("error"),
},
Operation: spec.CountOperation{
Query: spec.QuerySpec{
Predicates: []spec.Predicate{
{Field: "Gender", Comparator: spec.ComparatorEqual, ParamIndex: 1},
},
},
},
},
ExpectedCode: `
func (r *UserRepositoryMongo) CountByGender(arg0 context.Context, arg1 Gender) (int, error) {
count, err := r.collection.CountDocuments(arg0, bson.M{
"gender": arg1,
})
if err != nil {
return 0, err
}
return int(count), nil
}
`,
},
{
Name: "count with And operator",
MethodSpec: spec.MethodSpec{
Name: "CountByGenderAndCity",
Params: []code.Param{
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Type: code.SimpleType("Gender")},
{Type: code.SimpleType("int")},
},
Returns: []code.Type{
code.SimpleType("int"),
code.SimpleType("error"),
},
Operation: spec.CountOperation{
Query: spec.QuerySpec{
Operator: spec.OperatorAnd,
Predicates: []spec.Predicate{
{Field: "Gender", Comparator: spec.ComparatorEqual, ParamIndex: 1},
{Field: "Age", Comparator: spec.ComparatorEqual, ParamIndex: 2},
},
},
},
},
ExpectedCode: `
func (r *UserRepositoryMongo) CountByGenderAndCity(arg0 context.Context, arg1 Gender, arg2 int) (int, error) {
count, err := r.collection.CountDocuments(arg0, bson.M{
"gender": arg1,
"age": arg2,
})
if err != nil {
return 0, err
}
return int(count), nil
}
`,
},
{
Name: "count with Or operator",
MethodSpec: spec.MethodSpec{
Name: "CountByGenderOrCity",
Params: []code.Param{
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Type: code.SimpleType("Gender")},
{Type: code.SimpleType("int")},
},
Returns: []code.Type{
code.SimpleType("int"),
code.SimpleType("error"),
},
Operation: spec.CountOperation{
Query: spec.QuerySpec{
Operator: spec.OperatorOr,
Predicates: []spec.Predicate{
{Field: "Gender", Comparator: spec.ComparatorEqual, ParamIndex: 1},
{Field: "Age", Comparator: spec.ComparatorEqual, ParamIndex: 2},
},
},
},
},
ExpectedCode: `
func (r *UserRepositoryMongo) CountByGenderOrCity(arg0 context.Context, arg1 Gender, arg2 int) (int, error) {
count, err := r.collection.CountDocuments(arg0, bson.M{
"$or": []bson.M{
{"gender": arg1},
{"age": arg2},
},
})
if err != nil {
return 0, err
}
return int(count), nil
}
`,
},
{
Name: "count with Not comparator",
MethodSpec: spec.MethodSpec{
Name: "CountByGenderNot",
Params: []code.Param{
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Type: code.SimpleType("Gender")},
},
Returns: []code.Type{
code.SimpleType("int"),
code.SimpleType("error"),
},
Operation: spec.CountOperation{
Query: spec.QuerySpec{
Predicates: []spec.Predicate{
{Field: "Gender", Comparator: spec.ComparatorNot, ParamIndex: 1},
},
},
},
},
ExpectedCode: `
func (r *UserRepositoryMongo) CountByGenderNot(arg0 context.Context, arg1 Gender) (int, error) {
count, err := r.collection.CountDocuments(arg0, bson.M{
"gender": bson.M{"$ne": arg1},
})
if err != nil {
return 0, err
}
return int(count), nil
}
`,
},
{
Name: "count with LessThan comparator",
MethodSpec: spec.MethodSpec{
Name: "CountByAgeLessThan",
Params: []code.Param{
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Type: code.SimpleType("int")},
},
Returns: []code.Type{
code.SimpleType("int"),
code.SimpleType("error"),
},
Operation: spec.CountOperation{
Query: spec.QuerySpec{
Predicates: []spec.Predicate{
{Field: "Age", Comparator: spec.ComparatorLessThan, ParamIndex: 1},
},
},
},
},
ExpectedCode: `
func (r *UserRepositoryMongo) CountByAgeLessThan(arg0 context.Context, arg1 int) (int, error) {
count, err := r.collection.CountDocuments(arg0, bson.M{
"age": bson.M{"$lt": arg1},
})
if err != nil {
return 0, err
}
return int(count), nil
}
`,
},
{
Name: "count with LessThanEqual comparator",
MethodSpec: spec.MethodSpec{
Name: "CountByAgeLessThanEqual",
Params: []code.Param{
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Type: code.SimpleType("int")},
},
Returns: []code.Type{
code.SimpleType("int"),
code.SimpleType("error"),
},
Operation: spec.CountOperation{
Query: spec.QuerySpec{
Predicates: []spec.Predicate{
{Field: "Age", Comparator: spec.ComparatorLessThanEqual, ParamIndex: 1},
},
},
},
},
ExpectedCode: `
func (r *UserRepositoryMongo) CountByAgeLessThanEqual(arg0 context.Context, arg1 int) (int, error) {
count, err := r.collection.CountDocuments(arg0, bson.M{
"age": bson.M{"$lte": arg1},
})
if err != nil {
return 0, err
}
return int(count), nil
}
`,
},
{
Name: "count with GreaterThan comparator",
MethodSpec: spec.MethodSpec{
Name: "CountByAgeGreaterThan",
Params: []code.Param{
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Type: code.SimpleType("int")},
},
Returns: []code.Type{
code.SimpleType("int"),
code.SimpleType("error"),
},
Operation: spec.CountOperation{
Query: spec.QuerySpec{
Predicates: []spec.Predicate{
{Field: "Age", Comparator: spec.ComparatorGreaterThan, ParamIndex: 1},
},
},
},
},
ExpectedCode: `
func (r *UserRepositoryMongo) CountByAgeGreaterThan(arg0 context.Context, arg1 int) (int, error) {
count, err := r.collection.CountDocuments(arg0, bson.M{
"age": bson.M{"$gt": arg1},
})
if err != nil {
return 0, err
}
return int(count), nil
}
`,
},
{
Name: "count with GreaterThanEqual comparator",
MethodSpec: spec.MethodSpec{
Name: "CountByAgeGreaterThanEqual",
Params: []code.Param{
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Type: code.SimpleType("int")},
},
Returns: []code.Type{
code.SimpleType("int"),
code.SimpleType("error"),
},
Operation: spec.CountOperation{
Query: spec.QuerySpec{
Predicates: []spec.Predicate{
{Field: "Age", Comparator: spec.ComparatorGreaterThanEqual, ParamIndex: 1},
},
},
},
},
ExpectedCode: `
func (r *UserRepositoryMongo) CountByAgeGreaterThanEqual(arg0 context.Context, arg1 int) (int, error) {
count, err := r.collection.CountDocuments(arg0, bson.M{
"age": bson.M{"$gte": arg1},
})
if err != nil {
return 0, err
}
return int(count), nil
}
`,
},
{
Name: "count with Between comparator",
MethodSpec: spec.MethodSpec{
Name: "CountByAgeBetween",
Params: []code.Param{
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Type: code.SimpleType("int")},
{Type: code.SimpleType("int")},
},
Returns: []code.Type{
code.SimpleType("int"),
code.SimpleType("error"),
},
Operation: spec.CountOperation{
Query: spec.QuerySpec{
Predicates: []spec.Predicate{
{Field: "Age", Comparator: spec.ComparatorBetween, ParamIndex: 1},
},
},
},
},
ExpectedCode: `
func (r *UserRepositoryMongo) CountByAgeBetween(arg0 context.Context, arg1 int, arg2 int) (int, error) {
count, err := r.collection.CountDocuments(arg0, bson.M{
"age": bson.M{"$gte": arg1, "$lte": arg2},
})
if err != nil {
return 0, err
}
return int(count), nil
}
`,
},
{
Name: "count with In comparator",
MethodSpec: spec.MethodSpec{
Name: "CountByAgeIn",
Params: []code.Param{
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Type: code.ArrayType{ContainedType: code.SimpleType("int")}},
},
Returns: []code.Type{
code.SimpleType("int"),
code.SimpleType("error"),
},
Operation: spec.CountOperation{
Query: spec.QuerySpec{
Predicates: []spec.Predicate{
{Field: "Age", Comparator: spec.ComparatorIn, ParamIndex: 1},
},
},
},
},
ExpectedCode: `
func (r *UserRepositoryMongo) CountByAgeIn(arg0 context.Context, arg1 []int) (int, error) {
count, err := r.collection.CountDocuments(arg0, bson.M{
"age": bson.M{"$in": arg1},
})
if err != nil {
return 0, err
}
return int(count), nil
}
`,
},
}
for _, testCase := range testTable {
t.Run(testCase.Name, func(t *testing.T) {
generator := mongo.NewGenerator(userModel, "UserRepository")
buffer := new(bytes.Buffer)
err := generator.GenerateMethod(testCase.MethodSpec, buffer)
if err != nil {
t.Error(err)
}
if err := testutils.ExpectMultiLineString(testCase.ExpectedCode, buffer.String()); err != nil {
t.Error(err)
}
})
}
}
type GenerateMethodInvalidTestCase struct { type GenerateMethodInvalidTestCase struct {
Name string Name string
Method spec.MethodSpec Method spec.MethodSpec

View file

@ -159,3 +159,15 @@ const deleteManyTemplate = ` result, err := r.collection.DeleteMany(arg0, bson.M
return 0, err return 0, err
} }
return int(result.DeletedCount), nil` return int(result.DeletedCount), nil`
type mongoCountTemplateData struct {
QuerySpec querySpec
}
const countTemplate = ` count, err := r.collection.CountDocuments(arg0, bson.M{
{{.QuerySpec.Code}}
})
if err != nil {
return 0, err
}
return int(count), nil`

View file

@ -1,8 +1,6 @@
package spec package spec
import ( import (
"strings"
"github.com/sunboyy/repogen/internal/code" "github.com/sunboyy/repogen/internal/code"
) )
@ -57,91 +55,7 @@ type DeleteOperation struct {
Query QuerySpec Query QuerySpec
} }
// QuerySpec is a set of conditions of querying the database // CountOperation is a method specification for count operations
type QuerySpec struct { type CountOperation struct {
Operator Operator Query QuerySpec
Predicates []Predicate
}
// NumberOfArguments returns number of arguments required to perform the query
func (q QuerySpec) NumberOfArguments() int {
var totalArgs int
for _, predicate := range q.Predicates {
totalArgs += predicate.Comparator.NumberOfArguments()
}
return totalArgs
}
// Operator is a boolean operator for merging conditions
type Operator string
// boolean operator types
const (
OperatorAnd Operator = "AND"
OperatorOr Operator = "OR"
)
// Comparator is a comparison operator of the condition to query the data
type Comparator string
// comparator types
const (
ComparatorNot Comparator = "NOT"
ComparatorEqual Comparator = "EQUAL"
ComparatorLessThan Comparator = "LESS_THAN"
ComparatorLessThanEqual Comparator = "LESS_THAN_EQUAL"
ComparatorGreaterThan Comparator = "GREATER_THAN"
ComparatorGreaterThanEqual Comparator = "GREATER_THAN_EQUAL"
ComparatorBetween Comparator = "BETWEEN"
ComparatorIn Comparator = "IN"
)
// ArgumentTypeFromFieldType returns a type of required argument from the given struct field type
func (c Comparator) ArgumentTypeFromFieldType(t code.Type) code.Type {
if c == ComparatorIn {
return code.ArrayType{ContainedType: t}
}
return t
}
// NumberOfArguments returns the number of arguments required to perform the comparison
func (c Comparator) NumberOfArguments() int {
if c == ComparatorBetween {
return 2
}
return 1
}
// Predicate is a criteria for querying a field
type Predicate struct {
Field string
Comparator Comparator
ParamIndex int
}
type predicateToken []string
func (t predicateToken) ToPredicate(paramIndex int) Predicate {
if len(t) > 1 && t[len(t)-1] == "Not" {
return Predicate{Field: strings.Join(t[:len(t)-1], ""), Comparator: ComparatorNot, ParamIndex: paramIndex}
}
if len(t) > 2 && t[len(t)-2] == "Less" && t[len(t)-1] == "Than" {
return Predicate{Field: strings.Join(t[:len(t)-2], ""), Comparator: ComparatorLessThan, ParamIndex: paramIndex}
}
if len(t) > 3 && t[len(t)-3] == "Less" && t[len(t)-2] == "Than" && t[len(t)-1] == "Equal" {
return Predicate{Field: strings.Join(t[:len(t)-3], ""), Comparator: ComparatorLessThanEqual, ParamIndex: paramIndex}
}
if len(t) > 2 && t[len(t)-2] == "Greater" && t[len(t)-1] == "Than" {
return Predicate{Field: strings.Join(t[:len(t)-2], ""), Comparator: ComparatorGreaterThan, ParamIndex: paramIndex}
}
if len(t) > 3 && t[len(t)-3] == "Greater" && t[len(t)-2] == "Than" && t[len(t)-1] == "Equal" {
return Predicate{Field: strings.Join(t[:len(t)-3], ""), Comparator: ComparatorGreaterThanEqual, ParamIndex: paramIndex}
}
if len(t) > 1 && t[len(t)-1] == "In" {
return Predicate{Field: strings.Join(t[:len(t)-1], ""), Comparator: ComparatorIn, ParamIndex: paramIndex}
}
if len(t) > 1 && t[len(t)-1] == "Between" {
return Predicate{Field: strings.Join(t[:len(t)-1], ""), Comparator: ComparatorBetween, ParamIndex: paramIndex}
}
return Predicate{Field: strings.Join(t, ""), Comparator: ComparatorEqual, ParamIndex: paramIndex}
} }

View file

@ -31,6 +31,8 @@ func (p interfaceMethodParser) Parse() (MethodSpec, error) {
return p.parseUpdateMethod(methodNameTokens[1:]) return p.parseUpdateMethod(methodNameTokens[1:])
case "Delete": case "Delete":
return p.parseDeleteMethod(methodNameTokens[1:]) return p.parseDeleteMethod(methodNameTokens[1:])
case "Count":
return p.parseCountMethod(methodNameTokens[1:])
} }
return MethodSpec{}, UnknownOperationError return MethodSpec{}, UnknownOperationError
} }
@ -55,14 +57,9 @@ func (p interfaceMethodParser) parseInsertMethod(tokens []string) (MethodSpec, e
return MethodSpec{}, InvalidParamError return MethodSpec{}, InvalidParamError
} }
return MethodSpec{ return p.createMethodSpec(InsertOperation{
Name: p.Method.Name,
Params: p.Method.Params,
Returns: p.Method.Returns,
Operation: InsertOperation{
Mode: mode, Mode: mode,
}, }), nil
}, nil
} }
func (p interfaceMethodParser) extractInsertReturns(returns []code.Type) (QueryMode, error) { func (p interfaceMethodParser) extractInsertReturns(returns []code.Type) (QueryMode, error) {
@ -99,12 +96,12 @@ func (p interfaceMethodParser) parseFindMethod(tokens []string) (MethodSpec, err
return MethodSpec{}, UnsupportedNameError return MethodSpec{}, UnsupportedNameError
} }
mode, err := p.extractFindReturns(p.Method.Returns) mode, err := p.extractModelOrSliceReturns(p.Method.Returns)
if err != nil { if err != nil {
return MethodSpec{}, err return MethodSpec{}, err
} }
querySpec, err := p.parseQuery(tokens, 1) querySpec, err := parseQuery(tokens, 1)
if err != nil { if err != nil {
return MethodSpec{}, err return MethodSpec{}, err
} }
@ -117,18 +114,13 @@ func (p interfaceMethodParser) parseFindMethod(tokens []string) (MethodSpec, err
return MethodSpec{}, err return MethodSpec{}, err
} }
return MethodSpec{ return p.createMethodSpec(FindOperation{
Name: p.Method.Name,
Params: p.Method.Params,
Returns: p.Method.Returns,
Operation: FindOperation{
Mode: mode, Mode: mode,
Query: querySpec, Query: querySpec,
}, }), nil
}, nil
} }
func (p interfaceMethodParser) extractFindReturns(returns []code.Type) (QueryMode, error) { func (p interfaceMethodParser) extractModelOrSliceReturns(returns []code.Type) (QueryMode, error) {
if len(returns) != 2 { if len(returns) != 2 {
return "", UnsupportedReturnError return "", UnsupportedReturnError
} }
@ -166,7 +158,7 @@ func (p interfaceMethodParser) parseUpdateMethod(tokens []string) (MethodSpec, e
return MethodSpec{}, UnsupportedNameError return MethodSpec{}, UnsupportedNameError
} }
mode, err := p.extractCountReturns(p.Method.Returns) mode, err := p.extractIntOrBoolReturns(p.Method.Returns)
if err != nil { if err != nil {
return MethodSpec{}, err return MethodSpec{}, err
} }
@ -193,7 +185,7 @@ func (p interfaceMethodParser) parseUpdateMethod(tokens []string) (MethodSpec, e
} }
fields = append(fields, UpdateField{Name: aggregatedToken, ParamIndex: paramIndex}) fields = append(fields, UpdateField{Name: aggregatedToken, ParamIndex: paramIndex})
querySpec, err := p.parseQuery(tokens, 1+len(fields)) querySpec, err := parseQuery(tokens, 1+len(fields))
if err != nil { if err != nil {
return MethodSpec{}, err return MethodSpec{}, err
} }
@ -217,16 +209,11 @@ func (p interfaceMethodParser) parseUpdateMethod(tokens []string) (MethodSpec, e
return MethodSpec{}, err return MethodSpec{}, err
} }
return MethodSpec{ return p.createMethodSpec(UpdateOperation{
Name: p.Method.Name,
Params: p.Method.Params,
Returns: p.Method.Returns,
Operation: UpdateOperation{
Fields: fields, Fields: fields,
Mode: mode, Mode: mode,
Query: querySpec, Query: querySpec,
}, }), nil
}, nil
} }
func (p interfaceMethodParser) parseDeleteMethod(tokens []string) (MethodSpec, error) { func (p interfaceMethodParser) parseDeleteMethod(tokens []string) (MethodSpec, error) {
@ -234,12 +221,12 @@ func (p interfaceMethodParser) parseDeleteMethod(tokens []string) (MethodSpec, e
return MethodSpec{}, UnsupportedNameError return MethodSpec{}, UnsupportedNameError
} }
mode, err := p.extractCountReturns(p.Method.Returns) mode, err := p.extractIntOrBoolReturns(p.Method.Returns)
if err != nil { if err != nil {
return MethodSpec{}, err return MethodSpec{}, err
} }
querySpec, err := p.parseQuery(tokens, 1) querySpec, err := parseQuery(tokens, 1)
if err != nil { if err != nil {
return MethodSpec{}, err return MethodSpec{}, err
} }
@ -252,18 +239,56 @@ func (p interfaceMethodParser) parseDeleteMethod(tokens []string) (MethodSpec, e
return MethodSpec{}, err return MethodSpec{}, err
} }
return MethodSpec{ return p.createMethodSpec(DeleteOperation{
Name: p.Method.Name,
Params: p.Method.Params,
Returns: p.Method.Returns,
Operation: DeleteOperation{
Mode: mode, Mode: mode,
Query: querySpec, Query: querySpec,
}, }), nil
}, nil
} }
func (p interfaceMethodParser) extractCountReturns(returns []code.Type) (QueryMode, error) { func (p interfaceMethodParser) parseCountMethod(tokens []string) (MethodSpec, error) {
if len(tokens) == 0 {
return MethodSpec{}, UnsupportedNameError
}
if err := p.validateCountReturns(p.Method.Returns); err != nil {
return MethodSpec{}, err
}
querySpec, err := parseQuery(tokens, 1)
if err != nil {
return MethodSpec{}, err
}
if err := p.validateContextParam(); err != nil {
return MethodSpec{}, err
}
if err := p.validateQueryFromParams(p.Method.Params[1:], querySpec); err != nil {
return MethodSpec{}, err
}
return p.createMethodSpec(CountOperation{
Query: querySpec,
}), nil
}
func (p interfaceMethodParser) validateCountReturns(returns []code.Type) error {
if len(returns) != 2 {
return UnsupportedReturnError
}
if returns[0] != code.SimpleType("int") {
return UnsupportedReturnError
}
if returns[1] != code.SimpleType("error") {
return UnsupportedReturnError
}
return nil
}
func (p interfaceMethodParser) extractIntOrBoolReturns(returns []code.Type) (QueryMode, error) {
if len(returns) != 2 { if len(returns) != 2 {
return "", UnsupportedReturnError return "", UnsupportedReturnError
} }
@ -285,58 +310,6 @@ func (p interfaceMethodParser) extractCountReturns(returns []code.Type) (QueryMo
return "", UnsupportedReturnError return "", UnsupportedReturnError
} }
func (p interfaceMethodParser) parseQuery(tokens []string, paramIndex int) (QuerySpec, error) {
if len(tokens) == 0 {
return QuerySpec{}, InvalidQueryError
}
if len(tokens) == 1 && tokens[0] == "All" {
return QuerySpec{}, nil
}
if tokens[0] == "One" {
tokens = tokens[1:]
}
if tokens[0] == "By" {
tokens = tokens[1:]
}
if tokens[0] == "And" || tokens[0] == "Or" {
return QuerySpec{}, InvalidQueryError
}
var operator Operator
var predicates []Predicate
var aggregatedToken predicateToken
for _, token := range tokens {
if token != "And" && token != "Or" {
aggregatedToken = append(aggregatedToken, token)
} else if len(aggregatedToken) == 0 {
return QuerySpec{}, InvalidQueryError
} else if token == "And" && operator != OperatorOr {
operator = OperatorAnd
predicate := aggregatedToken.ToPredicate(paramIndex)
predicates = append(predicates, predicate)
paramIndex += predicate.Comparator.NumberOfArguments()
aggregatedToken = predicateToken{}
} else if token == "Or" && operator != OperatorAnd {
operator = OperatorOr
predicate := aggregatedToken.ToPredicate(paramIndex)
predicates = append(predicates, predicate)
paramIndex += predicate.Comparator.NumberOfArguments()
aggregatedToken = predicateToken{}
} else {
return QuerySpec{}, InvalidQueryError
}
}
if len(aggregatedToken) == 0 {
return QuerySpec{}, InvalidQueryError
}
predicates = append(predicates, aggregatedToken.ToPredicate(paramIndex))
return QuerySpec{Operator: operator, Predicates: predicates}, nil
}
func (p interfaceMethodParser) validateContextParam() error { func (p interfaceMethodParser) validateContextParam() error {
contextType := code.ExternalType{PackageAlias: "context", Name: "Context"} contextType := code.ExternalType{PackageAlias: "context", Name: "Context"}
if len(p.Method.Params) == 0 || p.Method.Params[0].Type != contextType { if len(p.Method.Params) == 0 || p.Method.Params[0].Type != contextType {
@ -368,3 +341,12 @@ func (p interfaceMethodParser) validateQueryFromParams(params []code.Param, quer
return nil return nil
} }
func (p interfaceMethodParser) createMethodSpec(operation Operation) MethodSpec {
return MethodSpec{
Name: p.Method.Name,
Params: p.Method.Params,
Returns: p.Method.Returns,
Operation: operation,
}
}

View file

@ -778,6 +778,67 @@ func TestParseInterfaceMethod_Delete(t *testing.T) {
} }
} }
func TestParseInterfaceMethod_Count(t *testing.T) {
testTable := []ParseInterfaceMethodTestCase{
{
Name: "CountAll method",
Method: code.Method{
Name: "CountAll",
Params: []code.Param{
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
},
Returns: []code.Type{
code.SimpleType("int"),
code.SimpleType("error"),
},
},
ExpectedOperation: spec.CountOperation{
Query: spec.QuerySpec{},
},
},
{
Name: "CountByArg method",
Method: code.Method{
Name: "CountByGender",
Params: []code.Param{
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Type: code.SimpleType("Gender")},
},
Returns: []code.Type{
code.SimpleType("int"),
code.SimpleType("error"),
},
},
ExpectedOperation: spec.CountOperation{
Query: spec.QuerySpec{
Predicates: []spec.Predicate{
{Field: "Gender", Comparator: spec.ComparatorEqual, ParamIndex: 1},
},
},
},
},
}
for _, testCase := range testTable {
t.Run(testCase.Name, func(t *testing.T) {
actualSpec, err := spec.ParseInterfaceMethod(structModel, testCase.Method)
if err != nil {
t.Errorf("Error = %s", err)
}
expectedOutput := spec.MethodSpec{
Name: testCase.Method.Name,
Params: testCase.Method.Params,
Returns: testCase.Method.Returns,
Operation: testCase.ExpectedOperation,
}
if !reflect.DeepEqual(actualSpec, expectedOutput) {
t.Errorf("Expected = %v\nReceived = %v", expectedOutput, actualSpec)
}
})
}
}
type ParseInterfaceMethodInvalidTestCase struct { type ParseInterfaceMethodInvalidTestCase struct {
Name string Name string
Method code.Method Method code.Method
@ -1403,3 +1464,142 @@ func TestParseInterfaceMethod_Delete_Invalid(t *testing.T) {
}) })
} }
} }
func TestParseInterfaceMethod_Count_Invalid(t *testing.T) {
testTable := []ParseInterfaceMethodInvalidTestCase{
{
Name: "unsupported count method name",
Method: code.Method{
Name: "Count",
},
ExpectedError: spec.UnsupportedNameError,
},
{
Name: "invalid number of returns",
Method: code.Method{
Name: "CountAll",
Returns: []code.Type{
code.SimpleType("int"),
code.SimpleType("error"),
code.SimpleType("bool"),
},
},
ExpectedError: spec.UnsupportedReturnError,
},
{
Name: "invalid number of returns",
Method: code.Method{
Name: "CountAll",
Returns: []code.Type{
code.SimpleType("int"),
code.SimpleType("error"),
code.SimpleType("bool"),
},
},
ExpectedError: spec.UnsupportedReturnError,
},
{
Name: "invalid integer return",
Method: code.Method{
Name: "CountAll",
Returns: []code.Type{
code.SimpleType("int64"),
code.SimpleType("error"),
},
},
ExpectedError: spec.UnsupportedReturnError,
},
{
Name: "error return not provided",
Method: code.Method{
Name: "CountAll",
Returns: []code.Type{
code.SimpleType("int"),
code.SimpleType("bool"),
},
},
ExpectedError: spec.UnsupportedReturnError,
},
{
Name: "invalid query",
Method: code.Method{
Name: "CountBy",
Params: []code.Param{
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
},
Returns: []code.Type{
code.SimpleType("int"),
code.SimpleType("error"),
},
},
ExpectedError: spec.InvalidQueryError,
},
{
Name: "context parameter not provided",
Method: code.Method{
Name: "CountAll",
Returns: []code.Type{
code.SimpleType("int"),
code.SimpleType("error"),
},
},
ExpectedError: spec.ContextParamRequiredError,
},
{
Name: "mismatched number of parameter",
Method: code.Method{
Name: "CountByGender",
Params: []code.Param{
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Type: code.SimpleType("Gender")},
{Type: code.SimpleType("int")},
},
Returns: []code.Type{
code.SimpleType("int"),
code.SimpleType("error"),
},
},
ExpectedError: spec.InvalidParamError,
},
{
Name: "mismatched method parameter type",
Method: code.Method{
Name: "CountByGender",
Params: []code.Param{
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Type: code.SimpleType("string")},
},
Returns: []code.Type{
code.SimpleType("int"),
code.SimpleType("error"),
},
},
ExpectedError: spec.InvalidParamError,
},
{
Name: "struct field not found",
Method: code.Method{
Name: "CountByCountry",
Params: []code.Param{
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Type: code.SimpleType("string")},
},
Returns: []code.Type{
code.SimpleType("int"),
code.SimpleType("error"),
},
},
ExpectedError: spec.StructFieldNotFoundError,
},
}
for _, testCase := range testTable {
t.Run(testCase.Name, func(t *testing.T) {
_, err := spec.ParseInterfaceMethod(structModel, testCase.Method)
if err != testCase.ExpectedError {
t.Errorf("\nExpected = %v\nReceived = %v", testCase.ExpectedError, err)
}
})
}
}

148
internal/spec/query.go Normal file
View file

@ -0,0 +1,148 @@
package spec
import (
"strings"
"github.com/sunboyy/repogen/internal/code"
)
// QuerySpec is a set of conditions of querying the database
type QuerySpec struct {
Operator Operator
Predicates []Predicate
}
// NumberOfArguments returns number of arguments required to perform the query
func (q QuerySpec) NumberOfArguments() int {
var totalArgs int
for _, predicate := range q.Predicates {
totalArgs += predicate.Comparator.NumberOfArguments()
}
return totalArgs
}
// Operator is a boolean operator for merging conditions
type Operator string
// boolean operator types
const (
OperatorAnd Operator = "AND"
OperatorOr Operator = "OR"
)
// Comparator is a comparison operator of the condition to query the data
type Comparator string
// comparator types
const (
ComparatorNot Comparator = "NOT"
ComparatorEqual Comparator = "EQUAL"
ComparatorLessThan Comparator = "LESS_THAN"
ComparatorLessThanEqual Comparator = "LESS_THAN_EQUAL"
ComparatorGreaterThan Comparator = "GREATER_THAN"
ComparatorGreaterThanEqual Comparator = "GREATER_THAN_EQUAL"
ComparatorBetween Comparator = "BETWEEN"
ComparatorIn Comparator = "IN"
)
// ArgumentTypeFromFieldType returns a type of required argument from the given struct field type
func (c Comparator) ArgumentTypeFromFieldType(t code.Type) code.Type {
if c == ComparatorIn {
return code.ArrayType{ContainedType: t}
}
return t
}
// NumberOfArguments returns the number of arguments required to perform the comparison
func (c Comparator) NumberOfArguments() int {
if c == ComparatorBetween {
return 2
}
return 1
}
// Predicate is a criteria for querying a field
type Predicate struct {
Field string
Comparator Comparator
ParamIndex int
}
type predicateToken []string
func (t predicateToken) ToPredicate(paramIndex int) Predicate {
if len(t) > 1 && t[len(t)-1] == "Not" {
return Predicate{Field: strings.Join(t[:len(t)-1], ""), Comparator: ComparatorNot, ParamIndex: paramIndex}
}
if len(t) > 2 && t[len(t)-2] == "Less" && t[len(t)-1] == "Than" {
return Predicate{Field: strings.Join(t[:len(t)-2], ""), Comparator: ComparatorLessThan, ParamIndex: paramIndex}
}
if len(t) > 3 && t[len(t)-3] == "Less" && t[len(t)-2] == "Than" && t[len(t)-1] == "Equal" {
return Predicate{Field: strings.Join(t[:len(t)-3], ""), Comparator: ComparatorLessThanEqual, ParamIndex: paramIndex}
}
if len(t) > 2 && t[len(t)-2] == "Greater" && t[len(t)-1] == "Than" {
return Predicate{Field: strings.Join(t[:len(t)-2], ""), Comparator: ComparatorGreaterThan, ParamIndex: paramIndex}
}
if len(t) > 3 && t[len(t)-3] == "Greater" && t[len(t)-2] == "Than" && t[len(t)-1] == "Equal" {
return Predicate{Field: strings.Join(t[:len(t)-3], ""), Comparator: ComparatorGreaterThanEqual, ParamIndex: paramIndex}
}
if len(t) > 1 && t[len(t)-1] == "In" {
return Predicate{Field: strings.Join(t[:len(t)-1], ""), Comparator: ComparatorIn, ParamIndex: paramIndex}
}
if len(t) > 1 && t[len(t)-1] == "Between" {
return Predicate{Field: strings.Join(t[:len(t)-1], ""), Comparator: ComparatorBetween, ParamIndex: paramIndex}
}
return Predicate{Field: strings.Join(t, ""), Comparator: ComparatorEqual, ParamIndex: paramIndex}
}
func parseQuery(tokens []string, paramIndex int) (QuerySpec, error) {
if len(tokens) == 0 {
return QuerySpec{}, InvalidQueryError
}
if len(tokens) == 1 && tokens[0] == "All" {
return QuerySpec{}, nil
}
if tokens[0] == "One" {
tokens = tokens[1:]
}
if tokens[0] == "By" {
tokens = tokens[1:]
}
if len(tokens) == 0 || tokens[0] == "And" || tokens[0] == "Or" {
return QuerySpec{}, InvalidQueryError
}
var operator Operator
var predicates []Predicate
var aggregatedToken predicateToken
for _, token := range tokens {
if token != "And" && token != "Or" {
aggregatedToken = append(aggregatedToken, token)
} else if len(aggregatedToken) == 0 {
return QuerySpec{}, InvalidQueryError
} else if token == "And" && operator != OperatorOr {
operator = OperatorAnd
predicate := aggregatedToken.ToPredicate(paramIndex)
predicates = append(predicates, predicate)
paramIndex += predicate.Comparator.NumberOfArguments()
aggregatedToken = predicateToken{}
} else if token == "Or" && operator != OperatorAnd {
operator = OperatorOr
predicate := aggregatedToken.ToPredicate(paramIndex)
predicates = append(predicates, predicate)
paramIndex += predicate.Comparator.NumberOfArguments()
aggregatedToken = predicateToken{}
} else {
return QuerySpec{}, InvalidQueryError
}
}
if len(aggregatedToken) == 0 {
return QuerySpec{}, InvalidQueryError
}
predicates = append(predicates, aggregatedToken.ToPredicate(paramIndex))
return QuerySpec{Operator: operator, Predicates: predicates}, nil
}