diff --git a/.golangci.yml b/.golangci.yml index f481709..45af5a6 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -2,6 +2,7 @@ linters: enable: - errname - errorlint + - goerr113 - lll - stylecheck linters-settings: diff --git a/examples/getting-started/main.go b/examples/getting-started/main.go index f892957..7e3d0ef 100644 --- a/examples/getting-started/main.go +++ b/examples/getting-started/main.go @@ -12,7 +12,7 @@ import ( // Replace these values with your own connection option. This connection option is hard-coded for easy // demonstration. Make sure not to hard-code the credentials in the production code. const ( - connectionString = "mongodb://lineman:lineman@localhost:27017" + connectionString = "mongodb://admin:password@localhost:27017" databaseName = "repogen_examples" collectionName = "gettingstarted_user" ) diff --git a/examples/getting-started/user.go b/examples/getting-started/user.go index 18aa0da..a514e9e 100644 --- a/examples/getting-started/user.go +++ b/examples/getting-started/user.go @@ -16,27 +16,32 @@ type UserModel struct { //go:generate repogen -src=user.go -dest=user_repo.go -model=UserModel -repo=UserRepository -// UserRepository is an interface that describes the specification of querying user data in the database +// UserRepository is an interface that describes the specification of querying +// user data in the database. type UserRepository interface { - // InsertOne stores userModel into the database and returns inserted ID if insertion - // succeeds and returns error if insertion fails. + // InsertOne stores userModel into the database and returns inserted ID + // if insertion succeeds and returns error if insertion fails. InsertOne(ctx context.Context, userModel *UserModel) (interface{}, error) - // FindByUsername queries user by username. If a user with specified username exists, - // the user will be returned. Otherwise, error will be returned. + // FindByUsername queries user by username. If a user with specified + // username exists, the user will be returned. Otherwise, error will be + // returned. FindByUsername(ctx context.Context, username string) (*UserModel, error) - // UpdateDisplayNameByID updates a user with the specified ID with a new display name. - // If there is a user matches the query, it will return true. Error will be returned - // only when error occurs while accessing the database. + // UpdateDisplayNameByID updates a user with the specified ID with a new + // display name. If there is a user matches the query, it will return + // true. Error will be returned only when error occurs while accessing + // the database. UpdateDisplayNameByID(ctx context.Context, displayName string, id primitive.ObjectID) (bool, error) - // DeleteByCity deletes users that have `city` value match the parameter and returns - // the match count. The error will be returned only when error occurs while accessing - // the database. This is a MANY mode because the first return type is an integer. + // DeleteByCity deletes users that have `city` value match the parameter + // and returns the match count. The error will be returned only when + // error occurs while accessing the database. This is a MANY mode + // because the first return type is an integer. DeleteByCity(ctx context.Context, city string) (int, error) - // CountByCity returns the number of rows that match the given city parameter. If an - // error occurs while accessing the database, error value will be returned. + // CountByCity returns the number of rows that match the given city + // parameter. If an error occurs while accessing the database, error + // value will be returned. CountByCity(ctx context.Context, city string) (int, error) } diff --git a/internal/mongo/generator_test.go b/internal/mongo/generator_test.go index cc0323c..cdee726 100644 --- a/internal/mongo/generator_test.go +++ b/internal/mongo/generator_test.go @@ -34,6 +34,11 @@ var ( Type: code.SimpleType("NameModel"), Tags: map[string][]string{"bson": {"name"}}, } + referrerField = code.StructField{ + Name: "Referrer", + Type: code.PointerType{ContainedType: code.SimpleType("UserModel")}, + Tags: map[string][]string{"bson": {"referrer"}}, + } consentHistoryField = code.StructField{ Name: "ConsentHistory", Type: code.ArrayType{ContainedType: code.SimpleType("ConsentHistory")}, @@ -68,6 +73,7 @@ var userModel = code.Struct{ genderField, ageField, nameField, + referrerField, consentHistoryField, enabledField, accessTokenField, @@ -882,6 +888,80 @@ func TestGenerateMethod_Find(t *testing.T) { if err := cursor.All(arg0, &entities); err != nil { return nil, err } + return entities, nil`, + }, + { + Name: "find with Exists comparator", + MethodSpec: spec.MethodSpec{ + Name: "FindByReferrerExists", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + }, + Returns: []code.Type{ + code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + code.TypeError, + }, + Operation: spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + { + Comparator: spec.ComparatorExists, + FieldReference: spec.FieldReference{referrerField}, + ParamIndex: 1, + }, + }, + }, + }, + }, + ExpectedBody: ` cursor, err := r.collection.Find(arg0, bson.M{ + "referrer": bson.M{"$exists": 1}, + }, options.Find().SetSort(bson.M{ + })) + if err != nil { + return nil, err + } + var entities []*UserModel + if err := cursor.All(arg0, &entities); err != nil { + return nil, err + } + return entities, nil`, + }, + { + Name: "find with NotExists comparator", + MethodSpec: spec.MethodSpec{ + Name: "FindByReferrerNotExists", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + }, + Returns: []code.Type{ + code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + code.TypeError, + }, + Operation: spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + { + Comparator: spec.ComparatorNotExists, + FieldReference: spec.FieldReference{referrerField}, + ParamIndex: 1, + }, + }, + }, + }, + }, + ExpectedBody: ` cursor, err := r.collection.Find(arg0, bson.M{ + "referrer": bson.M{"$exists": 0}, + }, options.Find().SetSort(bson.M{ + })) + if err != nil { + return nil, err + } + var entities []*UserModel + if err := cursor.All(arg0, &entities); err != nil { + return nil, err + } return entities, nil`, }, { diff --git a/internal/mongo/models.go b/internal/mongo/models.go index 1f12a58..5d5ee75 100644 --- a/internal/mongo/models.go +++ b/internal/mongo/models.go @@ -109,6 +109,10 @@ func (p predicate) Code() string { return fmt.Sprintf(`"%s": true`, p.Field) case spec.ComparatorFalse: return fmt.Sprintf(`"%s": false`, p.Field) + case spec.ComparatorExists: + return fmt.Sprintf(`"%s": bson.M{"$exists": 1}`, p.Field) + case spec.ComparatorNotExists: + return fmt.Sprintf(`"%s": bson.M{"$exists": 0}`, p.Field) } return "" } diff --git a/internal/spec/parser_test.go b/internal/spec/parser_test.go index e0e8756..4c9d7f4 100644 --- a/internal/spec/parser_test.go +++ b/internal/spec/parser_test.go @@ -583,6 +583,52 @@ func TestParseInterfaceMethod_Find(t *testing.T) { }}, }, }, + { + Name: "FindByArgExists method", + Method: code.Method{ + Name: "FindByReferrerExists", + Params: []code.Param{ + {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + }, + Returns: []code.Type{ + code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + code.TypeError, + }, + }, + ExpectedOperation: spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + { + FieldReference: spec.FieldReference{referrerField}, + Comparator: spec.ComparatorExists, + ParamIndex: 1, + }, + }}, + }, + }, + { + Name: "FindByArgNotExists method", + Method: code.Method{ + Name: "FindByReferrerNotExists", + Params: []code.Param{ + {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + }, + Returns: []code.Type{ + code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + code.TypeError, + }, + }, + ExpectedOperation: spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + { + FieldReference: spec.FieldReference{referrerField}, + Comparator: spec.ComparatorNotExists, + ParamIndex: 1, + }, + }}, + }, + }, { Name: "FindByArgOrderByArg method", Method: code.Method{ diff --git a/internal/spec/query.go b/internal/spec/query.go index d187bcc..236b6b5 100644 --- a/internal/spec/query.go +++ b/internal/spec/query.go @@ -44,6 +44,8 @@ const ( ComparatorNotIn Comparator = "NOT_IN" ComparatorTrue Comparator = "EQUAL_TRUE" ComparatorFalse Comparator = "EQUAL_FALSE" + ComparatorExists Comparator = "EXISTS" + ComparatorNotExists Comparator = "NOT_EXISTS" ) // ArgumentTypeFromFieldType returns a type of required argument from the given @@ -63,7 +65,7 @@ func (c Comparator) NumberOfArguments() int { switch c { case ComparatorBetween: return 2 - case ComparatorTrue, ComparatorFalse: + case ComparatorTrue, ComparatorFalse, ComparatorExists, ComparatorNotExists: return 0 default: return 1 @@ -82,7 +84,9 @@ type queryParser struct { StructModel code.Struct } -func (p queryParser) parseQuery(rawTokens []string, paramIndex int) (QuerySpec, error) { +func (p queryParser) parseQuery(rawTokens []string, paramIndex int) (QuerySpec, + error) { + if len(rawTokens) == 0 { return QuerySpec{}, ErrQueryRequired } @@ -154,7 +158,9 @@ func (p queryParser) splitPredicateTokens(tokens []string) (Operator, [][]string return operator, predicateTokens, nil } -func (p queryParser) parsePredicate(t []string, paramIndex int) (Predicate, error) { +func (p queryParser) parsePredicate(t []string, paramIndex int) (Predicate, + error) { + if len(t) > 1 && t[len(t)-1] == "Not" { return p.createPredicate(t[:len(t)-1], ComparatorNot, paramIndex) } @@ -173,6 +179,9 @@ func (p queryParser) parsePredicate(t []string, paramIndex int) (Predicate, erro if len(t) > 2 && t[len(t)-2] == "Not" && t[len(t)-1] == "In" { return p.createPredicate(t[:len(t)-2], ComparatorNotIn, paramIndex) } + if len(t) > 2 && t[len(t)-2] == "Not" && t[len(t)-1] == "Exists" { + return p.createPredicate(t[:len(t)-2], ComparatorNotExists, paramIndex) + } if len(t) > 1 && t[len(t)-1] == "In" { return p.createPredicate(t[:len(t)-1], ComparatorIn, paramIndex) } @@ -185,10 +194,15 @@ func (p queryParser) parsePredicate(t []string, paramIndex int) (Predicate, erro if len(t) > 1 && t[len(t)-1] == "False" { return p.createPredicate(t[:len(t)-1], ComparatorFalse, paramIndex) } + if len(t) > 1 && t[len(t)-1] == "Exists" { + return p.createPredicate(t[:len(t)-1], ComparatorExists, paramIndex) + } return p.createPredicate(t, ComparatorEqual, paramIndex) } -func (p queryParser) createPredicate(t []string, comparator Comparator, paramIndex int) (Predicate, error) { +func (p queryParser) createPredicate(t []string, comparator Comparator, + paramIndex int) (Predicate, error) { + fields, ok := p.fieldResolver.ResolveStructField(p.StructModel, t) if !ok { return Predicate{}, NewStructFieldNotFoundError(t) diff --git a/internal/testutils/multilines.go b/internal/testutils/multilines.go index 35c5c7b..8b9b505 100644 --- a/internal/testutils/multilines.go +++ b/internal/testutils/multilines.go @@ -5,6 +5,60 @@ import ( "strings" ) +// lineMismatchedError is an error indicating unmatched line when comparing +// string using ExpectMultiLineString. +type lineMismatchedError struct { + LineNumber int + Expected string + Received string +} + +// NewLineMismatchedError is a constructor for lineMismatchedError. +func NewLineMismatchedError(lineNumber int, expected, received string) error { + return lineMismatchedError{ + LineNumber: lineNumber, + Expected: expected, + Received: received, + } +} + +func (err lineMismatchedError) Error() string { + return fmt.Sprintf("at line %d\nexpected: %v\nreceived: %v", err.LineNumber, + err.Expected, err.Received) +} + +type missingLinesError struct { + MissingLines []string +} + +// NewMissingLinesError is a constructor for missingLinesError. +func NewMissingLinesError(missingLines []string) error { + return missingLinesError{ + MissingLines: missingLines, + } +} + +func (err missingLinesError) Error() string { + return fmt.Sprintf("missing lines:\n%s", + strings.Join(err.MissingLines, "\n")) +} + +type unexpectedLinesError struct { + UnexpectedLines []string +} + +// NewUnexpectedLinesError is a constructor for unexpectedLinesError. +func NewUnexpectedLinesError(unexpectedLines []string) error { + return unexpectedLinesError{ + UnexpectedLines: unexpectedLines, + } +} + +func (err unexpectedLinesError) Error() string { + return fmt.Sprintf("unexpected lines:\n%s", + strings.Join(err.UnexpectedLines, "\n")) +} + // ExpectMultiLineString compares two multi-line strings and report the // difference. func ExpectMultiLineString(expected, actual string) error { @@ -18,14 +72,14 @@ func ExpectMultiLineString(expected, actual string) error { for i := 0; i < numberOfComparableLines; i++ { if expectedLines[i] != actualLines[i] { - return fmt.Errorf("at line %d\nexpected: %v\nreceived: %v", i+1, expectedLines[i], actualLines[i]) + return NewLineMismatchedError(i+1, expectedLines[i], actualLines[i]) } } if len(expectedLines) < len(actualLines) { - return fmt.Errorf("unexpected lines:\n%s", strings.Join(actualLines[len(expectedLines):], "\n")) + return NewUnexpectedLinesError(actualLines[len(expectedLines):]) } else if len(expectedLines) > len(actualLines) { - return fmt.Errorf("missing lines:\n%s", strings.Join(expectedLines[len(actualLines):], "\n")) + return NewMissingLinesError(expectedLines[len(actualLines):]) } return nil diff --git a/main.go b/main.go index fe4202e..dd0bd9e 100644 --- a/main.go +++ b/main.go @@ -96,15 +96,20 @@ func generateFromRequest(pkgDir, structModelName, repositoryInterfaceName string return generateRepository(pkg, structModelName, repositoryInterfaceName) } +var ( + errStructNotFound = errors.New("struct not found") + errInterfaceNotFound = errors.New("interface not found") +) + func generateRepository(pkg code.Package, structModelName, repositoryInterfaceName string) (string, error) { structModel, ok := pkg.Structs[structModelName] if !ok { - return "", errors.New("struct model not found") + return "", errStructNotFound } intf, ok := pkg.Interfaces[repositoryInterfaceName] if !ok { - return "", errors.New("interface model not found") + return "", errInterfaceNotFound } var methodSpecs []spec.MethodSpec