Add count operation

This commit is contained in:
sunboyy 2021-02-06 18:05:47 +07:00
parent c752849518
commit 496f9541cb
12 changed files with 858 additions and 261 deletions

View file

@ -1,8 +1,6 @@
package spec
import (
"strings"
"github.com/sunboyy/repogen/internal/code"
)
@ -57,91 +55,7 @@ type DeleteOperation struct {
Query QuerySpec
}
// QuerySpec is a set of conditions of querying the database
type QuerySpec struct {
Operator Operator
Predicates []Predicate
}
// NumberOfArguments returns number of arguments required to perform the query
func (q QuerySpec) NumberOfArguments() int {
var totalArgs int
for _, predicate := range q.Predicates {
totalArgs += predicate.Comparator.NumberOfArguments()
}
return totalArgs
}
// Operator is a boolean operator for merging conditions
type Operator string
// boolean operator types
const (
OperatorAnd Operator = "AND"
OperatorOr Operator = "OR"
)
// Comparator is a comparison operator of the condition to query the data
type Comparator string
// comparator types
const (
ComparatorNot Comparator = "NOT"
ComparatorEqual Comparator = "EQUAL"
ComparatorLessThan Comparator = "LESS_THAN"
ComparatorLessThanEqual Comparator = "LESS_THAN_EQUAL"
ComparatorGreaterThan Comparator = "GREATER_THAN"
ComparatorGreaterThanEqual Comparator = "GREATER_THAN_EQUAL"
ComparatorBetween Comparator = "BETWEEN"
ComparatorIn Comparator = "IN"
)
// ArgumentTypeFromFieldType returns a type of required argument from the given struct field type
func (c Comparator) ArgumentTypeFromFieldType(t code.Type) code.Type {
if c == ComparatorIn {
return code.ArrayType{ContainedType: t}
}
return t
}
// NumberOfArguments returns the number of arguments required to perform the comparison
func (c Comparator) NumberOfArguments() int {
if c == ComparatorBetween {
return 2
}
return 1
}
// Predicate is a criteria for querying a field
type Predicate struct {
Field string
Comparator Comparator
ParamIndex int
}
type predicateToken []string
func (t predicateToken) ToPredicate(paramIndex int) Predicate {
if len(t) > 1 && t[len(t)-1] == "Not" {
return Predicate{Field: strings.Join(t[:len(t)-1], ""), Comparator: ComparatorNot, ParamIndex: paramIndex}
}
if len(t) > 2 && t[len(t)-2] == "Less" && t[len(t)-1] == "Than" {
return Predicate{Field: strings.Join(t[:len(t)-2], ""), Comparator: ComparatorLessThan, ParamIndex: paramIndex}
}
if len(t) > 3 && t[len(t)-3] == "Less" && t[len(t)-2] == "Than" && t[len(t)-1] == "Equal" {
return Predicate{Field: strings.Join(t[:len(t)-3], ""), Comparator: ComparatorLessThanEqual, ParamIndex: paramIndex}
}
if len(t) > 2 && t[len(t)-2] == "Greater" && t[len(t)-1] == "Than" {
return Predicate{Field: strings.Join(t[:len(t)-2], ""), Comparator: ComparatorGreaterThan, ParamIndex: paramIndex}
}
if len(t) > 3 && t[len(t)-3] == "Greater" && t[len(t)-2] == "Than" && t[len(t)-1] == "Equal" {
return Predicate{Field: strings.Join(t[:len(t)-3], ""), Comparator: ComparatorGreaterThanEqual, ParamIndex: paramIndex}
}
if len(t) > 1 && t[len(t)-1] == "In" {
return Predicate{Field: strings.Join(t[:len(t)-1], ""), Comparator: ComparatorIn, ParamIndex: paramIndex}
}
if len(t) > 1 && t[len(t)-1] == "Between" {
return Predicate{Field: strings.Join(t[:len(t)-1], ""), Comparator: ComparatorBetween, ParamIndex: paramIndex}
}
return Predicate{Field: strings.Join(t, ""), Comparator: ComparatorEqual, ParamIndex: paramIndex}
// CountOperation is a method specification for count operations
type CountOperation struct {
Query QuerySpec
}

View file

@ -31,6 +31,8 @@ func (p interfaceMethodParser) Parse() (MethodSpec, error) {
return p.parseUpdateMethod(methodNameTokens[1:])
case "Delete":
return p.parseDeleteMethod(methodNameTokens[1:])
case "Count":
return p.parseCountMethod(methodNameTokens[1:])
}
return MethodSpec{}, UnknownOperationError
}
@ -55,14 +57,9 @@ func (p interfaceMethodParser) parseInsertMethod(tokens []string) (MethodSpec, e
return MethodSpec{}, InvalidParamError
}
return MethodSpec{
Name: p.Method.Name,
Params: p.Method.Params,
Returns: p.Method.Returns,
Operation: InsertOperation{
Mode: mode,
},
}, nil
return p.createMethodSpec(InsertOperation{
Mode: mode,
}), nil
}
func (p interfaceMethodParser) extractInsertReturns(returns []code.Type) (QueryMode, error) {
@ -99,12 +96,12 @@ func (p interfaceMethodParser) parseFindMethod(tokens []string) (MethodSpec, err
return MethodSpec{}, UnsupportedNameError
}
mode, err := p.extractFindReturns(p.Method.Returns)
mode, err := p.extractModelOrSliceReturns(p.Method.Returns)
if err != nil {
return MethodSpec{}, err
}
querySpec, err := p.parseQuery(tokens, 1)
querySpec, err := parseQuery(tokens, 1)
if err != nil {
return MethodSpec{}, err
}
@ -117,18 +114,13 @@ func (p interfaceMethodParser) parseFindMethod(tokens []string) (MethodSpec, err
return MethodSpec{}, err
}
return MethodSpec{
Name: p.Method.Name,
Params: p.Method.Params,
Returns: p.Method.Returns,
Operation: FindOperation{
Mode: mode,
Query: querySpec,
},
}, nil
return p.createMethodSpec(FindOperation{
Mode: mode,
Query: querySpec,
}), nil
}
func (p interfaceMethodParser) extractFindReturns(returns []code.Type) (QueryMode, error) {
func (p interfaceMethodParser) extractModelOrSliceReturns(returns []code.Type) (QueryMode, error) {
if len(returns) != 2 {
return "", UnsupportedReturnError
}
@ -166,7 +158,7 @@ func (p interfaceMethodParser) parseUpdateMethod(tokens []string) (MethodSpec, e
return MethodSpec{}, UnsupportedNameError
}
mode, err := p.extractCountReturns(p.Method.Returns)
mode, err := p.extractIntOrBoolReturns(p.Method.Returns)
if err != nil {
return MethodSpec{}, err
}
@ -193,7 +185,7 @@ func (p interfaceMethodParser) parseUpdateMethod(tokens []string) (MethodSpec, e
}
fields = append(fields, UpdateField{Name: aggregatedToken, ParamIndex: paramIndex})
querySpec, err := p.parseQuery(tokens, 1+len(fields))
querySpec, err := parseQuery(tokens, 1+len(fields))
if err != nil {
return MethodSpec{}, err
}
@ -217,16 +209,11 @@ func (p interfaceMethodParser) parseUpdateMethod(tokens []string) (MethodSpec, e
return MethodSpec{}, err
}
return MethodSpec{
Name: p.Method.Name,
Params: p.Method.Params,
Returns: p.Method.Returns,
Operation: UpdateOperation{
Fields: fields,
Mode: mode,
Query: querySpec,
},
}, nil
return p.createMethodSpec(UpdateOperation{
Fields: fields,
Mode: mode,
Query: querySpec,
}), nil
}
func (p interfaceMethodParser) parseDeleteMethod(tokens []string) (MethodSpec, error) {
@ -234,12 +221,12 @@ func (p interfaceMethodParser) parseDeleteMethod(tokens []string) (MethodSpec, e
return MethodSpec{}, UnsupportedNameError
}
mode, err := p.extractCountReturns(p.Method.Returns)
mode, err := p.extractIntOrBoolReturns(p.Method.Returns)
if err != nil {
return MethodSpec{}, err
}
querySpec, err := p.parseQuery(tokens, 1)
querySpec, err := parseQuery(tokens, 1)
if err != nil {
return MethodSpec{}, err
}
@ -252,18 +239,56 @@ func (p interfaceMethodParser) parseDeleteMethod(tokens []string) (MethodSpec, e
return MethodSpec{}, err
}
return MethodSpec{
Name: p.Method.Name,
Params: p.Method.Params,
Returns: p.Method.Returns,
Operation: DeleteOperation{
Mode: mode,
Query: querySpec,
},
}, nil
return p.createMethodSpec(DeleteOperation{
Mode: mode,
Query: querySpec,
}), nil
}
func (p interfaceMethodParser) extractCountReturns(returns []code.Type) (QueryMode, error) {
func (p interfaceMethodParser) parseCountMethod(tokens []string) (MethodSpec, error) {
if len(tokens) == 0 {
return MethodSpec{}, UnsupportedNameError
}
if err := p.validateCountReturns(p.Method.Returns); err != nil {
return MethodSpec{}, err
}
querySpec, err := parseQuery(tokens, 1)
if err != nil {
return MethodSpec{}, err
}
if err := p.validateContextParam(); err != nil {
return MethodSpec{}, err
}
if err := p.validateQueryFromParams(p.Method.Params[1:], querySpec); err != nil {
return MethodSpec{}, err
}
return p.createMethodSpec(CountOperation{
Query: querySpec,
}), nil
}
func (p interfaceMethodParser) validateCountReturns(returns []code.Type) error {
if len(returns) != 2 {
return UnsupportedReturnError
}
if returns[0] != code.SimpleType("int") {
return UnsupportedReturnError
}
if returns[1] != code.SimpleType("error") {
return UnsupportedReturnError
}
return nil
}
func (p interfaceMethodParser) extractIntOrBoolReturns(returns []code.Type) (QueryMode, error) {
if len(returns) != 2 {
return "", UnsupportedReturnError
}
@ -285,58 +310,6 @@ func (p interfaceMethodParser) extractCountReturns(returns []code.Type) (QueryMo
return "", UnsupportedReturnError
}
func (p interfaceMethodParser) parseQuery(tokens []string, paramIndex int) (QuerySpec, error) {
if len(tokens) == 0 {
return QuerySpec{}, InvalidQueryError
}
if len(tokens) == 1 && tokens[0] == "All" {
return QuerySpec{}, nil
}
if tokens[0] == "One" {
tokens = tokens[1:]
}
if tokens[0] == "By" {
tokens = tokens[1:]
}
if tokens[0] == "And" || tokens[0] == "Or" {
return QuerySpec{}, InvalidQueryError
}
var operator Operator
var predicates []Predicate
var aggregatedToken predicateToken
for _, token := range tokens {
if token != "And" && token != "Or" {
aggregatedToken = append(aggregatedToken, token)
} else if len(aggregatedToken) == 0 {
return QuerySpec{}, InvalidQueryError
} else if token == "And" && operator != OperatorOr {
operator = OperatorAnd
predicate := aggregatedToken.ToPredicate(paramIndex)
predicates = append(predicates, predicate)
paramIndex += predicate.Comparator.NumberOfArguments()
aggregatedToken = predicateToken{}
} else if token == "Or" && operator != OperatorAnd {
operator = OperatorOr
predicate := aggregatedToken.ToPredicate(paramIndex)
predicates = append(predicates, predicate)
paramIndex += predicate.Comparator.NumberOfArguments()
aggregatedToken = predicateToken{}
} else {
return QuerySpec{}, InvalidQueryError
}
}
if len(aggregatedToken) == 0 {
return QuerySpec{}, InvalidQueryError
}
predicates = append(predicates, aggregatedToken.ToPredicate(paramIndex))
return QuerySpec{Operator: operator, Predicates: predicates}, nil
}
func (p interfaceMethodParser) validateContextParam() error {
contextType := code.ExternalType{PackageAlias: "context", Name: "Context"}
if len(p.Method.Params) == 0 || p.Method.Params[0].Type != contextType {
@ -368,3 +341,12 @@ func (p interfaceMethodParser) validateQueryFromParams(params []code.Param, quer
return nil
}
func (p interfaceMethodParser) createMethodSpec(operation Operation) MethodSpec {
return MethodSpec{
Name: p.Method.Name,
Params: p.Method.Params,
Returns: p.Method.Returns,
Operation: operation,
}
}

View file

@ -778,6 +778,67 @@ func TestParseInterfaceMethod_Delete(t *testing.T) {
}
}
func TestParseInterfaceMethod_Count(t *testing.T) {
testTable := []ParseInterfaceMethodTestCase{
{
Name: "CountAll method",
Method: code.Method{
Name: "CountAll",
Params: []code.Param{
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
},
Returns: []code.Type{
code.SimpleType("int"),
code.SimpleType("error"),
},
},
ExpectedOperation: spec.CountOperation{
Query: spec.QuerySpec{},
},
},
{
Name: "CountByArg method",
Method: code.Method{
Name: "CountByGender",
Params: []code.Param{
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Type: code.SimpleType("Gender")},
},
Returns: []code.Type{
code.SimpleType("int"),
code.SimpleType("error"),
},
},
ExpectedOperation: spec.CountOperation{
Query: spec.QuerySpec{
Predicates: []spec.Predicate{
{Field: "Gender", Comparator: spec.ComparatorEqual, ParamIndex: 1},
},
},
},
},
}
for _, testCase := range testTable {
t.Run(testCase.Name, func(t *testing.T) {
actualSpec, err := spec.ParseInterfaceMethod(structModel, testCase.Method)
if err != nil {
t.Errorf("Error = %s", err)
}
expectedOutput := spec.MethodSpec{
Name: testCase.Method.Name,
Params: testCase.Method.Params,
Returns: testCase.Method.Returns,
Operation: testCase.ExpectedOperation,
}
if !reflect.DeepEqual(actualSpec, expectedOutput) {
t.Errorf("Expected = %v\nReceived = %v", expectedOutput, actualSpec)
}
})
}
}
type ParseInterfaceMethodInvalidTestCase struct {
Name string
Method code.Method
@ -1403,3 +1464,142 @@ func TestParseInterfaceMethod_Delete_Invalid(t *testing.T) {
})
}
}
func TestParseInterfaceMethod_Count_Invalid(t *testing.T) {
testTable := []ParseInterfaceMethodInvalidTestCase{
{
Name: "unsupported count method name",
Method: code.Method{
Name: "Count",
},
ExpectedError: spec.UnsupportedNameError,
},
{
Name: "invalid number of returns",
Method: code.Method{
Name: "CountAll",
Returns: []code.Type{
code.SimpleType("int"),
code.SimpleType("error"),
code.SimpleType("bool"),
},
},
ExpectedError: spec.UnsupportedReturnError,
},
{
Name: "invalid number of returns",
Method: code.Method{
Name: "CountAll",
Returns: []code.Type{
code.SimpleType("int"),
code.SimpleType("error"),
code.SimpleType("bool"),
},
},
ExpectedError: spec.UnsupportedReturnError,
},
{
Name: "invalid integer return",
Method: code.Method{
Name: "CountAll",
Returns: []code.Type{
code.SimpleType("int64"),
code.SimpleType("error"),
},
},
ExpectedError: spec.UnsupportedReturnError,
},
{
Name: "error return not provided",
Method: code.Method{
Name: "CountAll",
Returns: []code.Type{
code.SimpleType("int"),
code.SimpleType("bool"),
},
},
ExpectedError: spec.UnsupportedReturnError,
},
{
Name: "invalid query",
Method: code.Method{
Name: "CountBy",
Params: []code.Param{
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
},
Returns: []code.Type{
code.SimpleType("int"),
code.SimpleType("error"),
},
},
ExpectedError: spec.InvalidQueryError,
},
{
Name: "context parameter not provided",
Method: code.Method{
Name: "CountAll",
Returns: []code.Type{
code.SimpleType("int"),
code.SimpleType("error"),
},
},
ExpectedError: spec.ContextParamRequiredError,
},
{
Name: "mismatched number of parameter",
Method: code.Method{
Name: "CountByGender",
Params: []code.Param{
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Type: code.SimpleType("Gender")},
{Type: code.SimpleType("int")},
},
Returns: []code.Type{
code.SimpleType("int"),
code.SimpleType("error"),
},
},
ExpectedError: spec.InvalidParamError,
},
{
Name: "mismatched method parameter type",
Method: code.Method{
Name: "CountByGender",
Params: []code.Param{
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Type: code.SimpleType("string")},
},
Returns: []code.Type{
code.SimpleType("int"),
code.SimpleType("error"),
},
},
ExpectedError: spec.InvalidParamError,
},
{
Name: "struct field not found",
Method: code.Method{
Name: "CountByCountry",
Params: []code.Param{
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Type: code.SimpleType("string")},
},
Returns: []code.Type{
code.SimpleType("int"),
code.SimpleType("error"),
},
},
ExpectedError: spec.StructFieldNotFoundError,
},
}
for _, testCase := range testTable {
t.Run(testCase.Name, func(t *testing.T) {
_, err := spec.ParseInterfaceMethod(structModel, testCase.Method)
if err != testCase.ExpectedError {
t.Errorf("\nExpected = %v\nReceived = %v", testCase.ExpectedError, err)
}
})
}
}

148
internal/spec/query.go Normal file
View file

@ -0,0 +1,148 @@
package spec
import (
"strings"
"github.com/sunboyy/repogen/internal/code"
)
// QuerySpec is a set of conditions of querying the database
type QuerySpec struct {
Operator Operator
Predicates []Predicate
}
// NumberOfArguments returns number of arguments required to perform the query
func (q QuerySpec) NumberOfArguments() int {
var totalArgs int
for _, predicate := range q.Predicates {
totalArgs += predicate.Comparator.NumberOfArguments()
}
return totalArgs
}
// Operator is a boolean operator for merging conditions
type Operator string
// boolean operator types
const (
OperatorAnd Operator = "AND"
OperatorOr Operator = "OR"
)
// Comparator is a comparison operator of the condition to query the data
type Comparator string
// comparator types
const (
ComparatorNot Comparator = "NOT"
ComparatorEqual Comparator = "EQUAL"
ComparatorLessThan Comparator = "LESS_THAN"
ComparatorLessThanEqual Comparator = "LESS_THAN_EQUAL"
ComparatorGreaterThan Comparator = "GREATER_THAN"
ComparatorGreaterThanEqual Comparator = "GREATER_THAN_EQUAL"
ComparatorBetween Comparator = "BETWEEN"
ComparatorIn Comparator = "IN"
)
// ArgumentTypeFromFieldType returns a type of required argument from the given struct field type
func (c Comparator) ArgumentTypeFromFieldType(t code.Type) code.Type {
if c == ComparatorIn {
return code.ArrayType{ContainedType: t}
}
return t
}
// NumberOfArguments returns the number of arguments required to perform the comparison
func (c Comparator) NumberOfArguments() int {
if c == ComparatorBetween {
return 2
}
return 1
}
// Predicate is a criteria for querying a field
type Predicate struct {
Field string
Comparator Comparator
ParamIndex int
}
type predicateToken []string
func (t predicateToken) ToPredicate(paramIndex int) Predicate {
if len(t) > 1 && t[len(t)-1] == "Not" {
return Predicate{Field: strings.Join(t[:len(t)-1], ""), Comparator: ComparatorNot, ParamIndex: paramIndex}
}
if len(t) > 2 && t[len(t)-2] == "Less" && t[len(t)-1] == "Than" {
return Predicate{Field: strings.Join(t[:len(t)-2], ""), Comparator: ComparatorLessThan, ParamIndex: paramIndex}
}
if len(t) > 3 && t[len(t)-3] == "Less" && t[len(t)-2] == "Than" && t[len(t)-1] == "Equal" {
return Predicate{Field: strings.Join(t[:len(t)-3], ""), Comparator: ComparatorLessThanEqual, ParamIndex: paramIndex}
}
if len(t) > 2 && t[len(t)-2] == "Greater" && t[len(t)-1] == "Than" {
return Predicate{Field: strings.Join(t[:len(t)-2], ""), Comparator: ComparatorGreaterThan, ParamIndex: paramIndex}
}
if len(t) > 3 && t[len(t)-3] == "Greater" && t[len(t)-2] == "Than" && t[len(t)-1] == "Equal" {
return Predicate{Field: strings.Join(t[:len(t)-3], ""), Comparator: ComparatorGreaterThanEqual, ParamIndex: paramIndex}
}
if len(t) > 1 && t[len(t)-1] == "In" {
return Predicate{Field: strings.Join(t[:len(t)-1], ""), Comparator: ComparatorIn, ParamIndex: paramIndex}
}
if len(t) > 1 && t[len(t)-1] == "Between" {
return Predicate{Field: strings.Join(t[:len(t)-1], ""), Comparator: ComparatorBetween, ParamIndex: paramIndex}
}
return Predicate{Field: strings.Join(t, ""), Comparator: ComparatorEqual, ParamIndex: paramIndex}
}
func parseQuery(tokens []string, paramIndex int) (QuerySpec, error) {
if len(tokens) == 0 {
return QuerySpec{}, InvalidQueryError
}
if len(tokens) == 1 && tokens[0] == "All" {
return QuerySpec{}, nil
}
if tokens[0] == "One" {
tokens = tokens[1:]
}
if tokens[0] == "By" {
tokens = tokens[1:]
}
if len(tokens) == 0 || tokens[0] == "And" || tokens[0] == "Or" {
return QuerySpec{}, InvalidQueryError
}
var operator Operator
var predicates []Predicate
var aggregatedToken predicateToken
for _, token := range tokens {
if token != "And" && token != "Or" {
aggregatedToken = append(aggregatedToken, token)
} else if len(aggregatedToken) == 0 {
return QuerySpec{}, InvalidQueryError
} else if token == "And" && operator != OperatorOr {
operator = OperatorAnd
predicate := aggregatedToken.ToPredicate(paramIndex)
predicates = append(predicates, predicate)
paramIndex += predicate.Comparator.NumberOfArguments()
aggregatedToken = predicateToken{}
} else if token == "Or" && operator != OperatorAnd {
operator = OperatorOr
predicate := aggregatedToken.ToPredicate(paramIndex)
predicates = append(predicates, predicate)
paramIndex += predicate.Comparator.NumberOfArguments()
aggregatedToken = predicateToken{}
} else {
return QuerySpec{}, InvalidQueryError
}
}
if len(aggregatedToken) == 0 {
return QuerySpec{}, InvalidQueryError
}
predicates = append(predicates, aggregatedToken.ToPredicate(paramIndex))
return QuerySpec{Operator: operator, Predicates: predicates}, nil
}