Move function body generation into codegen package (#35)

* Move function body generation into codegen package
This commit is contained in:
sunboyy 2023-04-18 20:21:46 +07:00 committed by GitHub
parent b00c2ac77a
commit bdb63c8129
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
23 changed files with 3795 additions and 2642 deletions

264
internal/codegen/body.go Normal file
View file

@ -0,0 +1,264 @@
package codegen
import (
"fmt"
"strings"
"github.com/sunboyy/repogen/internal/code"
)
type FunctionBody []Statement
func (b FunctionBody) Code() string {
var lines []string
for _, statement := range b {
stmtLines := statement.CodeLines()
for _, line := range stmtLines {
lines = append(lines, fmt.Sprintf("\t%s", line))
}
}
return strings.Join(lines, "\n")
}
type Statement interface {
CodeLines() []string
}
type RawStatement string
func (stmt RawStatement) CodeLines() []string {
return []string{string(stmt)}
}
type Identifier string
func (id Identifier) CodeLines() []string {
return []string{string(id)}
}
type DeclStatement struct {
Name string
Type code.Type
}
func (stmt DeclStatement) CodeLines() []string {
return []string{fmt.Sprintf("var %s %s", stmt.Name, stmt.Type.Code())}
}
type DeclAssignStatement struct {
Vars []string
Values StatementList
}
func (stmt DeclAssignStatement) CodeLines() []string {
vars := strings.Join(stmt.Vars, ", ")
lines := stmt.Values.CodeLines()
lines[0] = fmt.Sprintf("%s := %s", vars, lines[0])
return lines
}
type AssignStatement struct {
Vars []string
Values StatementList
}
func (stmt AssignStatement) CodeLines() []string {
vars := strings.Join(stmt.Vars, ", ")
lines := stmt.Values.CodeLines()
lines[0] = fmt.Sprintf("%s = %s", vars, lines[0])
return lines
}
type StatementList []Statement
func (l StatementList) CodeLines() []string {
if len(l) == 0 {
return []string{""}
}
return concatenateStatements(", ", []Statement(l))
}
type ReturnStatement StatementList
func (stmt ReturnStatement) CodeLines() []string {
lines := StatementList(stmt).CodeLines()
lines[0] = fmt.Sprintf("return %s", lines[0])
return lines
}
type ChainStatement []Statement
func (stmt ChainStatement) CodeLines() []string {
return concatenateStatements(".", []Statement(stmt))
}
type CallStatement struct {
FuncName string
Params StatementList
}
func (stmt CallStatement) CodeLines() []string {
lines := stmt.Params.CodeLines()
lines[0] = fmt.Sprintf("%s(%s", stmt.FuncName, lines[0])
lines[len(lines)-1] += ")"
return lines
}
type SliceStatement struct {
Type code.Type
Values []Statement
}
func (stmt SliceStatement) CodeLines() []string {
lines := []string{stmt.Type.Code() + "{"}
for _, value := range stmt.Values {
stmtLines := value.CodeLines()
stmtLines[len(stmtLines)-1] += ","
for _, line := range stmtLines {
lines = append(lines, "\t"+line)
}
}
lines = append(lines, "}")
return lines
}
type MapStatement struct {
Type string
Pairs []MapPair
}
func (stmt MapStatement) CodeLines() []string {
return generateCollectionCodeLines(stmt.Type, stmt.Pairs)
}
type MapPair struct {
Key string
Value Statement
}
func (p MapPair) ItemCodeLines() []string {
lines := p.Value.CodeLines()
lines[0] = fmt.Sprintf(`"%s": %s`, p.Key, lines[0])
return lines
}
type StructStatement struct {
Type string
Pairs []StructFieldPair
}
func (stmt StructStatement) CodeLines() []string {
return generateCollectionCodeLines(stmt.Type, stmt.Pairs)
}
type StructFieldPair struct {
Key string
Value Statement
}
func (p StructFieldPair) ItemCodeLines() []string {
lines := p.Value.CodeLines()
lines[0] = fmt.Sprintf(`%s: %s`, p.Key, lines[0])
return lines
}
type collectionItem interface {
ItemCodeLines() []string
}
func generateCollectionCodeLines[T collectionItem](typ string, pairs []T) []string {
lines := []string{fmt.Sprintf("%s{", typ)}
for _, pair := range pairs {
pairLines := pair.ItemCodeLines()
pairLines[len(pairLines)-1] += ","
for _, line := range pairLines {
lines = append(lines, fmt.Sprintf("\t%s", line))
}
}
lines = append(lines, "}")
return lines
}
type RawBlock struct {
Header []string
Statements []Statement
}
func (b RawBlock) CodeLines() []string {
lines := make([]string, len(b.Header))
copy(lines, b.Header)
lines[len(lines)-1] += " {"
for _, stmt := range b.Statements {
stmtLines := stmt.CodeLines()
for _, line := range stmtLines {
lines = append(lines, fmt.Sprintf("\t%s", line))
}
}
lines = append(lines, "}")
return lines
}
type IfBlock struct {
Condition []Statement
Statements []Statement
}
func (b IfBlock) CodeLines() []string {
conditionCode := concatenateStatements("; ", b.Condition)
conditionCode[0] = "if " + conditionCode[0]
return RawBlock{
Header: conditionCode,
Statements: b.Statements,
}.CodeLines()
}
func concatenateStatements(sep string, statements []Statement) []string {
var lines []string
lastLine := ""
for _, stmt := range statements {
stmtLines := stmt.CodeLines()
if lastLine != "" {
lastLine += sep
}
lastLine += stmtLines[0]
if len(stmtLines) > 1 {
lines = append(lines, lastLine)
lines = append(lines, stmtLines[1:len(stmtLines)-1]...)
lastLine = stmtLines[len(stmtLines)-1]
}
}
if lastLine != "" {
lines = append(lines, lastLine)
}
return lines
}
type ChainBuilder []Statement
func NewChainBuilder(object string) ChainBuilder {
return ChainBuilder{
Identifier(object),
}
}
func (b ChainBuilder) Chain(field string) ChainBuilder {
b = append(b, Identifier(field))
return b
}
func (b ChainBuilder) Call(method string, params ...Statement) ChainBuilder {
b = append(b, CallStatement{
FuncName: method,
Params: params,
})
return b
}
func (b ChainBuilder) Build() ChainStatement {
return ChainStatement(b)
}

View file

@ -0,0 +1,305 @@
package codegen_test
import (
"reflect"
"testing"
"github.com/sunboyy/repogen/internal/code"
"github.com/sunboyy/repogen/internal/codegen"
)
func TestIdentifier(t *testing.T) {
identifier := codegen.Identifier("user")
expected := []string{"user"}
actual := identifier.CodeLines()
if !reflect.DeepEqual(expected, actual) {
t.Errorf("expected=%+v actual=%+v", expected, actual)
}
}
func TestDeclStatement(t *testing.T) {
stmt := codegen.DeclStatement{
Name: "arrs",
Type: code.ArrayType{ContainedType: code.SimpleType("int")},
}
expected := []string{"var arrs []int"}
actual := stmt.CodeLines()
if !reflect.DeepEqual(expected, actual) {
t.Errorf("expected=%+v actual=%+v", expected, actual)
}
}
func TestDeclAssignStatement(t *testing.T) {
stmt := codegen.DeclAssignStatement{
Vars: []string{"value", "err"},
Values: codegen.StatementList{
codegen.Identifier("1"),
codegen.Identifier("nil"),
},
}
expected := []string{"value, err := 1, nil"}
actual := stmt.CodeLines()
if !reflect.DeepEqual(expected, actual) {
t.Errorf("expected=%+v actual=%+v", expected, actual)
}
}
func TestAssignStatement(t *testing.T) {
stmt := codegen.AssignStatement{
Vars: []string{"value", "err"},
Values: codegen.StatementList{
codegen.Identifier("1"),
codegen.Identifier("nil"),
},
}
expected := []string{"value, err = 1, nil"}
actual := stmt.CodeLines()
if !reflect.DeepEqual(expected, actual) {
t.Errorf("expected=%+v actual=%+v", expected, actual)
}
}
func TestReturnStatement(t *testing.T) {
stmt := codegen.ReturnStatement{
codegen.Identifier("result"),
codegen.Identifier("nil"),
}
expected := []string{"return result, nil"}
actual := stmt.CodeLines()
if !reflect.DeepEqual(expected, actual) {
t.Errorf("expected=%+v actual=%+v", expected, actual)
}
}
func TestChainStatement(t *testing.T) {
stmt := codegen.ChainStatement{
codegen.Identifier("r"),
codegen.Identifier("userRepository"),
codegen.CallStatement{
FuncName: "Insert",
Params: codegen.StatementList{
codegen.StructStatement{
Type: "User",
Pairs: []codegen.StructFieldPair{
{
Key: "ID",
Value: codegen.Identifier("arg0"),
},
{
Key: "Name",
Value: codegen.Identifier("arg1"),
},
},
},
},
},
codegen.CallStatement{
FuncName: "Do",
},
}
expected := []string{
"r.userRepository.Insert(User{",
" ID: arg0,",
" Name: arg1,",
"}).Do()",
}
actual := stmt.CodeLines()
if !reflect.DeepEqual(expected, actual) {
t.Errorf("expected=%+v actual=%+v", expected, actual)
}
}
func TestCallStatement(t *testing.T) {
stmt := codegen.CallStatement{
FuncName: "FindByID",
Params: codegen.StatementList{
codegen.Identifier("ctx"),
codegen.Identifier("user"),
},
}
expected := []string{"FindByID(ctx, user)"}
actual := stmt.CodeLines()
if !reflect.DeepEqual(expected, actual) {
t.Errorf("expected=%+v actual=%+v", expected, actual)
}
}
func TestSliceStatement(t *testing.T) {
stmt := codegen.SliceStatement{
Type: code.ArrayType{
ContainedType: code.SimpleType("string"),
},
Values: []codegen.Statement{
codegen.Identifier(`"hello"`),
codegen.ChainStatement{
codegen.CallStatement{
FuncName: "GetUser",
Params: codegen.StatementList{
codegen.Identifier("userID"),
},
},
codegen.Identifier("Name"),
},
},
}
expected := []string{
"[]string{",
` "hello",`,
` GetUser(userID).Name,`,
"}",
}
actual := stmt.CodeLines()
if !reflect.DeepEqual(expected, actual) {
t.Errorf("expected=%+v actual=%+v", expected, actual)
}
}
func TestMapStatement(t *testing.T) {
stmt := codegen.MapStatement{
Type: "map[string]int",
Pairs: []codegen.MapPair{
{
Key: "key1",
Value: codegen.Identifier("value1"),
},
{
Key: "key2",
Value: codegen.Identifier("value2"),
},
},
}
expected := []string{
"map[string]int{",
` "key1": value1,`,
` "key2": value2,`,
"}",
}
actual := stmt.CodeLines()
if !reflect.DeepEqual(expected, actual) {
t.Errorf("expected=%+v actual=%+v", expected, actual)
}
}
func TestStructStatement(t *testing.T) {
stmt := codegen.StructStatement{
Type: "User",
Pairs: []codegen.StructFieldPair{
{
Key: "ID",
Value: codegen.Identifier("arg0"),
},
{
Key: "Name",
Value: codegen.Identifier("arg1"),
},
},
}
expected := []string{
"User{",
` ID: arg0,`,
` Name: arg1,`,
"}",
}
actual := stmt.CodeLines()
if !reflect.DeepEqual(expected, actual) {
t.Errorf("expected=%+v actual=%+v", expected, actual)
}
}
func TestIfBlockStatement(t *testing.T) {
stmt := codegen.IfBlock{
Condition: []codegen.Statement{
codegen.DeclAssignStatement{
Vars: []string{"err"},
Values: codegen.StatementList{
codegen.CallStatement{
FuncName: "Insert",
Params: codegen.StatementList{
codegen.Identifier("ctx"),
codegen.StructStatement{
Type: "User",
Pairs: []codegen.StructFieldPair{
{
Key: "ID",
Value: codegen.Identifier("id"),
},
{
Key: "Name",
Value: codegen.Identifier("name"),
},
},
},
},
},
},
},
codegen.RawStatement("err != nil"),
},
Statements: []codegen.Statement{
codegen.ReturnStatement{
codegen.Identifier("nil"),
codegen.Identifier("err"),
},
},
}
expected := []string{
"if err := Insert(ctx, User{",
" ID: id,",
" Name: name,",
"}); err != nil {",
" return nil, err",
"}",
}
actual := stmt.CodeLines()
if !reflect.DeepEqual(expected, actual) {
t.Errorf("expected=%+v actual=%+v", expected, actual)
}
}
func TestChainBuilder(t *testing.T) {
expected := codegen.ChainStatement{
codegen.Identifier("r"),
codegen.Identifier("repository"),
codegen.CallStatement{
FuncName: "Find",
Params: []codegen.Statement{
codegen.Identifier("ctx"),
},
},
codegen.CallStatement{
FuncName: "Decode",
},
}
actual := codegen.NewChainBuilder("r").
Chain("repository").
Call("Find", codegen.Identifier("ctx")).
Call("Decode").
Build()
if !reflect.DeepEqual(expected, actual) {
t.Errorf("expected=%+v actual=%+v", expected, actual)
}
}

View file

@ -79,17 +79,43 @@ func TestBuilderBuild(t *testing.T) {
{Name: "username", Type: code.TypeString}, {Name: "username", Type: code.TypeString},
}, },
Returns: []code.Type{code.SimpleType("User")}, Returns: []code.Type{code.SimpleType("User")},
Body: ` return User{ Body: codegen.FunctionBody{
ID: primitive.NewObjectID(), codegen.ReturnStatement{
Username: username, codegen.StructStatement{
}`, Type: "User",
Pairs: []codegen.StructFieldPair{
{
Key: "ID",
Value: codegen.ChainStatement{
codegen.Identifier("primitive"),
codegen.CallStatement{
FuncName: "NewObjectID",
},
},
},
{
Key: "Username",
Value: codegen.Identifier("username"),
},
},
},
},
},
}) })
builder.AddImplementer(codegen.MethodBuilder{ builder.AddImplementer(codegen.MethodBuilder{
Receiver: codegen.MethodReceiver{Name: "u", Type: code.SimpleType("User")}, Receiver: codegen.MethodReceiver{Name: "u", Type: code.SimpleType("User")},
Name: "IDHex", Name: "IDHex",
Params: nil, Params: nil,
Returns: []code.Type{code.TypeString}, Returns: []code.Type{code.TypeString},
Body: " return u.ID.Hex()", Body: codegen.FunctionBody{
codegen.ReturnStatement{
codegen.ChainStatement{
codegen.Identifier("u"),
codegen.Identifier("ID"),
codegen.CallStatement{FuncName: "Hex"},
},
},
},
}) })
generatedCode, err := builder.Build() generatedCode, err := builder.Build()

View file

@ -11,7 +11,7 @@ import (
const functionTemplate = ` const functionTemplate = `
func {{.Name}}({{.GenParams}}){{.GenReturns}} { func {{.Name}}({{.GenParams}}){{.GenReturns}} {
{{.Body}} {{.Body.Code}}
} }
` `
@ -20,7 +20,7 @@ type FunctionBuilder struct {
Name string Name string
Params []code.Param Params []code.Param
Returns []code.Type Returns []code.Type
Body string Body FunctionBody
} }
// Impl writes function declatation code to the buffer. // Impl writes function declatation code to the buffer.

View file

@ -14,7 +14,20 @@ func TestFunctionBuilderBuild_NoReturn(t *testing.T) {
Name: "init", Name: "init",
Params: nil, Params: nil,
Returns: nil, Returns: nil,
Body: ` logrus.SetLevel(logrus.DebugLevel)`, Body: codegen.FunctionBody{
codegen.ChainStatement{
codegen.Identifier("logrus"),
codegen.CallStatement{
FuncName: "SetLevel",
Params: codegen.StatementList{
codegen.ChainStatement{
codegen.Identifier("logrus"),
codegen.Identifier("DebugLevel"),
},
},
},
},
},
} }
expectedCode := ` expectedCode := `
func init() { func init() {
@ -57,18 +70,25 @@ func TestFunctionBuilderBuild_OneReturn(t *testing.T) {
Returns: []code.Type{ Returns: []code.Type{
code.SimpleType("User"), code.SimpleType("User"),
}, },
Body: ` return User{ Body: codegen.FunctionBody{
Username: username, codegen.ReturnStatement{
Age: age, codegen.StructStatement{
Parent: parent Type: "User",
}`, Pairs: []codegen.StructFieldPair{
{Key: "Username", Value: codegen.Identifier("username")},
{Key: "Age", Value: codegen.Identifier("age")},
{Key: "Parent", Value: codegen.Identifier("parent")},
},
},
},
},
} }
expectedCode := ` expectedCode := `
func NewUser(username string, age int, parent *User) User { func NewUser(username string, age int, parent *User) User {
return User{ return User{
Username: username, Username: username,
Age: age, Age: age,
Parent: parent Parent: parent,
} }
} }
` `
@ -101,7 +121,19 @@ func TestFunctionBuilderBuild_MultiReturn(t *testing.T) {
code.SimpleType("User"), code.SimpleType("User"),
code.TypeError, code.TypeError,
}, },
Body: ` return collection.Save(user)`, Body: codegen.FunctionBody{
codegen.ReturnStatement{
codegen.ChainStatement{
codegen.Identifier("collection"),
codegen.CallStatement{
FuncName: "Save",
Params: codegen.StatementList{
codegen.Identifier("user"),
},
},
},
},
},
} }
expectedCode := ` expectedCode := `
func Save(user User) (User, error) { func Save(user User) (User, error) {

View file

@ -10,7 +10,7 @@ import (
const methodTemplate = ` const methodTemplate = `
func ({{.GenReceiver}}) {{.Name}}({{.GenParams}}){{.GenReturns}} { func ({{.GenReceiver}}) {{.Name}}({{.GenParams}}){{.GenReturns}} {
{{.Body}} {{.Body.Code}}
} }
` `
@ -20,7 +20,7 @@ type MethodBuilder struct {
Name string Name string
Params []code.Param Params []code.Param
Returns []code.Type Returns []code.Type
Body string Body FunctionBody
} }
// MethodReceiver describes a specification of a method receiver. // MethodReceiver describes a specification of a method receiver.

View file

@ -15,7 +15,17 @@ func TestMethodBuilderBuild_IgnoreReceiverNoReturn(t *testing.T) {
Name: "Init", Name: "Init",
Params: nil, Params: nil,
Returns: nil, Returns: nil,
Body: ` db.Init(&User{})`, Body: codegen.FunctionBody{
codegen.ChainStatement{
codegen.Identifier("db"),
codegen.CallStatement{
FuncName: "Init",
Params: codegen.StatementList{
codegen.RawStatement("&User{}"),
},
},
},
},
} }
expectedCode := ` expectedCode := `
func (User) Init() { func (User) Init() {
@ -47,7 +57,19 @@ func TestMethodBuilderBuild_IgnorePoinerReceiverOneReturn(t *testing.T) {
Name: "Init", Name: "Init",
Params: nil, Params: nil,
Returns: []code.Type{code.TypeError}, Returns: []code.Type{code.TypeError},
Body: ` return db.Init(&User{})`, Body: codegen.FunctionBody{
codegen.ReturnStatement{
codegen.ChainStatement{
codegen.Identifier("db"),
codegen.CallStatement{
FuncName: "Init",
Params: codegen.StatementList{
codegen.RawStatement("&User{}"),
},
},
},
},
},
} }
expectedCode := ` expectedCode := `
func (*User) Init() error { func (*User) Init() error {
@ -81,8 +103,17 @@ func TestMethodBuilderBuild_UseReceiverMultiReturn(t *testing.T) {
{Name: "age", Type: code.TypeInt}, {Name: "age", Type: code.TypeInt},
}, },
Returns: []code.Type{code.SimpleType("User"), code.TypeError}, Returns: []code.Type{code.SimpleType("User"), code.TypeError},
Body: ` u.Age = age Body: codegen.FunctionBody{
return u`, codegen.AssignStatement{
Vars: []string{"u.Age"},
Values: codegen.StatementList{
codegen.Identifier("age"),
},
},
codegen.ReturnStatement{
codegen.Identifier("u"),
},
},
} }
expectedCode := ` expectedCode := `
func (u User) WithAge(age int) (User, error) { func (u User) WithAge(age int) (User, error) {
@ -118,7 +149,14 @@ func TestMethodBuilderBuild_UsePointerReceiverNoReturn(t *testing.T) {
{Name: "age", Type: code.TypeInt}, {Name: "age", Type: code.TypeInt},
}, },
Returns: nil, Returns: nil,
Body: ` u.Age = age`, Body: codegen.FunctionBody{
codegen.AssignStatement{
Vars: []string{"u.Age"},
Values: codegen.StatementList{
codegen.Identifier("age"),
},
},
},
} }
expectedCode := ` expectedCode := `
func (u *User) SetAge(age int) { func (u *User) SetAge(age int) {

94
internal/mongo/common.go Normal file
View file

@ -0,0 +1,94 @@
package mongo
import (
"strings"
"github.com/sunboyy/repogen/internal/code"
"github.com/sunboyy/repogen/internal/codegen"
"github.com/sunboyy/repogen/internal/spec"
)
var returnNilErr = codegen.ReturnStatement{
codegen.Identifier("nil"),
codegen.Identifier("err"),
}
var ifErrReturnNilErr = codegen.IfBlock{
Condition: []codegen.Statement{
codegen.RawStatement("err != nil"),
},
Statements: []codegen.Statement{
returnNilErr,
},
}
var ifErrReturn0Err = codegen.IfBlock{
Condition: []codegen.Statement{
codegen.RawStatement("err != nil"),
},
Statements: []codegen.Statement{
codegen.ReturnStatement{
codegen.Identifier("0"),
codegen.Identifier("err"),
},
},
}
var ifErrReturnFalseErr = codegen.IfBlock{
Condition: []codegen.Statement{
codegen.RawStatement("err != nil"),
},
Statements: []codegen.Statement{
codegen.ReturnStatement{
codegen.Identifier("false"),
codegen.Identifier("err"),
},
},
}
type baseMethodGenerator struct {
structModel code.Struct
}
func (g baseMethodGenerator) bsonFieldReference(fieldReference spec.FieldReference) (string, error) {
var bsonTags []string
for _, field := range fieldReference {
tag, err := g.bsonTagFromField(field)
if err != nil {
return "", err
}
bsonTags = append(bsonTags, tag)
}
return strings.Join(bsonTags, "."), nil
}
func (g baseMethodGenerator) bsonTagFromField(field code.StructField) (string, error) {
bsonTag, ok := field.Tags["bson"]
if !ok {
return "", NewBsonTagNotFoundError(field.Name)
}
return bsonTag[0], nil
}
func (g baseMethodGenerator) convertQuerySpec(query spec.QuerySpec) (querySpec, error) {
var predicates []predicate
for _, predicateSpec := range query.Predicates {
bsonFieldReference, err := g.bsonFieldReference(predicateSpec.FieldReference)
if err != nil {
return querySpec{}, err
}
predicates = append(predicates, predicate{
Field: bsonFieldReference,
Comparator: predicateSpec.Comparator,
ParamIndex: predicateSpec.ParamIndex,
})
}
return querySpec{
Operator: query.Operator,
Predicates: predicates,
}, nil
}

39
internal/mongo/count.go Normal file
View file

@ -0,0 +1,39 @@
package mongo
import (
"github.com/sunboyy/repogen/internal/codegen"
"github.com/sunboyy/repogen/internal/spec"
)
func (g RepositoryGenerator) generateCountBody(
operation spec.CountOperation) (codegen.FunctionBody, error) {
querySpec, err := g.convertQuerySpec(operation.Query)
if err != nil {
return nil, err
}
return codegen.FunctionBody{
codegen.DeclAssignStatement{
Vars: []string{"count", "err"},
Values: codegen.StatementList{
codegen.NewChainBuilder("r").
Chain("collection").
Call("CountDocuments",
codegen.Identifier("arg0"),
querySpec.Code(),
).Build(),
},
},
ifErrReturn0Err,
codegen.ReturnStatement{
codegen.CallStatement{
FuncName: "int",
Params: codegen.StatementList{
codegen.Identifier("count"),
},
},
codegen.Identifier("nil"),
},
}, nil
}

View file

@ -0,0 +1,437 @@
package mongo_test
import (
"fmt"
"reflect"
"testing"
"github.com/sunboyy/repogen/internal/code"
"github.com/sunboyy/repogen/internal/codegen"
"github.com/sunboyy/repogen/internal/mongo"
"github.com/sunboyy/repogen/internal/spec"
"github.com/sunboyy/repogen/internal/testutils"
)
func TestGenerateMethod_Count(t *testing.T) {
testTable := []GenerateMethodTestCase{
{
Name: "simple count method",
MethodSpec: spec.MethodSpec{
Name: "CountByGender",
Params: []code.Param{
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Type: code.SimpleType("Gender")},
},
Returns: []code.Type{
code.TypeInt,
code.TypeError,
},
Operation: spec.CountOperation{
Query: spec.QuerySpec{
Predicates: []spec.Predicate{
{
FieldReference: spec.FieldReference{genderField},
Comparator: spec.ComparatorEqual,
ParamIndex: 1,
},
},
},
},
},
ExpectedBody: ` count, err := r.collection.CountDocuments(arg0, bson.M{
"gender": arg1,
})
if err != nil {
return 0, err
}
return int(count), nil`,
},
{
Name: "count with And operator",
MethodSpec: spec.MethodSpec{
Name: "CountByGenderAndCity",
Params: []code.Param{
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Type: code.SimpleType("Gender")},
{Type: code.TypeInt},
},
Returns: []code.Type{
code.TypeInt,
code.TypeError,
},
Operation: spec.CountOperation{
Query: spec.QuerySpec{
Operator: spec.OperatorAnd,
Predicates: []spec.Predicate{
{
FieldReference: spec.FieldReference{genderField},
Comparator: spec.ComparatorEqual,
ParamIndex: 1,
},
{
FieldReference: spec.FieldReference{ageField},
Comparator: spec.ComparatorEqual,
ParamIndex: 2,
},
},
},
},
},
ExpectedBody: ` count, err := r.collection.CountDocuments(arg0, bson.M{
"$and": []bson.M{
{
"gender": arg1,
},
{
"age": arg2,
},
},
})
if err != nil {
return 0, err
}
return int(count), nil`,
},
{
Name: "count with Or operator",
MethodSpec: spec.MethodSpec{
Name: "CountByGenderOrCity",
Params: []code.Param{
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Type: code.SimpleType("Gender")},
{Type: code.TypeInt},
},
Returns: []code.Type{
code.TypeInt,
code.TypeError,
},
Operation: spec.CountOperation{
Query: spec.QuerySpec{
Operator: spec.OperatorOr,
Predicates: []spec.Predicate{
{
FieldReference: spec.FieldReference{genderField},
Comparator: spec.ComparatorEqual,
ParamIndex: 1,
},
{
FieldReference: spec.FieldReference{ageField},
Comparator: spec.ComparatorEqual,
ParamIndex: 2,
},
},
},
},
},
ExpectedBody: ` count, err := r.collection.CountDocuments(arg0, bson.M{
"$or": []bson.M{
{
"gender": arg1,
},
{
"age": arg2,
},
},
})
if err != nil {
return 0, err
}
return int(count), nil`,
},
{
Name: "count with Not comparator",
MethodSpec: spec.MethodSpec{
Name: "CountByGenderNot",
Params: []code.Param{
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Type: code.SimpleType("Gender")},
},
Returns: []code.Type{
code.TypeInt,
code.TypeError,
},
Operation: spec.CountOperation{
Query: spec.QuerySpec{
Predicates: []spec.Predicate{
{
FieldReference: spec.FieldReference{genderField},
Comparator: spec.ComparatorNot,
ParamIndex: 1,
},
},
},
},
},
ExpectedBody: ` count, err := r.collection.CountDocuments(arg0, bson.M{
"gender": bson.M{
"$ne": arg1,
},
})
if err != nil {
return 0, err
}
return int(count), nil`,
},
{
Name: "count with LessThan comparator",
MethodSpec: spec.MethodSpec{
Name: "CountByAgeLessThan",
Params: []code.Param{
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Type: code.TypeInt},
},
Returns: []code.Type{
code.TypeInt,
code.TypeError,
},
Operation: spec.CountOperation{
Query: spec.QuerySpec{
Predicates: []spec.Predicate{
{
FieldReference: spec.FieldReference{ageField},
Comparator: spec.ComparatorLessThan,
ParamIndex: 1,
},
},
},
},
},
ExpectedBody: ` count, err := r.collection.CountDocuments(arg0, bson.M{
"age": bson.M{
"$lt": arg1,
},
})
if err != nil {
return 0, err
}
return int(count), nil`,
},
{
Name: "count with LessThanEqual comparator",
MethodSpec: spec.MethodSpec{
Name: "CountByAgeLessThanEqual",
Params: []code.Param{
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Type: code.TypeInt},
},
Returns: []code.Type{
code.TypeInt,
code.TypeError,
},
Operation: spec.CountOperation{
Query: spec.QuerySpec{
Predicates: []spec.Predicate{
{
FieldReference: spec.FieldReference{ageField},
Comparator: spec.ComparatorLessThanEqual,
ParamIndex: 1,
},
},
},
},
},
ExpectedBody: ` count, err := r.collection.CountDocuments(arg0, bson.M{
"age": bson.M{
"$lte": arg1,
},
})
if err != nil {
return 0, err
}
return int(count), nil`,
},
{
Name: "count with GreaterThan comparator",
MethodSpec: spec.MethodSpec{
Name: "CountByAgeGreaterThan",
Params: []code.Param{
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Type: code.TypeInt},
},
Returns: []code.Type{
code.TypeInt,
code.TypeError,
},
Operation: spec.CountOperation{
Query: spec.QuerySpec{
Predicates: []spec.Predicate{
{
FieldReference: spec.FieldReference{ageField},
Comparator: spec.ComparatorGreaterThan,
ParamIndex: 1,
},
},
},
},
},
ExpectedBody: ` count, err := r.collection.CountDocuments(arg0, bson.M{
"age": bson.M{
"$gt": arg1,
},
})
if err != nil {
return 0, err
}
return int(count), nil`,
},
{
Name: "count with GreaterThanEqual comparator",
MethodSpec: spec.MethodSpec{
Name: "CountByAgeGreaterThanEqual",
Params: []code.Param{
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Type: code.TypeInt},
},
Returns: []code.Type{
code.TypeInt,
code.TypeError,
},
Operation: spec.CountOperation{
Query: spec.QuerySpec{
Predicates: []spec.Predicate{
{
FieldReference: spec.FieldReference{ageField},
Comparator: spec.ComparatorGreaterThanEqual,
ParamIndex: 1,
},
},
},
},
},
ExpectedBody: ` count, err := r.collection.CountDocuments(arg0, bson.M{
"age": bson.M{
"$gte": arg1,
},
})
if err != nil {
return 0, err
}
return int(count), nil`,
},
{
Name: "count with Between comparator",
MethodSpec: spec.MethodSpec{
Name: "CountByAgeBetween",
Params: []code.Param{
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Type: code.TypeInt},
{Type: code.TypeInt},
},
Returns: []code.Type{
code.TypeInt,
code.TypeError,
},
Operation: spec.CountOperation{
Query: spec.QuerySpec{
Predicates: []spec.Predicate{
{
FieldReference: spec.FieldReference{ageField},
Comparator: spec.ComparatorBetween,
ParamIndex: 1,
},
},
},
},
},
ExpectedBody: ` count, err := r.collection.CountDocuments(arg0, bson.M{
"age": bson.M{
"$gte": arg1,
"$lte": arg2,
},
})
if err != nil {
return 0, err
}
return int(count), nil`,
},
{
Name: "count with In comparator",
MethodSpec: spec.MethodSpec{
Name: "CountByAgeIn",
Params: []code.Param{
{Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Type: code.ArrayType{ContainedType: code.TypeInt}},
},
Returns: []code.Type{
code.TypeInt,
code.TypeError,
},
Operation: spec.CountOperation{
Query: spec.QuerySpec{
Predicates: []spec.Predicate{
{
FieldReference: spec.FieldReference{ageField},
Comparator: spec.ComparatorIn,
ParamIndex: 1,
},
},
},
},
},
ExpectedBody: ` count, err := r.collection.CountDocuments(arg0, bson.M{
"age": bson.M{
"$in": arg1,
},
})
if err != nil {
return 0, err
}
return int(count), nil`,
},
}
for _, testCase := range testTable {
t.Run(testCase.Name, func(t *testing.T) {
generator := mongo.NewGenerator(userModel, "UserRepository")
expectedReceiver := codegen.MethodReceiver{
Name: "r",
Type: "UserRepositoryMongo",
Pointer: true,
}
var expectedParams []code.Param
for i, param := range testCase.MethodSpec.Params {
expectedParams = append(expectedParams, code.Param{
Name: fmt.Sprintf("arg%d", i),
Type: param.Type,
})
}
actual, err := generator.GenerateMethod(testCase.MethodSpec)
if err != nil {
t.Fatal(err)
}
if expectedReceiver != actual.Receiver {
t.Errorf(
"incorrect method receiver: expected %+v, got %+v",
expectedReceiver,
actual.Receiver,
)
}
if testCase.MethodSpec.Name != actual.Name {
t.Errorf(
"incorrect method name: expected %s, got %s",
testCase.MethodSpec.Name,
actual.Name,
)
}
if !reflect.DeepEqual(expectedParams, actual.Params) {
t.Errorf(
"incorrect struct params: expected %+v, got %+v",
expectedParams,
actual.Params,
)
}
if !reflect.DeepEqual(testCase.MethodSpec.Returns, actual.Returns) {
t.Errorf(
"incorrect struct returns: expected %+v, got %+v",
testCase.MethodSpec.Returns,
actual.Returns,
)
}
if err := testutils.ExpectMultiLineString(testCase.ExpectedBody, actual.Body.Code()); err != nil {
t.Error(err)
}
})
}
}

84
internal/mongo/delete.go Normal file
View file

@ -0,0 +1,84 @@
package mongo
import (
"github.com/sunboyy/repogen/internal/codegen"
"github.com/sunboyy/repogen/internal/spec"
)
func (g RepositoryGenerator) generateDeleteBody(
operation spec.DeleteOperation) (codegen.FunctionBody, error) {
return deleteBodyGenerator{
baseMethodGenerator: g.baseMethodGenerator,
operation: operation,
}.generate()
}
type deleteBodyGenerator struct {
baseMethodGenerator
operation spec.DeleteOperation
}
func (g deleteBodyGenerator) generate() (codegen.FunctionBody, error) {
querySpec, err := g.convertQuerySpec(g.operation.Query)
if err != nil {
return nil, err
}
if g.operation.Mode == spec.QueryModeOne {
return g.generateDeleteOneBody(querySpec), nil
}
return g.generateDeleteManyBody(querySpec), nil
}
func (g deleteBodyGenerator) generateDeleteOneBody(
querySpec querySpec) codegen.FunctionBody {
return codegen.FunctionBody{
codegen.DeclAssignStatement{
Vars: []string{"result", "err"},
Values: codegen.StatementList{
codegen.NewChainBuilder("r").
Chain("collection").
Call("DeleteOne",
codegen.Identifier("arg0"),
querySpec.Code(),
).Build(),
},
},
ifErrReturnFalseErr,
codegen.ReturnStatement{
codegen.RawStatement("result.DeletedCount > 0"),
codegen.Identifier("nil"),
},
}
}
func (g deleteBodyGenerator) generateDeleteManyBody(
querySpec querySpec) codegen.FunctionBody {
return codegen.FunctionBody{
codegen.DeclAssignStatement{
Vars: []string{"result", "err"},
Values: codegen.StatementList{
codegen.NewChainBuilder("r").
Chain("collection").
Call("DeleteMany",
codegen.Identifier("arg0"),
querySpec.Code(),
).Build(),
},
},
ifErrReturn0Err,
codegen.ReturnStatement{
codegen.CallStatement{
FuncName: "int",
Params: codegen.StatementList{
codegen.NewChainBuilder("result").Chain("DeletedCount").Build(),
},
},
codegen.Identifier("nil"),
},
}
}

View file

@ -0,0 +1,477 @@
package mongo_test
import (
"fmt"
"reflect"
"testing"
"github.com/sunboyy/repogen/internal/code"
"github.com/sunboyy/repogen/internal/codegen"
"github.com/sunboyy/repogen/internal/mongo"
"github.com/sunboyy/repogen/internal/spec"
"github.com/sunboyy/repogen/internal/testutils"
)
func TestGenerateMethod_Delete(t *testing.T) {
testTable := []GenerateMethodTestCase{
{
Name: "simple delete one method",
MethodSpec: spec.MethodSpec{
Name: "DeleteByID",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "id", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}},
},
Returns: []code.Type{code.TypeBool, code.TypeError},
Operation: spec.DeleteOperation{
Mode: spec.QueryModeOne,
Query: spec.QuerySpec{
Predicates: []spec.Predicate{
{
Comparator: spec.ComparatorEqual,
FieldReference: spec.FieldReference{idField},
ParamIndex: 1,
},
},
},
},
},
ExpectedBody: ` result, err := r.collection.DeleteOne(arg0, bson.M{
"_id": arg1,
})
if err != nil {
return false, err
}
return result.DeletedCount > 0, nil`,
},
{
Name: "simple delete many method",
MethodSpec: spec.MethodSpec{
Name: "DeleteByGender",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "gender", Type: code.SimpleType("Gender")},
},
Returns: []code.Type{
code.TypeInt,
code.TypeError,
},
Operation: spec.DeleteOperation{
Mode: spec.QueryModeMany,
Query: spec.QuerySpec{
Predicates: []spec.Predicate{
{
Comparator: spec.ComparatorEqual,
FieldReference: spec.FieldReference{genderField},
ParamIndex: 1,
},
},
},
},
},
ExpectedBody: ` result, err := r.collection.DeleteMany(arg0, bson.M{
"gender": arg1,
})
if err != nil {
return 0, err
}
return int(result.DeletedCount), nil`,
},
{
Name: "delete with And operator",
MethodSpec: spec.MethodSpec{
Name: "DeleteByGenderAndAge",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "gender", Type: code.SimpleType("Gender")},
{Name: "age", Type: code.TypeInt},
},
Returns: []code.Type{
code.TypeInt,
code.TypeError,
},
Operation: spec.DeleteOperation{
Mode: spec.QueryModeMany,
Query: spec.QuerySpec{
Operator: spec.OperatorAnd,
Predicates: []spec.Predicate{
{
Comparator: spec.ComparatorEqual,
FieldReference: spec.FieldReference{genderField},
ParamIndex: 1,
},
{
Comparator: spec.ComparatorEqual,
FieldReference: spec.FieldReference{ageField},
ParamIndex: 2,
},
},
},
},
},
ExpectedBody: ` result, err := r.collection.DeleteMany(arg0, bson.M{
"$and": []bson.M{
{
"gender": arg1,
},
{
"age": arg2,
},
},
})
if err != nil {
return 0, err
}
return int(result.DeletedCount), nil`,
},
{
Name: "delete with Or operator",
MethodSpec: spec.MethodSpec{
Name: "DeleteByGenderOrAge",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "gender", Type: code.SimpleType("Gender")},
{Name: "age", Type: code.TypeInt},
},
Returns: []code.Type{
code.TypeInt,
code.TypeError,
},
Operation: spec.DeleteOperation{
Mode: spec.QueryModeMany,
Query: spec.QuerySpec{
Operator: spec.OperatorOr,
Predicates: []spec.Predicate{
{
Comparator: spec.ComparatorEqual,
FieldReference: spec.FieldReference{genderField},
ParamIndex: 1,
},
{
Comparator: spec.ComparatorEqual,
FieldReference: spec.FieldReference{ageField},
ParamIndex: 2,
},
},
},
},
},
ExpectedBody: ` result, err := r.collection.DeleteMany(arg0, bson.M{
"$or": []bson.M{
{
"gender": arg1,
},
{
"age": arg2,
},
},
})
if err != nil {
return 0, err
}
return int(result.DeletedCount), nil`,
},
{
Name: "delete with Not comparator",
MethodSpec: spec.MethodSpec{
Name: "DeleteByGenderNot",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "gender", Type: code.SimpleType("Gender")},
},
Returns: []code.Type{
code.TypeInt,
code.TypeError,
},
Operation: spec.DeleteOperation{
Mode: spec.QueryModeMany,
Query: spec.QuerySpec{
Predicates: []spec.Predicate{
{
Comparator: spec.ComparatorNot,
FieldReference: spec.FieldReference{genderField},
ParamIndex: 1,
},
},
},
},
},
ExpectedBody: ` result, err := r.collection.DeleteMany(arg0, bson.M{
"gender": bson.M{
"$ne": arg1,
},
})
if err != nil {
return 0, err
}
return int(result.DeletedCount), nil`,
},
{
Name: "delete with LessThan comparator",
MethodSpec: spec.MethodSpec{
Name: "DeleteByAgeLessThan",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "age", Type: code.TypeInt},
},
Returns: []code.Type{
code.TypeInt,
code.TypeError,
},
Operation: spec.DeleteOperation{
Mode: spec.QueryModeMany,
Query: spec.QuerySpec{
Predicates: []spec.Predicate{
{
Comparator: spec.ComparatorLessThan,
FieldReference: spec.FieldReference{ageField},
ParamIndex: 1,
},
},
},
},
},
ExpectedBody: ` result, err := r.collection.DeleteMany(arg0, bson.M{
"age": bson.M{
"$lt": arg1,
},
})
if err != nil {
return 0, err
}
return int(result.DeletedCount), nil`,
},
{
Name: "delete with LessThanEqual comparator",
MethodSpec: spec.MethodSpec{
Name: "DeleteByAgeLessThanEqual",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "age", Type: code.TypeInt},
},
Returns: []code.Type{
code.TypeInt,
code.TypeError,
},
Operation: spec.DeleteOperation{
Mode: spec.QueryModeMany,
Query: spec.QuerySpec{
Predicates: []spec.Predicate{
{
Comparator: spec.ComparatorLessThanEqual,
FieldReference: spec.FieldReference{ageField},
ParamIndex: 1,
},
},
},
},
},
ExpectedBody: ` result, err := r.collection.DeleteMany(arg0, bson.M{
"age": bson.M{
"$lte": arg1,
},
})
if err != nil {
return 0, err
}
return int(result.DeletedCount), nil`,
},
{
Name: "delete with GreaterThan comparator",
MethodSpec: spec.MethodSpec{
Name: "DeleteByAgeGreaterThan",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "age", Type: code.TypeInt},
},
Returns: []code.Type{
code.TypeInt,
code.TypeError,
},
Operation: spec.DeleteOperation{
Mode: spec.QueryModeMany,
Query: spec.QuerySpec{
Predicates: []spec.Predicate{
{
Comparator: spec.ComparatorGreaterThan,
FieldReference: spec.FieldReference{ageField},
ParamIndex: 1,
},
},
},
},
},
ExpectedBody: ` result, err := r.collection.DeleteMany(arg0, bson.M{
"age": bson.M{
"$gt": arg1,
},
})
if err != nil {
return 0, err
}
return int(result.DeletedCount), nil`,
},
{
Name: "delete with GreaterThanEqual comparator",
MethodSpec: spec.MethodSpec{
Name: "DeleteByAgeGreaterThanEqual",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "age", Type: code.TypeInt},
},
Returns: []code.Type{
code.TypeInt,
code.TypeError,
},
Operation: spec.DeleteOperation{
Mode: spec.QueryModeMany,
Query: spec.QuerySpec{
Predicates: []spec.Predicate{
{
Comparator: spec.ComparatorGreaterThanEqual,
FieldReference: spec.FieldReference{ageField},
ParamIndex: 1,
},
},
},
},
},
ExpectedBody: ` result, err := r.collection.DeleteMany(arg0, bson.M{
"age": bson.M{
"$gte": arg1,
},
})
if err != nil {
return 0, err
}
return int(result.DeletedCount), nil`,
},
{
Name: "delete with Between comparator",
MethodSpec: spec.MethodSpec{
Name: "DeleteByAgeBetween",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "fromAge", Type: code.TypeInt},
{Name: "toAge", Type: code.TypeInt},
},
Returns: []code.Type{
code.TypeInt,
code.TypeError,
},
Operation: spec.DeleteOperation{
Mode: spec.QueryModeMany,
Query: spec.QuerySpec{
Predicates: []spec.Predicate{
{
Comparator: spec.ComparatorBetween,
FieldReference: spec.FieldReference{ageField},
ParamIndex: 1,
},
},
},
},
},
ExpectedBody: ` result, err := r.collection.DeleteMany(arg0, bson.M{
"age": bson.M{
"$gte": arg1,
"$lte": arg2,
},
})
if err != nil {
return 0, err
}
return int(result.DeletedCount), nil`,
},
{
Name: "delete with In comparator",
MethodSpec: spec.MethodSpec{
Name: "DeleteByGenderIn",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "gender", Type: code.ArrayType{ContainedType: code.SimpleType("Gender")}},
},
Returns: []code.Type{
code.TypeInt,
code.TypeError,
},
Operation: spec.DeleteOperation{
Mode: spec.QueryModeMany,
Query: spec.QuerySpec{
Predicates: []spec.Predicate{
{
Comparator: spec.ComparatorIn,
FieldReference: spec.FieldReference{genderField},
ParamIndex: 1,
},
},
},
},
},
ExpectedBody: ` result, err := r.collection.DeleteMany(arg0, bson.M{
"gender": bson.M{
"$in": arg1,
},
})
if err != nil {
return 0, err
}
return int(result.DeletedCount), nil`,
},
}
for _, testCase := range testTable {
t.Run(testCase.Name, func(t *testing.T) {
generator := mongo.NewGenerator(userModel, "UserRepository")
expectedReceiver := codegen.MethodReceiver{
Name: "r",
Type: "UserRepositoryMongo",
Pointer: true,
}
var expectedParams []code.Param
for i, param := range testCase.MethodSpec.Params {
expectedParams = append(expectedParams, code.Param{
Name: fmt.Sprintf("arg%d", i),
Type: param.Type,
})
}
actual, err := generator.GenerateMethod(testCase.MethodSpec)
if err != nil {
t.Fatal(err)
}
if expectedReceiver != actual.Receiver {
t.Errorf(
"incorrect method receiver: expected %+v, got %+v",
expectedReceiver,
actual.Receiver,
)
}
if testCase.MethodSpec.Name != actual.Name {
t.Errorf(
"incorrect method name: expected %s, got %s",
testCase.MethodSpec.Name,
actual.Name,
)
}
if !reflect.DeepEqual(expectedParams, actual.Params) {
t.Errorf(
"incorrect struct params: expected %+v, got %+v",
expectedParams,
actual.Params,
)
}
if !reflect.DeepEqual(testCase.MethodSpec.Returns, actual.Returns) {
t.Errorf(
"incorrect struct returns: expected %+v, got %+v",
testCase.MethodSpec.Returns,
actual.Returns,
)
}
if err := testutils.ExpectMultiLineString(testCase.ExpectedBody, actual.Body.Code()); err != nil {
t.Error(err)
}
})
}
}

160
internal/mongo/find.go Normal file
View file

@ -0,0 +1,160 @@
package mongo
import (
"github.com/sunboyy/repogen/internal/code"
"github.com/sunboyy/repogen/internal/codegen"
"github.com/sunboyy/repogen/internal/spec"
)
func (g RepositoryGenerator) generateFindBody(
operation spec.FindOperation) (codegen.FunctionBody, error) {
return findBodyGenerator{
baseMethodGenerator: g.baseMethodGenerator,
operation: operation,
}.generate()
}
type findBodyGenerator struct {
baseMethodGenerator
operation spec.FindOperation
}
func (g findBodyGenerator) generate() (codegen.FunctionBody, error) {
querySpec, err := g.convertQuerySpec(g.operation.Query)
if err != nil {
return nil, err
}
sortsCode, err := g.generateSortMap()
if err != nil {
return nil, err
}
if g.operation.Mode == spec.QueryModeOne {
return g.generateFindOneBody(querySpec, sortsCode), nil
}
return g.generateFindManyBody(querySpec, sortsCode), nil
}
func (g findBodyGenerator) generateFindOneBody(querySpec querySpec,
sortsCode codegen.MapStatement) codegen.FunctionBody {
return codegen.FunctionBody{
codegen.DeclStatement{
Name: "entity",
Type: code.SimpleType(g.structModel.Name),
},
codegen.IfBlock{
Condition: []codegen.Statement{
codegen.DeclAssignStatement{
Vars: []string{"err"},
Values: codegen.StatementList{
codegen.NewChainBuilder("r").
Chain("collection").
Call("FindOne",
codegen.Identifier("arg0"),
querySpec.Code(),
codegen.NewChainBuilder("options").
Call("FindOne").
Call("SetSort", sortsCode).
Build(),
).
Call("Decode",
codegen.RawStatement("&entity"),
).Build(),
},
},
codegen.RawStatement("err != nil"),
},
Statements: []codegen.Statement{
returnNilErr,
},
},
codegen.ReturnStatement{
codegen.RawStatement("&entity"),
codegen.Identifier("nil"),
},
}
}
func (g findBodyGenerator) generateFindManyBody(querySpec querySpec,
sortsCode codegen.MapStatement) codegen.FunctionBody {
return codegen.FunctionBody{
codegen.DeclAssignStatement{
Vars: []string{"cursor", "err"},
Values: codegen.StatementList{
codegen.NewChainBuilder("r").
Chain("collection").
Call("Find",
codegen.Identifier("arg0"),
querySpec.Code(),
codegen.NewChainBuilder("options").
Call("Find").
Call("SetSort", sortsCode).
Build(),
).Build(),
},
},
ifErrReturnNilErr,
codegen.DeclStatement{
Name: "entities",
Type: code.ArrayType{
ContainedType: code.PointerType{
ContainedType: code.SimpleType(g.structModel.Name),
},
},
},
codegen.IfBlock{
Condition: []codegen.Statement{
codegen.DeclAssignStatement{
Vars: []string{"err"},
Values: codegen.StatementList{
codegen.NewChainBuilder("cursor").
Call("All",
codegen.Identifier("arg0"),
codegen.RawStatement("&entities"),
).Build(),
},
},
codegen.RawStatement("err != nil"),
},
Statements: []codegen.Statement{
returnNilErr,
},
},
codegen.ReturnStatement{
codegen.Identifier("entities"),
codegen.Identifier("nil"),
},
}
}
func (g findBodyGenerator) generateSortMap() (
codegen.MapStatement, error) {
sortsCode := codegen.MapStatement{
Type: "bson.M",
}
for _, s := range g.operation.Sorts {
bsonFieldReference, err := g.bsonFieldReference(s.FieldReference)
if err != nil {
return codegen.MapStatement{}, err
}
sortValueIdentifier := codegen.Identifier("1")
if s.Ordering == spec.OrderingDescending {
sortValueIdentifier = codegen.Identifier("-1")
}
sortsCode.Pairs = append(sortsCode.Pairs, codegen.MapPair{
Key: bsonFieldReference,
Value: sortValueIdentifier,
})
}
return sortsCode, nil
}

886
internal/mongo/find_test.go Normal file
View file

@ -0,0 +1,886 @@
package mongo_test
import (
"fmt"
"reflect"
"testing"
"github.com/sunboyy/repogen/internal/code"
"github.com/sunboyy/repogen/internal/codegen"
"github.com/sunboyy/repogen/internal/mongo"
"github.com/sunboyy/repogen/internal/spec"
"github.com/sunboyy/repogen/internal/testutils"
)
func TestGenerateMethod_Find(t *testing.T) {
testTable := []GenerateMethodTestCase{
{
Name: "simple find one method",
MethodSpec: spec.MethodSpec{
Name: "FindByID",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "id", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}},
},
Returns: []code.Type{
code.PointerType{ContainedType: code.SimpleType("UserModel")},
code.TypeError,
},
Operation: spec.FindOperation{
Mode: spec.QueryModeOne,
Query: spec.QuerySpec{
Predicates: []spec.Predicate{
{
Comparator: spec.ComparatorEqual,
FieldReference: spec.FieldReference{idField},
ParamIndex: 1,
},
},
},
},
},
ExpectedBody: ` var entity UserModel
if err := r.collection.FindOne(arg0, bson.M{
"_id": arg1,
}, options.FindOne().SetSort(bson.M{
})).Decode(&entity); err != nil {
return nil, err
}
return &entity, nil`,
},
{
Name: "simple find many method",
MethodSpec: spec.MethodSpec{
Name: "FindByGender",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "gender", Type: code.SimpleType("Gender")},
},
Returns: []code.Type{
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
code.TypeError,
},
Operation: spec.FindOperation{
Mode: spec.QueryModeMany,
Query: spec.QuerySpec{
Predicates: []spec.Predicate{
{
Comparator: spec.ComparatorEqual,
FieldReference: spec.FieldReference{genderField},
ParamIndex: 1,
},
},
},
},
},
ExpectedBody: ` cursor, err := r.collection.Find(arg0, bson.M{
"gender": arg1,
}, options.Find().SetSort(bson.M{
}))
if err != nil {
return nil, err
}
var entities []*UserModel
if err := cursor.All(arg0, &entities); err != nil {
return nil, err
}
return entities, nil`,
},
{
Name: "find with deep field reference",
MethodSpec: spec.MethodSpec{
Name: "FindByNameFirst",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "firstName", Type: code.TypeString},
},
Returns: []code.Type{
code.PointerType{ContainedType: code.SimpleType("UserModel")},
code.TypeError,
},
Operation: spec.FindOperation{
Mode: spec.QueryModeOne,
Query: spec.QuerySpec{
Predicates: []spec.Predicate{
{
Comparator: spec.ComparatorEqual,
FieldReference: spec.FieldReference{nameField, firstNameField},
ParamIndex: 1,
},
},
},
},
},
ExpectedBody: ` var entity UserModel
if err := r.collection.FindOne(arg0, bson.M{
"name.first": arg1,
}, options.FindOne().SetSort(bson.M{
})).Decode(&entity); err != nil {
return nil, err
}
return &entity, nil`,
},
{
Name: "find with And operator",
MethodSpec: spec.MethodSpec{
Name: "FindByGenderAndAge",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "gender", Type: code.SimpleType("Gender")},
{Name: "age", Type: code.TypeInt},
},
Returns: []code.Type{
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
code.TypeError,
},
Operation: spec.FindOperation{
Mode: spec.QueryModeMany,
Query: spec.QuerySpec{
Operator: spec.OperatorAnd,
Predicates: []spec.Predicate{
{
Comparator: spec.ComparatorEqual,
FieldReference: spec.FieldReference{genderField},
ParamIndex: 1,
},
{
Comparator: spec.ComparatorEqual,
FieldReference: spec.FieldReference{ageField},
ParamIndex: 2,
},
},
},
},
},
ExpectedBody: ` cursor, err := r.collection.Find(arg0, bson.M{
"$and": []bson.M{
{
"gender": arg1,
},
{
"age": arg2,
},
},
}, options.Find().SetSort(bson.M{
}))
if err != nil {
return nil, err
}
var entities []*UserModel
if err := cursor.All(arg0, &entities); err != nil {
return nil, err
}
return entities, nil`,
},
{
Name: "find with Or operator",
MethodSpec: spec.MethodSpec{
Name: "FindByGenderOrAge",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "gender", Type: code.SimpleType("Gender")},
{Name: "age", Type: code.TypeInt},
},
Returns: []code.Type{
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
code.TypeError,
},
Operation: spec.FindOperation{
Mode: spec.QueryModeMany,
Query: spec.QuerySpec{
Operator: spec.OperatorOr,
Predicates: []spec.Predicate{
{
Comparator: spec.ComparatorEqual,
FieldReference: spec.FieldReference{genderField},
ParamIndex: 1,
},
{
Comparator: spec.ComparatorEqual,
FieldReference: spec.FieldReference{ageField},
ParamIndex: 2,
},
},
},
},
},
ExpectedBody: ` cursor, err := r.collection.Find(arg0, bson.M{
"$or": []bson.M{
{
"gender": arg1,
},
{
"age": arg2,
},
},
}, options.Find().SetSort(bson.M{
}))
if err != nil {
return nil, err
}
var entities []*UserModel
if err := cursor.All(arg0, &entities); err != nil {
return nil, err
}
return entities, nil`,
},
{
Name: "find with Not comparator",
MethodSpec: spec.MethodSpec{
Name: "FindByGenderNot",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "gender", Type: code.SimpleType("Gender")},
},
Returns: []code.Type{
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
code.TypeError,
},
Operation: spec.FindOperation{
Mode: spec.QueryModeMany,
Query: spec.QuerySpec{
Predicates: []spec.Predicate{
{
Comparator: spec.ComparatorNot,
FieldReference: spec.FieldReference{genderField},
ParamIndex: 1,
},
},
},
},
},
ExpectedBody: ` cursor, err := r.collection.Find(arg0, bson.M{
"gender": bson.M{
"$ne": arg1,
},
}, options.Find().SetSort(bson.M{
}))
if err != nil {
return nil, err
}
var entities []*UserModel
if err := cursor.All(arg0, &entities); err != nil {
return nil, err
}
return entities, nil`,
},
{
Name: "find with LessThan comparator",
MethodSpec: spec.MethodSpec{
Name: "FindByAgeLessThan",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "age", Type: code.TypeInt},
},
Returns: []code.Type{
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
code.TypeError,
},
Operation: spec.FindOperation{
Mode: spec.QueryModeMany,
Query: spec.QuerySpec{
Predicates: []spec.Predicate{
{
Comparator: spec.ComparatorLessThan,
FieldReference: spec.FieldReference{ageField},
ParamIndex: 1,
},
},
},
},
},
ExpectedBody: ` cursor, err := r.collection.Find(arg0, bson.M{
"age": bson.M{
"$lt": arg1,
},
}, options.Find().SetSort(bson.M{
}))
if err != nil {
return nil, err
}
var entities []*UserModel
if err := cursor.All(arg0, &entities); err != nil {
return nil, err
}
return entities, nil`,
},
{
Name: "find with LessThanEqual comparator",
MethodSpec: spec.MethodSpec{
Name: "FindByAgeLessThanEqual",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "age", Type: code.TypeInt},
},
Returns: []code.Type{
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
code.TypeError,
},
Operation: spec.FindOperation{
Mode: spec.QueryModeMany,
Query: spec.QuerySpec{
Predicates: []spec.Predicate{
{
Comparator: spec.ComparatorLessThanEqual,
FieldReference: spec.FieldReference{ageField},
ParamIndex: 1,
},
},
},
},
},
ExpectedBody: ` cursor, err := r.collection.Find(arg0, bson.M{
"age": bson.M{
"$lte": arg1,
},
}, options.Find().SetSort(bson.M{
}))
if err != nil {
return nil, err
}
var entities []*UserModel
if err := cursor.All(arg0, &entities); err != nil {
return nil, err
}
return entities, nil`,
},
{
Name: "find with GreaterThan comparator",
MethodSpec: spec.MethodSpec{
Name: "FindByAgeGreaterThan",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "age", Type: code.TypeInt},
},
Returns: []code.Type{
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
code.TypeError,
},
Operation: spec.FindOperation{
Mode: spec.QueryModeMany,
Query: spec.QuerySpec{
Predicates: []spec.Predicate{
{
Comparator: spec.ComparatorGreaterThan,
FieldReference: spec.FieldReference{ageField},
ParamIndex: 1,
},
},
},
},
},
ExpectedBody: ` cursor, err := r.collection.Find(arg0, bson.M{
"age": bson.M{
"$gt": arg1,
},
}, options.Find().SetSort(bson.M{
}))
if err != nil {
return nil, err
}
var entities []*UserModel
if err := cursor.All(arg0, &entities); err != nil {
return nil, err
}
return entities, nil`,
},
{
Name: "find with GreaterThanEqual comparator",
MethodSpec: spec.MethodSpec{
Name: "FindByAgeGreaterThanEqual",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "age", Type: code.TypeInt},
},
Returns: []code.Type{
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
code.TypeError,
},
Operation: spec.FindOperation{
Mode: spec.QueryModeMany,
Query: spec.QuerySpec{
Predicates: []spec.Predicate{
{
Comparator: spec.ComparatorGreaterThanEqual,
FieldReference: spec.FieldReference{ageField},
ParamIndex: 1,
},
},
},
},
},
ExpectedBody: ` cursor, err := r.collection.Find(arg0, bson.M{
"age": bson.M{
"$gte": arg1,
},
}, options.Find().SetSort(bson.M{
}))
if err != nil {
return nil, err
}
var entities []*UserModel
if err := cursor.All(arg0, &entities); err != nil {
return nil, err
}
return entities, nil`,
},
{
Name: "find with Between comparator",
MethodSpec: spec.MethodSpec{
Name: "FindByAgeBetween",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "fromAge", Type: code.TypeInt},
{Name: "toAge", Type: code.TypeInt},
},
Returns: []code.Type{
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
code.TypeError,
},
Operation: spec.FindOperation{
Mode: spec.QueryModeMany,
Query: spec.QuerySpec{
Predicates: []spec.Predicate{
{
Comparator: spec.ComparatorBetween,
FieldReference: spec.FieldReference{ageField},
ParamIndex: 1,
},
},
},
},
},
ExpectedBody: ` cursor, err := r.collection.Find(arg0, bson.M{
"age": bson.M{
"$gte": arg1,
"$lte": arg2,
},
}, options.Find().SetSort(bson.M{
}))
if err != nil {
return nil, err
}
var entities []*UserModel
if err := cursor.All(arg0, &entities); err != nil {
return nil, err
}
return entities, nil`,
},
{
Name: "find with In comparator",
MethodSpec: spec.MethodSpec{
Name: "FindByGenderIn",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "gender", Type: code.ArrayType{ContainedType: code.SimpleType("Gender")}},
},
Returns: []code.Type{
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
code.TypeError,
},
Operation: spec.FindOperation{
Mode: spec.QueryModeMany,
Query: spec.QuerySpec{
Predicates: []spec.Predicate{
{
Comparator: spec.ComparatorIn,
FieldReference: spec.FieldReference{genderField},
ParamIndex: 1,
},
},
},
},
},
ExpectedBody: ` cursor, err := r.collection.Find(arg0, bson.M{
"gender": bson.M{
"$in": arg1,
},
}, options.Find().SetSort(bson.M{
}))
if err != nil {
return nil, err
}
var entities []*UserModel
if err := cursor.All(arg0, &entities); err != nil {
return nil, err
}
return entities, nil`,
},
{
Name: "find with NotIn comparator",
MethodSpec: spec.MethodSpec{
Name: "FindByGenderNotIn",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "gender", Type: code.ArrayType{ContainedType: code.SimpleType("Gender")}},
},
Returns: []code.Type{
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
code.TypeError,
},
Operation: spec.FindOperation{
Mode: spec.QueryModeMany,
Query: spec.QuerySpec{
Predicates: []spec.Predicate{
{
Comparator: spec.ComparatorNotIn,
FieldReference: spec.FieldReference{genderField},
ParamIndex: 1,
},
},
},
},
},
ExpectedBody: ` cursor, err := r.collection.Find(arg0, bson.M{
"gender": bson.M{
"$nin": arg1,
},
}, options.Find().SetSort(bson.M{
}))
if err != nil {
return nil, err
}
var entities []*UserModel
if err := cursor.All(arg0, &entities); err != nil {
return nil, err
}
return entities, nil`,
},
{
Name: "find with True comparator",
MethodSpec: spec.MethodSpec{
Name: "FindByEnabledTrue",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
},
Returns: []code.Type{
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
code.TypeError,
},
Operation: spec.FindOperation{
Mode: spec.QueryModeMany,
Query: spec.QuerySpec{
Predicates: []spec.Predicate{
{
Comparator: spec.ComparatorTrue,
FieldReference: spec.FieldReference{enabledField},
ParamIndex: 1,
},
},
},
},
},
ExpectedBody: ` cursor, err := r.collection.Find(arg0, bson.M{
"enabled": true,
}, options.Find().SetSort(bson.M{
}))
if err != nil {
return nil, err
}
var entities []*UserModel
if err := cursor.All(arg0, &entities); err != nil {
return nil, err
}
return entities, nil`,
},
{
Name: "find with False comparator",
MethodSpec: spec.MethodSpec{
Name: "FindByEnabledFalse",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
},
Returns: []code.Type{
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
code.TypeError,
},
Operation: spec.FindOperation{
Mode: spec.QueryModeMany,
Query: spec.QuerySpec{
Predicates: []spec.Predicate{
{
Comparator: spec.ComparatorFalse,
FieldReference: spec.FieldReference{enabledField},
ParamIndex: 1,
},
},
},
},
},
ExpectedBody: ` cursor, err := r.collection.Find(arg0, bson.M{
"enabled": false,
}, options.Find().SetSort(bson.M{
}))
if err != nil {
return nil, err
}
var entities []*UserModel
if err := cursor.All(arg0, &entities); err != nil {
return nil, err
}
return entities, nil`,
},
{
Name: "find with Exists comparator",
MethodSpec: spec.MethodSpec{
Name: "FindByReferrerExists",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
},
Returns: []code.Type{
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
code.TypeError,
},
Operation: spec.FindOperation{
Mode: spec.QueryModeMany,
Query: spec.QuerySpec{
Predicates: []spec.Predicate{
{
Comparator: spec.ComparatorExists,
FieldReference: spec.FieldReference{referrerField},
ParamIndex: 1,
},
},
},
},
},
ExpectedBody: ` cursor, err := r.collection.Find(arg0, bson.M{
"referrer": bson.M{
"$exists": 1,
},
}, options.Find().SetSort(bson.M{
}))
if err != nil {
return nil, err
}
var entities []*UserModel
if err := cursor.All(arg0, &entities); err != nil {
return nil, err
}
return entities, nil`,
},
{
Name: "find with NotExists comparator",
MethodSpec: spec.MethodSpec{
Name: "FindByReferrerNotExists",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
},
Returns: []code.Type{
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
code.TypeError,
},
Operation: spec.FindOperation{
Mode: spec.QueryModeMany,
Query: spec.QuerySpec{
Predicates: []spec.Predicate{
{
Comparator: spec.ComparatorNotExists,
FieldReference: spec.FieldReference{referrerField},
ParamIndex: 1,
},
},
},
},
},
ExpectedBody: ` cursor, err := r.collection.Find(arg0, bson.M{
"referrer": bson.M{
"$exists": 0,
},
}, options.Find().SetSort(bson.M{
}))
if err != nil {
return nil, err
}
var entities []*UserModel
if err := cursor.All(arg0, &entities); err != nil {
return nil, err
}
return entities, nil`,
},
{
Name: "find with sort ascending",
MethodSpec: spec.MethodSpec{
Name: "FindAllOrderByAge",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
},
Returns: []code.Type{
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
code.TypeError,
},
Operation: spec.FindOperation{
Mode: spec.QueryModeMany,
Sorts: []spec.Sort{
{FieldReference: spec.FieldReference{ageField}, Ordering: spec.OrderingAscending},
},
},
},
ExpectedBody: ` cursor, err := r.collection.Find(arg0, bson.M{
}, options.Find().SetSort(bson.M{
"age": 1,
}))
if err != nil {
return nil, err
}
var entities []*UserModel
if err := cursor.All(arg0, &entities); err != nil {
return nil, err
}
return entities, nil`,
},
{
Name: "find with sort descending",
MethodSpec: spec.MethodSpec{
Name: "FindAllOrderByAgeDesc",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
},
Returns: []code.Type{
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
code.TypeError,
},
Operation: spec.FindOperation{
Mode: spec.QueryModeMany,
Sorts: []spec.Sort{
{FieldReference: spec.FieldReference{ageField}, Ordering: spec.OrderingDescending},
},
},
},
ExpectedBody: ` cursor, err := r.collection.Find(arg0, bson.M{
}, options.Find().SetSort(bson.M{
"age": -1,
}))
if err != nil {
return nil, err
}
var entities []*UserModel
if err := cursor.All(arg0, &entities); err != nil {
return nil, err
}
return entities, nil`,
},
{
Name: "find with deep sort ascending",
MethodSpec: spec.MethodSpec{
Name: "FindAllOrderByNameFirst",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
},
Returns: []code.Type{
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
code.TypeError,
},
Operation: spec.FindOperation{
Mode: spec.QueryModeMany,
Sorts: []spec.Sort{
{
FieldReference: spec.FieldReference{nameField, firstNameField},
Ordering: spec.OrderingAscending,
},
},
},
},
ExpectedBody: ` cursor, err := r.collection.Find(arg0, bson.M{
}, options.Find().SetSort(bson.M{
"name.first": 1,
}))
if err != nil {
return nil, err
}
var entities []*UserModel
if err := cursor.All(arg0, &entities); err != nil {
return nil, err
}
return entities, nil`,
},
{
Name: "find with multiple sorts",
MethodSpec: spec.MethodSpec{
Name: "FindAllOrderByGenderAndAgeDesc",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
},
Returns: []code.Type{
code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
code.TypeError,
},
Operation: spec.FindOperation{
Mode: spec.QueryModeMany,
Sorts: []spec.Sort{
{FieldReference: spec.FieldReference{genderField}, Ordering: spec.OrderingAscending},
{FieldReference: spec.FieldReference{ageField}, Ordering: spec.OrderingDescending},
},
},
},
ExpectedBody: ` cursor, err := r.collection.Find(arg0, bson.M{
}, options.Find().SetSort(bson.M{
"gender": 1,
"age": -1,
}))
if err != nil {
return nil, err
}
var entities []*UserModel
if err := cursor.All(arg0, &entities); err != nil {
return nil, err
}
return entities, nil`,
},
}
for _, testCase := range testTable {
t.Run(testCase.Name, func(t *testing.T) {
generator := mongo.NewGenerator(userModel, "UserRepository")
expectedReceiver := codegen.MethodReceiver{
Name: "r",
Type: "UserRepositoryMongo",
Pointer: true,
}
var expectedParams []code.Param
for i, param := range testCase.MethodSpec.Params {
expectedParams = append(expectedParams, code.Param{
Name: fmt.Sprintf("arg%d", i),
Type: param.Type,
})
}
actual, err := generator.GenerateMethod(testCase.MethodSpec)
if err != nil {
t.Fatal(err)
}
if expectedReceiver != actual.Receiver {
t.Errorf(
"incorrect method receiver: expected %+v, got %+v",
expectedReceiver,
actual.Receiver,
)
}
if testCase.MethodSpec.Name != actual.Name {
t.Errorf(
"incorrect method name: expected %s, got %s",
testCase.MethodSpec.Name,
actual.Name,
)
}
if !reflect.DeepEqual(expectedParams, actual.Params) {
t.Errorf(
"incorrect struct params: expected %+v, got %+v",
expectedParams,
actual.Params,
)
}
if !reflect.DeepEqual(testCase.MethodSpec.Returns, actual.Returns) {
t.Errorf(
"incorrect struct returns: expected %+v, got %+v",
testCase.MethodSpec.Returns,
actual.Returns,
)
}
if err := testutils.ExpectMultiLineString(testCase.ExpectedBody, actual.Body.Code()); err != nil {
t.Error(err)
}
})
}
}

View file

@ -1,10 +1,7 @@
package mongo package mongo
import ( import (
"bytes"
"fmt" "fmt"
"strings"
"text/template"
"github.com/sunboyy/repogen/internal/code" "github.com/sunboyy/repogen/internal/code"
"github.com/sunboyy/repogen/internal/codegen" "github.com/sunboyy/repogen/internal/codegen"
@ -14,7 +11,9 @@ import (
// NewGenerator creates a new instance of MongoDB repository generator // NewGenerator creates a new instance of MongoDB repository generator
func NewGenerator(structModel code.Struct, interfaceName string) RepositoryGenerator { func NewGenerator(structModel code.Struct, interfaceName string) RepositoryGenerator {
return RepositoryGenerator{ return RepositoryGenerator{
StructModel: structModel, baseMethodGenerator: baseMethodGenerator{
structModel: structModel,
},
InterfaceName: interfaceName, InterfaceName: interfaceName,
} }
} }
@ -22,7 +21,7 @@ func NewGenerator(structModel code.Struct, interfaceName string) RepositoryGener
// RepositoryGenerator is a MongoDB repository generator that provides // RepositoryGenerator is a MongoDB repository generator that provides
// necessary information required to construct an implementation. // necessary information required to construct an implementation.
type RepositoryGenerator struct { type RepositoryGenerator struct {
StructModel code.Struct baseMethodGenerator
InterfaceName string InterfaceName string
} }
@ -63,20 +62,6 @@ func (g RepositoryGenerator) GenerateStruct() codegen.StructBuilder {
// GenerateConstructor creates codegen.FunctionBuilder of a constructor for // GenerateConstructor creates codegen.FunctionBuilder of a constructor for
// mongo repository implementation struct. // mongo repository implementation struct.
func (g RepositoryGenerator) GenerateConstructor() (codegen.FunctionBuilder, error) { func (g RepositoryGenerator) GenerateConstructor() (codegen.FunctionBuilder, error) {
tmpl, err := template.New("mongo_constructor_body").Parse(constructorBody)
if err != nil {
return codegen.FunctionBuilder{}, err
}
tmplData := constructorBodyData{
ImplStructName: g.repoImplStructName(),
}
buffer := new(bytes.Buffer)
if err := tmpl.Execute(buffer, tmplData); err != nil {
return codegen.FunctionBuilder{}, err
}
return codegen.FunctionBuilder{ return codegen.FunctionBuilder{
Name: "New" + g.InterfaceName, Name: "New" + g.InterfaceName,
Params: []code.Param{ Params: []code.Param{
@ -93,7 +78,17 @@ func (g RepositoryGenerator) GenerateConstructor() (codegen.FunctionBuilder, err
Returns: []code.Type{ Returns: []code.Type{
code.SimpleType(g.InterfaceName), code.SimpleType(g.InterfaceName),
}, },
Body: buffer.String(), Body: codegen.FunctionBody{
codegen.ReturnStatement{
codegen.StructStatement{
Type: fmt.Sprintf("&%s", g.repoImplStructName()),
Pairs: []codegen.StructFieldPair{{
Key: "collection",
Value: codegen.Identifier("collection"),
}},
},
},
},
}, nil }, nil
} }
@ -126,220 +121,25 @@ func (g RepositoryGenerator) GenerateMethod(methodSpec spec.MethodSpec) (codegen
}, nil }, nil
} }
func (g RepositoryGenerator) generateMethodImplementation(methodSpec spec.MethodSpec) (string, error) { func (g RepositoryGenerator) generateMethodImplementation(
methodSpec spec.MethodSpec) (codegen.FunctionBody, error) {
switch operation := methodSpec.Operation.(type) { switch operation := methodSpec.Operation.(type) {
case spec.InsertOperation: case spec.InsertOperation:
return g.generateInsertImplementation(operation) return g.generateInsertBody(operation), nil
case spec.FindOperation: case spec.FindOperation:
return g.generateFindImplementation(operation) return g.generateFindBody(operation)
case spec.UpdateOperation: case spec.UpdateOperation:
return g.generateUpdateImplementation(operation) return g.generateUpdateBody(operation)
case spec.DeleteOperation: case spec.DeleteOperation:
return g.generateDeleteImplementation(operation) return g.generateDeleteBody(operation)
case spec.CountOperation: case spec.CountOperation:
return g.generateCountImplementation(operation) return g.generateCountBody(operation)
default: default:
return "", NewOperationNotSupportedError(operation.Name()) return nil, NewOperationNotSupportedError(operation.Name())
} }
} }
func (g RepositoryGenerator) generateInsertImplementation(operation spec.InsertOperation) (string, error) {
if operation.Mode == spec.QueryModeOne {
return insertOneTemplate, nil
}
return insertManyTemplate, nil
}
func (g RepositoryGenerator) generateFindImplementation(operation spec.FindOperation) (string, error) {
querySpec, err := g.mongoQuerySpec(operation.Query)
if err != nil {
return "", err
}
sorts, err := g.mongoSorts(operation.Sorts)
if err != nil {
return "", err
}
tmplData := mongoFindTemplateData{
EntityType: g.StructModel.Name,
QuerySpec: querySpec,
Sorts: sorts,
}
if operation.Mode == spec.QueryModeOne {
return generateFromTemplate("mongo_repository_findone", findOneTemplate, tmplData)
}
return generateFromTemplate("mongo_repository_findmany", findManyTemplate, tmplData)
}
func (g RepositoryGenerator) mongoSorts(sortSpec []spec.Sort) ([]findSort, error) {
var sorts []findSort
for _, s := range sortSpec {
bsonFieldReference, err := g.bsonFieldReference(s.FieldReference)
if err != nil {
return nil, err
}
sorts = append(sorts, findSort{
BsonTag: bsonFieldReference,
Ordering: s.Ordering,
})
}
return sorts, nil
}
func (g RepositoryGenerator) generateUpdateImplementation(operation spec.UpdateOperation) (string, error) {
update, err := g.getMongoUpdate(operation.Update)
if err != nil {
return "", err
}
querySpec, err := g.mongoQuerySpec(operation.Query)
if err != nil {
return "", err
}
tmplData := mongoUpdateTemplateData{
Update: update,
QuerySpec: querySpec,
}
if operation.Mode == spec.QueryModeOne {
return generateFromTemplate("mongo_repository_updateone", updateOneTemplate, tmplData)
}
return generateFromTemplate("mongo_repository_updatemany", updateManyTemplate, tmplData)
}
func (g RepositoryGenerator) getMongoUpdate(updateSpec spec.Update) (update, error) {
switch updateSpec := updateSpec.(type) {
case spec.UpdateModel:
return updateModel{}, nil
case spec.UpdateFields:
update := make(updateFields)
for _, field := range updateSpec {
bsonFieldReference, err := g.bsonFieldReference(field.FieldReference)
if err != nil {
return querySpec{}, err
}
updateKey := getUpdateOperatorKey(field.Operator)
if updateKey == "" {
return querySpec{}, NewUpdateOperatorNotSupportedError(field.Operator)
}
updateField := updateField{
BsonTag: bsonFieldReference,
ParamIndex: field.ParamIndex,
}
update[updateKey] = append(update[updateKey], updateField)
}
return update, nil
default:
return nil, NewUpdateTypeNotSupportedError(updateSpec)
}
}
func getUpdateOperatorKey(operator spec.UpdateOperator) string {
switch operator {
case spec.UpdateOperatorSet:
return "$set"
case spec.UpdateOperatorPush:
return "$push"
case spec.UpdateOperatorInc:
return "$inc"
default:
return ""
}
}
func (g RepositoryGenerator) generateDeleteImplementation(operation spec.DeleteOperation) (string, error) {
querySpec, err := g.mongoQuerySpec(operation.Query)
if err != nil {
return "", err
}
tmplData := mongoDeleteTemplateData{
QuerySpec: querySpec,
}
if operation.Mode == spec.QueryModeOne {
return generateFromTemplate("mongo_repository_deleteone", deleteOneTemplate, tmplData)
}
return generateFromTemplate("mongo_repository_deletemany", deleteManyTemplate, tmplData)
}
func (g RepositoryGenerator) generateCountImplementation(operation spec.CountOperation) (string, error) {
querySpec, err := g.mongoQuerySpec(operation.Query)
if err != nil {
return "", err
}
tmplData := mongoCountTemplateData{
QuerySpec: querySpec,
}
return generateFromTemplate("mongo_repository_count", countTemplate, tmplData)
}
func (g RepositoryGenerator) mongoQuerySpec(query spec.QuerySpec) (querySpec, error) {
var predicates []predicate
for _, predicateSpec := range query.Predicates {
bsonFieldReference, err := g.bsonFieldReference(predicateSpec.FieldReference)
if err != nil {
return querySpec{}, err
}
predicates = append(predicates, predicate{
Field: bsonFieldReference,
Comparator: predicateSpec.Comparator,
ParamIndex: predicateSpec.ParamIndex,
})
}
return querySpec{
Operator: query.Operator,
Predicates: predicates,
}, nil
}
func (g RepositoryGenerator) bsonFieldReference(fieldReference spec.FieldReference) (string, error) {
var bsonTags []string
for _, field := range fieldReference {
tag, err := g.bsonTagFromField(field)
if err != nil {
return "", err
}
bsonTags = append(bsonTags, tag)
}
return strings.Join(bsonTags, "."), nil
}
func (g RepositoryGenerator) bsonTagFromField(field code.StructField) (string, error) {
bsonTag, ok := field.Tags["bson"]
if !ok {
return "", NewBsonTagNotFoundError(field.Name)
}
return bsonTag[0], nil
}
func (g RepositoryGenerator) repoImplStructName() string { func (g RepositoryGenerator) repoImplStructName() string {
return g.InterfaceName + "Mongo" return g.InterfaceName + "Mongo"
} }
func generateFromTemplate(name string, templateString string, tmplData interface{}) (string, error) {
tmpl, err := template.New(name).Parse(templateString)
if err != nil {
return "", err
}
buffer := new(bytes.Buffer)
if err := tmpl.Execute(buffer, tmplData); err != nil {
return "", err
}
return buffer.String(), nil
}

File diff suppressed because it is too large Load diff

81
internal/mongo/insert.go Normal file
View file

@ -0,0 +1,81 @@
package mongo
import (
"github.com/sunboyy/repogen/internal/code"
"github.com/sunboyy/repogen/internal/codegen"
"github.com/sunboyy/repogen/internal/spec"
)
func (g RepositoryGenerator) generateInsertBody(
operation spec.InsertOperation) codegen.FunctionBody {
if operation.Mode == spec.QueryModeOne {
return g.generateInsertOneBody()
}
return g.generateInsertManyBody()
}
func (g RepositoryGenerator) generateInsertOneBody() codegen.FunctionBody {
return codegen.FunctionBody{
codegen.DeclAssignStatement{
Vars: []string{"result", "err"},
Values: codegen.StatementList{
codegen.NewChainBuilder("r").
Chain("collection").
Call("InsertOne",
codegen.Identifier("arg0"),
codegen.Identifier("arg1"),
).Build(),
},
},
ifErrReturnNilErr,
codegen.ReturnStatement{
codegen.NewChainBuilder("result").Chain("InsertedID").Build(),
codegen.Identifier("nil"),
},
}
}
func (g RepositoryGenerator) generateInsertManyBody() codegen.FunctionBody {
return codegen.FunctionBody{
codegen.DeclStatement{
Name: "entities",
Type: code.ArrayType{
ContainedType: code.InterfaceType{},
},
},
codegen.RawBlock{
Header: []string{"for _, model := range arg1"},
Statements: []codegen.Statement{
codegen.AssignStatement{
Vars: []string{"entities"},
Values: codegen.StatementList{
codegen.CallStatement{
FuncName: "append",
Params: codegen.StatementList{
codegen.Identifier("entities"),
codegen.Identifier("model"),
},
},
},
},
},
},
codegen.DeclAssignStatement{
Vars: []string{"result", "err"},
Values: codegen.StatementList{
codegen.NewChainBuilder("r").
Chain("collection").
Call("InsertMany",
codegen.Identifier("arg0"),
codegen.Identifier("entities"),
).Build(),
},
},
ifErrReturnNilErr,
codegen.ReturnStatement{
codegen.NewChainBuilder("result").Chain("InsertedIDs").Build(),
codegen.Identifier("nil"),
},
}
}

View file

@ -0,0 +1,123 @@
package mongo_test
import (
"fmt"
"reflect"
"testing"
"github.com/sunboyy/repogen/internal/code"
"github.com/sunboyy/repogen/internal/codegen"
"github.com/sunboyy/repogen/internal/mongo"
"github.com/sunboyy/repogen/internal/spec"
"github.com/sunboyy/repogen/internal/testutils"
)
func TestGenerateMethod_Insert(t *testing.T) {
testTable := []GenerateMethodTestCase{
{
Name: "insert one method",
MethodSpec: spec.MethodSpec{
Name: "InsertOne",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "userModel", Type: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
},
Returns: []code.Type{
code.InterfaceType{},
code.TypeError,
},
Operation: spec.InsertOperation{
Mode: spec.QueryModeOne,
},
},
ExpectedBody: ` result, err := r.collection.InsertOne(arg0, arg1)
if err != nil {
return nil, err
}
return result.InsertedID, nil`,
},
{
Name: "insert many method",
MethodSpec: spec.MethodSpec{
Name: "Insert",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "userModel", Type: code.ArrayType{
ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")},
}},
},
Returns: []code.Type{
code.ArrayType{ContainedType: code.InterfaceType{}},
code.TypeError,
},
Operation: spec.InsertOperation{
Mode: spec.QueryModeMany,
},
},
ExpectedBody: ` var entities []interface{}
for _, model := range arg1 {
entities = append(entities, model)
}
result, err := r.collection.InsertMany(arg0, entities)
if err != nil {
return nil, err
}
return result.InsertedIDs, nil`,
},
}
for _, testCase := range testTable {
t.Run(testCase.Name, func(t *testing.T) {
generator := mongo.NewGenerator(userModel, "UserRepository")
expectedReceiver := codegen.MethodReceiver{
Name: "r",
Type: "UserRepositoryMongo",
Pointer: true,
}
var expectedParams []code.Param
for i, param := range testCase.MethodSpec.Params {
expectedParams = append(expectedParams, code.Param{
Name: fmt.Sprintf("arg%d", i),
Type: param.Type,
})
}
actual, err := generator.GenerateMethod(testCase.MethodSpec)
if err != nil {
t.Fatal(err)
}
if expectedReceiver != actual.Receiver {
t.Errorf(
"incorrect method receiver: expected %+v, got %+v",
expectedReceiver,
actual.Receiver,
)
}
if testCase.MethodSpec.Name != actual.Name {
t.Errorf(
"incorrect method name: expected %s, got %s",
testCase.MethodSpec.Name,
actual.Name,
)
}
if !reflect.DeepEqual(expectedParams, actual.Params) {
t.Errorf(
"incorrect struct params: expected %+v, got %+v",
expectedParams,
actual.Params,
)
}
if !reflect.DeepEqual(testCase.MethodSpec.Returns, actual.Returns) {
t.Errorf(
"incorrect struct returns: expected %+v, got %+v",
testCase.MethodSpec.Returns,
actual.Returns,
)
}
if err := testutils.ExpectMultiLineString(testCase.ExpectedBody, actual.Body.Code()); err != nil {
t.Error(err)
}
})
}
}

View file

@ -3,8 +3,9 @@ package mongo
import ( import (
"fmt" "fmt"
"sort" "sort"
"strings"
"github.com/sunboyy/repogen/internal/code"
"github.com/sunboyy/repogen/internal/codegen"
"github.com/sunboyy/repogen/internal/spec" "github.com/sunboyy/repogen/internal/spec"
) )
@ -14,36 +15,54 @@ type updateField struct {
} }
type update interface { type update interface {
Code() string Code() codegen.Statement
} }
type updateModel struct { type updateModel struct {
} }
func (u updateModel) Code() string { func (u updateModel) Code() codegen.Statement {
return ` "$set": arg1,` return codegen.MapStatement{
Type: "bson.M",
Pairs: []codegen.MapPair{
{
Key: "$set",
Value: codegen.Identifier("arg1"),
},
},
}
} }
type updateFields map[string][]updateField type updateFields map[string][]updateField
func (u updateFields) Code() string { func (u updateFields) Code() codegen.Statement {
var keys []string var keys []string
for k := range u { for k := range u {
keys = append(keys, k) keys = append(keys, k)
} }
sort.Strings(keys) sort.Strings(keys)
var lines []string stmt := codegen.MapStatement{
Type: "bson.M",
}
for _, key := range keys { for _, key := range keys {
lines = append(lines, fmt.Sprintf(` "%s": bson.M{`, key)) applicationMap := codegen.MapStatement{
Type: "bson.M",
}
for _, field := range u[key] { for _, field := range u[key] {
lines = append(lines, fmt.Sprintf(` "%s": arg%d,`, field.BsonTag, field.ParamIndex)) applicationMap.Pairs = append(applicationMap.Pairs, codegen.MapPair{
Key: field.BsonTag,
Value: codegen.Identifier(fmt.Sprintf("arg%d", field.ParamIndex)),
})
} }
lines = append(lines, ` },`) stmt.Pairs = append(stmt.Pairs, codegen.MapPair{
Key: key,
Value: applicationMap,
})
} }
return strings.Join(lines, "\n") return stmt
} }
type querySpec struct { type querySpec struct {
@ -51,32 +70,52 @@ type querySpec struct {
Predicates []predicate Predicates []predicate
} }
func (q querySpec) Code() string { func (q querySpec) Code() codegen.Statement {
var predicateCodes []string var predicatePairs []codegen.MapPair
for _, predicate := range q.Predicates { for _, predicate := range q.Predicates {
predicateCodes = append(predicateCodes, predicate.Code()) predicatePairs = append(predicatePairs, predicate.Code())
}
var predicateMaps []codegen.Statement
for _, pair := range predicatePairs {
predicateMaps = append(predicateMaps, codegen.MapStatement{
Pairs: []codegen.MapPair{pair},
})
} }
var lines []string stmt := codegen.MapStatement{
Type: "bson.M",
}
switch q.Operator { switch q.Operator {
case spec.OperatorOr: case spec.OperatorOr:
lines = append(lines, ` "$or": []bson.M{`) stmt.Pairs = append(stmt.Pairs, codegen.MapPair{
for _, predicateCode := range predicateCodes { Key: "$or",
lines = append(lines, fmt.Sprintf(` {%s},`, predicateCode)) Value: codegen.SliceStatement{
} Type: code.ArrayType{
lines = append(lines, ` },`) ContainedType: code.ExternalType{
PackageAlias: "bson",
Name: "M",
},
},
Values: predicateMaps,
},
})
case spec.OperatorAnd: case spec.OperatorAnd:
lines = append(lines, ` "$and": []bson.M{`) stmt.Pairs = append(stmt.Pairs, codegen.MapPair{
for _, predicateCode := range predicateCodes { Key: "$and",
lines = append(lines, fmt.Sprintf(` {%s},`, predicateCode)) Value: codegen.SliceStatement{
} Type: code.ArrayType{
lines = append(lines, ` },`) ContainedType: code.ExternalType{
PackageAlias: "bson",
Name: "M",
},
},
Values: predicateMaps,
},
})
default: default:
for _, predicateCode := range predicateCodes { stmt.Pairs = predicatePairs
lines = append(lines, fmt.Sprintf(` %s,`, predicateCode))
} }
} return stmt
return strings.Join(lines, "\n")
} }
type predicate struct { type predicate struct {
@ -85,34 +124,86 @@ type predicate struct {
ParamIndex int ParamIndex int
} }
func (p predicate) Code() string { func (p predicate) Code() codegen.MapPair {
argStmt := codegen.Identifier(fmt.Sprintf("arg%d", p.ParamIndex))
switch p.Comparator { switch p.Comparator {
case spec.ComparatorEqual: case spec.ComparatorEqual:
return fmt.Sprintf(`"%s": arg%d`, p.Field, p.ParamIndex) return p.createValueMapPair(argStmt)
case spec.ComparatorNot: case spec.ComparatorNot:
return fmt.Sprintf(`"%s": bson.M{"$ne": arg%d}`, p.Field, p.ParamIndex) return p.createSingleComparisonMapPair("$ne", argStmt)
case spec.ComparatorLessThan: case spec.ComparatorLessThan:
return fmt.Sprintf(`"%s": bson.M{"$lt": arg%d}`, p.Field, p.ParamIndex) return p.createSingleComparisonMapPair("$lt", argStmt)
case spec.ComparatorLessThanEqual: case spec.ComparatorLessThanEqual:
return fmt.Sprintf(`"%s": bson.M{"$lte": arg%d}`, p.Field, p.ParamIndex) return p.createSingleComparisonMapPair("$lte", argStmt)
case spec.ComparatorGreaterThan: case spec.ComparatorGreaterThan:
return fmt.Sprintf(`"%s": bson.M{"$gt": arg%d}`, p.Field, p.ParamIndex) return p.createSingleComparisonMapPair("$gt", argStmt)
case spec.ComparatorGreaterThanEqual: case spec.ComparatorGreaterThanEqual:
return fmt.Sprintf(`"%s": bson.M{"$gte": arg%d}`, p.Field, p.ParamIndex) return p.createSingleComparisonMapPair("$gte", argStmt)
case spec.ComparatorBetween: case spec.ComparatorBetween:
return fmt.Sprintf(`"%s": bson.M{"$gte": arg%d, "$lte": arg%d}`, p.Field, p.ParamIndex, p.ParamIndex+1) argStmt2 := codegen.Identifier(fmt.Sprintf("arg%d", p.ParamIndex+1))
return p.createBetweenMapPair(argStmt, argStmt2)
case spec.ComparatorIn: case spec.ComparatorIn:
return fmt.Sprintf(`"%s": bson.M{"$in": arg%d}`, p.Field, p.ParamIndex) return p.createSingleComparisonMapPair("$in", argStmt)
case spec.ComparatorNotIn: case spec.ComparatorNotIn:
return fmt.Sprintf(`"%s": bson.M{"$nin": arg%d}`, p.Field, p.ParamIndex) return p.createSingleComparisonMapPair("$nin", argStmt)
case spec.ComparatorTrue: case spec.ComparatorTrue:
return fmt.Sprintf(`"%s": true`, p.Field) return p.createValueMapPair(codegen.Identifier("true"))
case spec.ComparatorFalse: case spec.ComparatorFalse:
return fmt.Sprintf(`"%s": false`, p.Field) return p.createValueMapPair(codegen.Identifier("false"))
case spec.ComparatorExists: case spec.ComparatorExists:
return fmt.Sprintf(`"%s": bson.M{"$exists": 1}`, p.Field) return p.createExistsMapPair("1")
case spec.ComparatorNotExists: case spec.ComparatorNotExists:
return fmt.Sprintf(`"%s": bson.M{"$exists": 0}`, p.Field) return p.createExistsMapPair("0")
}
return codegen.MapPair{}
}
func (p predicate) createValueMapPair(
argStmt codegen.Statement) codegen.MapPair {
return codegen.MapPair{
Key: p.Field,
Value: argStmt,
}
}
func (p predicate) createSingleComparisonMapPair(comparatorKey string,
argStmt codegen.Statement) codegen.MapPair {
return codegen.MapPair{
Key: p.Field,
Value: codegen.MapStatement{
Type: "bson.M",
Pairs: []codegen.MapPair{{Key: comparatorKey, Value: argStmt}},
},
}
}
func (p predicate) createBetweenMapPair(argStmt codegen.Statement,
argStmt2 codegen.Statement) codegen.MapPair {
return codegen.MapPair{
Key: p.Field,
Value: codegen.MapStatement{
Type: "bson.M",
Pairs: []codegen.MapPair{
{Key: "$gte", Value: argStmt},
{Key: "$lte", Value: argStmt2},
},
},
}
}
func (p predicate) createExistsMapPair(existsValue string) codegen.MapPair {
return codegen.MapPair{
Key: p.Field,
Value: codegen.MapStatement{
Type: "bson.M",
Pairs: []codegen.MapPair{{
Key: "$exists",
Value: codegen.Identifier(existsValue),
}},
},
} }
return ""
} }

View file

@ -1,128 +0,0 @@
package mongo
import (
"github.com/sunboyy/repogen/internal/spec"
)
const constructorBody = ` return &{{.ImplStructName}}{
collection: collection,
}`
type constructorBodyData struct {
ImplStructName string
}
const insertOneTemplate = ` result, err := r.collection.InsertOne(arg0, arg1)
if err != nil {
return nil, err
}
return result.InsertedID, nil`
const insertManyTemplate = ` var entities []interface{}
for _, model := range arg1 {
entities = append(entities, model)
}
result, err := r.collection.InsertMany(arg0, entities)
if err != nil {
return nil, err
}
return result.InsertedIDs, nil`
type mongoFindTemplateData struct {
EntityType string
QuerySpec querySpec
Sorts []findSort
}
type findSort struct {
BsonTag string
Ordering spec.Ordering
}
func (s findSort) OrderNum() int {
if s.Ordering == spec.OrderingAscending {
return 1
}
return -1
}
const findOneTemplate = ` var entity {{.EntityType}}
if err := r.collection.FindOne(arg0, bson.M{
{{.QuerySpec.Code}}
}, options.FindOne().SetSort(bson.M{
{{range $index, $element := .Sorts}} "{{$element.BsonTag}}": {{$element.OrderNum}},
{{end}} })).Decode(&entity); err != nil {
return nil, err
}
return &entity, nil`
const findManyTemplate = ` cursor, err := r.collection.Find(arg0, bson.M{
{{.QuerySpec.Code}}
}, options.Find().SetSort(bson.M{
{{range $index, $element := .Sorts}} "{{$element.BsonTag}}": {{$element.OrderNum}},
{{end}} }))
if err != nil {
return nil, err
}
var entities []*{{.EntityType}}
if err := cursor.All(arg0, &entities); err != nil {
return nil, err
}
return entities, nil`
type mongoUpdateTemplateData struct {
Update update
QuerySpec querySpec
}
const updateOneTemplate = ` result, err := r.collection.UpdateOne(arg0, bson.M{
{{.QuerySpec.Code}}
}, bson.M{
{{.Update.Code}}
})
if err != nil {
return false, err
}
return result.MatchedCount > 0, err`
const updateManyTemplate = ` result, err := r.collection.UpdateMany(arg0, bson.M{
{{.QuerySpec.Code}}
}, bson.M{
{{.Update.Code}}
})
if err != nil {
return 0, err
}
return int(result.MatchedCount), err`
type mongoDeleteTemplateData struct {
QuerySpec querySpec
}
const deleteOneTemplate = ` result, err := r.collection.DeleteOne(arg0, bson.M{
{{.QuerySpec.Code}}
})
if err != nil {
return false, err
}
return result.DeletedCount > 0, nil`
const deleteManyTemplate = ` result, err := r.collection.DeleteMany(arg0, bson.M{
{{.QuerySpec.Code}}
})
if err != nil {
return 0, err
}
return int(result.DeletedCount), nil`
type mongoCountTemplateData struct {
QuerySpec querySpec
}
const countTemplate = ` count, err := r.collection.CountDocuments(arg0, bson.M{
{{.QuerySpec.Code}}
})
if err != nil {
return 0, err
}
return int(count), nil`

133
internal/mongo/update.go Normal file
View file

@ -0,0 +1,133 @@
package mongo
import (
"github.com/sunboyy/repogen/internal/codegen"
"github.com/sunboyy/repogen/internal/spec"
)
func (g RepositoryGenerator) generateUpdateBody(
operation spec.UpdateOperation) (codegen.FunctionBody, error) {
return updateBodyGenerator{
baseMethodGenerator: g.baseMethodGenerator,
operation: operation,
}.generate()
}
type updateBodyGenerator struct {
baseMethodGenerator
operation spec.UpdateOperation
}
func (g updateBodyGenerator) generate() (codegen.FunctionBody, error) {
update, err := g.convertUpdate(g.operation.Update)
if err != nil {
return nil, err
}
querySpec, err := g.convertQuerySpec(g.operation.Query)
if err != nil {
return nil, err
}
if g.operation.Mode == spec.QueryModeOne {
return g.generateUpdateOneBody(update, querySpec), nil
}
return g.generateUpdateManyBody(update, querySpec), nil
}
func (g updateBodyGenerator) generateUpdateOneBody(update update,
querySpec querySpec) codegen.FunctionBody {
return codegen.FunctionBody{
codegen.DeclAssignStatement{
Vars: []string{"result", "err"},
Values: codegen.StatementList{
codegen.NewChainBuilder("r").
Chain("collection").
Call("UpdateOne",
codegen.Identifier("arg0"),
querySpec.Code(),
update.Code(),
).Build(),
},
},
ifErrReturnFalseErr,
codegen.ReturnStatement{
codegen.RawStatement("result.MatchedCount > 0"),
codegen.Identifier("nil"),
},
}
}
func (g updateBodyGenerator) generateUpdateManyBody(update update,
querySpec querySpec) codegen.FunctionBody {
return codegen.FunctionBody{
codegen.DeclAssignStatement{
Vars: []string{"result", "err"},
Values: codegen.StatementList{
codegen.NewChainBuilder("r").
Chain("collection").
Call("UpdateMany",
codegen.Identifier("arg0"),
querySpec.Code(),
update.Code(),
).Build(),
},
},
ifErrReturn0Err,
codegen.ReturnStatement{
codegen.CallStatement{
FuncName: "int",
Params: codegen.StatementList{
codegen.NewChainBuilder("result").
Chain("MatchedCount").Build(),
},
},
codegen.Identifier("nil"),
},
}
}
func (g updateBodyGenerator) convertUpdate(updateSpec spec.Update) (update, error) {
switch updateSpec := updateSpec.(type) {
case spec.UpdateModel:
return updateModel{}, nil
case spec.UpdateFields:
update := make(updateFields)
for _, field := range updateSpec {
bsonFieldReference, err := g.bsonFieldReference(field.FieldReference)
if err != nil {
return nil, err
}
updateKey := getUpdateOperatorKey(field.Operator)
if updateKey == "" {
return nil, NewUpdateOperatorNotSupportedError(field.Operator)
}
updateField := updateField{
BsonTag: bsonFieldReference,
ParamIndex: field.ParamIndex,
}
update[updateKey] = append(update[updateKey], updateField)
}
return update, nil
default:
return nil, NewUpdateTypeNotSupportedError(updateSpec)
}
}
func getUpdateOperatorKey(operator spec.UpdateOperator) string {
switch operator {
case spec.UpdateOperatorSet:
return "$set"
case spec.UpdateOperatorPush:
return "$push"
case spec.UpdateOperatorInc:
return "$inc"
default:
return ""
}
}

View file

@ -0,0 +1,389 @@
package mongo_test
import (
"fmt"
"reflect"
"testing"
"github.com/sunboyy/repogen/internal/code"
"github.com/sunboyy/repogen/internal/codegen"
"github.com/sunboyy/repogen/internal/mongo"
"github.com/sunboyy/repogen/internal/spec"
"github.com/sunboyy/repogen/internal/testutils"
)
func TestGenerateMethod_Update(t *testing.T) {
testTable := []GenerateMethodTestCase{
{
Name: "update model method",
MethodSpec: spec.MethodSpec{
Name: "UpdateByID",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "model", Type: code.PointerType{ContainedType: code.SimpleType("UserModel")}},
{Name: "id", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}},
},
Returns: []code.Type{
code.TypeBool,
code.TypeError,
},
Operation: spec.UpdateOperation{
Update: spec.UpdateModel{},
Mode: spec.QueryModeOne,
Query: spec.QuerySpec{
Predicates: []spec.Predicate{
{
FieldReference: spec.FieldReference{idField},
Comparator: spec.ComparatorEqual,
ParamIndex: 2,
},
},
},
},
},
ExpectedBody: ` result, err := r.collection.UpdateOne(arg0, bson.M{
"_id": arg2,
}, bson.M{
"$set": arg1,
})
if err != nil {
return false, err
}
return result.MatchedCount > 0, nil`,
},
{
Name: "simple update one method",
MethodSpec: spec.MethodSpec{
Name: "UpdateAgeByID",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "age", Type: code.TypeInt},
{Name: "id", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}},
},
Returns: []code.Type{
code.TypeBool,
code.TypeError,
},
Operation: spec.UpdateOperation{
Update: spec.UpdateFields{
spec.UpdateField{
FieldReference: spec.FieldReference{ageField},
ParamIndex: 1,
Operator: spec.UpdateOperatorSet,
},
},
Mode: spec.QueryModeOne,
Query: spec.QuerySpec{
Predicates: []spec.Predicate{
{
FieldReference: spec.FieldReference{idField},
Comparator: spec.ComparatorEqual,
ParamIndex: 2,
},
},
},
},
},
ExpectedBody: ` result, err := r.collection.UpdateOne(arg0, bson.M{
"_id": arg2,
}, bson.M{
"$set": bson.M{
"age": arg1,
},
})
if err != nil {
return false, err
}
return result.MatchedCount > 0, nil`,
},
{
Name: "simple update many method",
MethodSpec: spec.MethodSpec{
Name: "UpdateAgeByGender",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "age", Type: code.TypeInt},
{Name: "gender", Type: code.SimpleType("Gender")},
},
Returns: []code.Type{
code.TypeInt,
code.TypeError,
},
Operation: spec.UpdateOperation{
Update: spec.UpdateFields{
spec.UpdateField{
FieldReference: spec.FieldReference{ageField},
ParamIndex: 1,
Operator: spec.UpdateOperatorSet,
},
},
Mode: spec.QueryModeMany,
Query: spec.QuerySpec{
Predicates: []spec.Predicate{
{
FieldReference: spec.FieldReference{genderField},
Comparator: spec.ComparatorEqual,
ParamIndex: 2,
},
},
},
},
},
ExpectedBody: ` result, err := r.collection.UpdateMany(arg0, bson.M{
"gender": arg2,
}, bson.M{
"$set": bson.M{
"age": arg1,
},
})
if err != nil {
return 0, err
}
return int(result.MatchedCount), nil`,
},
{
Name: "simple update push method",
MethodSpec: spec.MethodSpec{
Name: "UpdateConsentHistoryPushByID",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "consentHistory", Type: code.SimpleType("ConsentHistory")},
{Name: "id", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}},
},
Returns: []code.Type{
code.TypeBool,
code.TypeError,
},
Operation: spec.UpdateOperation{
Update: spec.UpdateFields{
spec.UpdateField{
FieldReference: spec.FieldReference{consentHistoryField},
ParamIndex: 1,
Operator: spec.UpdateOperatorPush,
},
},
Mode: spec.QueryModeOne,
Query: spec.QuerySpec{
Predicates: []spec.Predicate{
{
FieldReference: spec.FieldReference{idField},
Comparator: spec.ComparatorEqual,
ParamIndex: 2,
},
},
},
},
},
ExpectedBody: ` result, err := r.collection.UpdateOne(arg0, bson.M{
"_id": arg2,
}, bson.M{
"$push": bson.M{
"consent_history": arg1,
},
})
if err != nil {
return false, err
}
return result.MatchedCount > 0, nil`,
},
{
Name: "simple update inc method",
MethodSpec: spec.MethodSpec{
Name: "UpdateAgeIncByID",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "age", Type: code.TypeInt},
{Name: "id", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}},
},
Returns: []code.Type{
code.TypeBool,
code.TypeError,
},
Operation: spec.UpdateOperation{
Update: spec.UpdateFields{
spec.UpdateField{
FieldReference: spec.FieldReference{ageField},
ParamIndex: 1,
Operator: spec.UpdateOperatorInc,
},
},
Mode: spec.QueryModeOne,
Query: spec.QuerySpec{
Predicates: []spec.Predicate{
{
FieldReference: spec.FieldReference{idField},
Comparator: spec.ComparatorEqual,
ParamIndex: 2,
},
},
},
},
},
ExpectedBody: ` result, err := r.collection.UpdateOne(arg0, bson.M{
"_id": arg2,
}, bson.M{
"$inc": bson.M{
"age": arg1,
},
})
if err != nil {
return false, err
}
return result.MatchedCount > 0, nil`,
},
{
Name: "simple update set and push method",
MethodSpec: spec.MethodSpec{
Name: "UpdateEnabledAndConsentHistoryPushByID",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "enabled", Type: code.TypeBool},
{Name: "consentHistory", Type: code.SimpleType("ConsentHistory")},
{Name: "gender", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}},
},
Returns: []code.Type{
code.TypeBool,
code.TypeError,
},
Operation: spec.UpdateOperation{
Update: spec.UpdateFields{
spec.UpdateField{
FieldReference: spec.FieldReference{enabledField},
ParamIndex: 1,
Operator: spec.UpdateOperatorSet,
},
spec.UpdateField{
FieldReference: spec.FieldReference{consentHistoryField},
ParamIndex: 2,
Operator: spec.UpdateOperatorPush,
},
},
Mode: spec.QueryModeOne,
Query: spec.QuerySpec{
Predicates: []spec.Predicate{
{
FieldReference: spec.FieldReference{idField},
Comparator: spec.ComparatorEqual,
ParamIndex: 3,
},
},
},
},
},
ExpectedBody: ` result, err := r.collection.UpdateOne(arg0, bson.M{
"_id": arg3,
}, bson.M{
"$push": bson.M{
"consent_history": arg2,
},
"$set": bson.M{
"enabled": arg1,
},
})
if err != nil {
return false, err
}
return result.MatchedCount > 0, nil`,
},
{
Name: "update with deeply referenced field",
MethodSpec: spec.MethodSpec{
Name: "UpdateNameFirstByID",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "firstName", Type: code.TypeString},
{Name: "id", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}},
},
Returns: []code.Type{
code.TypeBool,
code.TypeError,
},
Operation: spec.UpdateOperation{
Update: spec.UpdateFields{
spec.UpdateField{
FieldReference: spec.FieldReference{nameField, firstNameField},
ParamIndex: 1,
Operator: spec.UpdateOperatorSet,
},
},
Mode: spec.QueryModeOne,
Query: spec.QuerySpec{
Predicates: []spec.Predicate{
{
FieldReference: spec.FieldReference{idField},
Comparator: spec.ComparatorEqual,
ParamIndex: 2,
},
},
},
},
},
ExpectedBody: ` result, err := r.collection.UpdateOne(arg0, bson.M{
"_id": arg2,
}, bson.M{
"$set": bson.M{
"name.first": arg1,
},
})
if err != nil {
return false, err
}
return result.MatchedCount > 0, nil`,
},
}
for _, testCase := range testTable {
t.Run(testCase.Name, func(t *testing.T) {
generator := mongo.NewGenerator(userModel, "UserRepository")
expectedReceiver := codegen.MethodReceiver{
Name: "r",
Type: "UserRepositoryMongo",
Pointer: true,
}
var expectedParams []code.Param
for i, param := range testCase.MethodSpec.Params {
expectedParams = append(expectedParams, code.Param{
Name: fmt.Sprintf("arg%d", i),
Type: param.Type,
})
}
actual, err := generator.GenerateMethod(testCase.MethodSpec)
if err != nil {
t.Fatal(err)
}
if expectedReceiver != actual.Receiver {
t.Errorf(
"incorrect method receiver: expected %+v, got %+v",
expectedReceiver,
actual.Receiver,
)
}
if testCase.MethodSpec.Name != actual.Name {
t.Errorf(
"incorrect method name: expected %s, got %s",
testCase.MethodSpec.Name,
actual.Name,
)
}
if !reflect.DeepEqual(expectedParams, actual.Params) {
t.Errorf(
"incorrect struct params: expected %+v, got %+v",
expectedParams,
actual.Params,
)
}
if !reflect.DeepEqual(testCase.MethodSpec.Returns, actual.Returns) {
t.Errorf(
"incorrect struct returns: expected %+v, got %+v",
testCase.MethodSpec.Returns,
actual.Returns,
)
}
if err := testutils.ExpectMultiLineString(testCase.ExpectedBody, actual.Body.Code()); err != nil {
t.Error(err)
}
})
}
}

View file

@ -33,8 +33,16 @@ func (r *UserRepositoryMongo) FindByID(arg0 context.Context, arg1 primitive.Obje
func (r *UserRepositoryMongo) FindByGenderNotAndAgeLessThan(arg0 context.Context, arg1 Gender, arg2 int) (*UserModel, error) { func (r *UserRepositoryMongo) FindByGenderNotAndAgeLessThan(arg0 context.Context, arg1 Gender, arg2 int) (*UserModel, error) {
cursor, err := r.collection.Find(arg0, bson.M{ cursor, err := r.collection.Find(arg0, bson.M{
"$and": []bson.M{ "$and": []bson.M{
{"gender": bson.M{"$ne": arg1}}, {
{"age": bson.M{"$lt": arg2}}, "gender": bson.M{
"$ne": arg1,
},
},
{
"age": bson.M{
"$lt": arg2,
},
},
}, },
}, options.Find().SetSort(bson.M{})) }, options.Find().SetSort(bson.M{}))
if err != nil { if err != nil {
@ -49,7 +57,9 @@ func (r *UserRepositoryMongo) FindByGenderNotAndAgeLessThan(arg0 context.Context
func (r *UserRepositoryMongo) FindByAgeLessThanEqualOrderByAge(arg0 context.Context, arg1 int) ([]*UserModel, error) { func (r *UserRepositoryMongo) FindByAgeLessThanEqualOrderByAge(arg0 context.Context, arg1 int) ([]*UserModel, error) {
cursor, err := r.collection.Find(arg0, bson.M{ cursor, err := r.collection.Find(arg0, bson.M{
"age": bson.M{"$lte": arg1}, "age": bson.M{
"$lte": arg1,
},
}, options.Find().SetSort(bson.M{ }, options.Find().SetSort(bson.M{
"age": 1, "age": 1,
})) }))
@ -65,7 +75,9 @@ func (r *UserRepositoryMongo) FindByAgeLessThanEqualOrderByAge(arg0 context.Cont
func (r *UserRepositoryMongo) FindByAgeGreaterThanOrderByAgeAsc(arg0 context.Context, arg1 int) ([]*UserModel, error) { func (r *UserRepositoryMongo) FindByAgeGreaterThanOrderByAgeAsc(arg0 context.Context, arg1 int) ([]*UserModel, error) {
cursor, err := r.collection.Find(arg0, bson.M{ cursor, err := r.collection.Find(arg0, bson.M{
"age": bson.M{"$gt": arg1}, "age": bson.M{
"$gt": arg1,
},
}, options.Find().SetSort(bson.M{ }, options.Find().SetSort(bson.M{
"age": 1, "age": 1,
})) }))
@ -81,7 +93,9 @@ func (r *UserRepositoryMongo) FindByAgeGreaterThanOrderByAgeAsc(arg0 context.Con
func (r *UserRepositoryMongo) FindByAgeGreaterThanEqualOrderByAgeDesc(arg0 context.Context, arg1 int) ([]*UserModel, error) { func (r *UserRepositoryMongo) FindByAgeGreaterThanEqualOrderByAgeDesc(arg0 context.Context, arg1 int) ([]*UserModel, error) {
cursor, err := r.collection.Find(arg0, bson.M{ cursor, err := r.collection.Find(arg0, bson.M{
"age": bson.M{"$gte": arg1}, "age": bson.M{
"$gte": arg1,
},
}, options.Find().SetSort(bson.M{ }, options.Find().SetSort(bson.M{
"age": -1, "age": -1,
})) }))
@ -97,7 +111,10 @@ func (r *UserRepositoryMongo) FindByAgeGreaterThanEqualOrderByAgeDesc(arg0 conte
func (r *UserRepositoryMongo) FindByAgeBetween(arg0 context.Context, arg1 int, arg2 int) ([]*UserModel, error) { func (r *UserRepositoryMongo) FindByAgeBetween(arg0 context.Context, arg1 int, arg2 int) ([]*UserModel, error) {
cursor, err := r.collection.Find(arg0, bson.M{ cursor, err := r.collection.Find(arg0, bson.M{
"age": bson.M{"$gte": arg1, "$lte": arg2}, "age": bson.M{
"$gte": arg1,
"$lte": arg2,
},
}, options.Find().SetSort(bson.M{})) }, options.Find().SetSort(bson.M{}))
if err != nil { if err != nil {
return nil, err return nil, err
@ -112,8 +129,12 @@ func (r *UserRepositoryMongo) FindByAgeBetween(arg0 context.Context, arg1 int, a
func (r *UserRepositoryMongo) FindByGenderOrAge(arg0 context.Context, arg1 Gender, arg2 int) ([]*UserModel, error) { func (r *UserRepositoryMongo) FindByGenderOrAge(arg0 context.Context, arg1 Gender, arg2 int) ([]*UserModel, error) {
cursor, err := r.collection.Find(arg0, bson.M{ cursor, err := r.collection.Find(arg0, bson.M{
"$or": []bson.M{ "$or": []bson.M{
{"gender": arg1}, {
{"age": arg2}, "gender": arg1,
},
{
"age": arg2,
},
}, },
}, options.Find().SetSort(bson.M{})) }, options.Find().SetSort(bson.M{}))
if err != nil { if err != nil {