From bdb63c8129f222bebf7f9eef3107096ddec743be Mon Sep 17 00:00:00 2001 From: sunboyy Date: Tue, 18 Apr 2023 20:21:46 +0700 Subject: [PATCH] Move function body generation into codegen package (#35) * Move function body generation into codegen package --- internal/codegen/body.go | 264 ++++ internal/codegen/body_test.go | 305 ++++ internal/codegen/builder_test.go | 36 +- internal/codegen/function.go | 4 +- internal/codegen/function_test.go | 48 +- internal/codegen/method.go | 4 +- internal/codegen/method_test.go | 48 +- internal/mongo/common.go | 94 ++ internal/mongo/count.go | 39 + internal/mongo/count_test.go | 437 ++++++ internal/mongo/delete.go | 84 ++ internal/mongo/delete_test.go | 477 ++++++ internal/mongo/find.go | 160 +++ internal/mongo/find_test.go | 886 ++++++++++++ internal/mongo/generator.go | 248 +--- internal/mongo/generator_test.go | 2231 +---------------------------- internal/mongo/insert.go | 81 ++ internal/mongo/insert_test.go | 123 ++ internal/mongo/models.go | 181 ++- internal/mongo/templates.go | 128 -- internal/mongo/update.go | 133 ++ internal/mongo/update_test.go | 389 +++++ test/generator_test_expected.txt | 37 +- 23 files changed, 3795 insertions(+), 2642 deletions(-) create mode 100644 internal/codegen/body.go create mode 100644 internal/codegen/body_test.go create mode 100644 internal/mongo/common.go create mode 100644 internal/mongo/count.go create mode 100644 internal/mongo/count_test.go create mode 100644 internal/mongo/delete.go create mode 100644 internal/mongo/delete_test.go create mode 100644 internal/mongo/find.go create mode 100644 internal/mongo/find_test.go create mode 100644 internal/mongo/insert.go create mode 100644 internal/mongo/insert_test.go delete mode 100644 internal/mongo/templates.go create mode 100644 internal/mongo/update.go create mode 100644 internal/mongo/update_test.go diff --git a/internal/codegen/body.go b/internal/codegen/body.go new file mode 100644 index 0000000..e6ba09e --- /dev/null +++ b/internal/codegen/body.go @@ -0,0 +1,264 @@ +package codegen + +import ( + "fmt" + "strings" + + "github.com/sunboyy/repogen/internal/code" +) + +type FunctionBody []Statement + +func (b FunctionBody) Code() string { + var lines []string + for _, statement := range b { + stmtLines := statement.CodeLines() + for _, line := range stmtLines { + lines = append(lines, fmt.Sprintf("\t%s", line)) + } + } + return strings.Join(lines, "\n") +} + +type Statement interface { + CodeLines() []string +} + +type RawStatement string + +func (stmt RawStatement) CodeLines() []string { + return []string{string(stmt)} +} + +type Identifier string + +func (id Identifier) CodeLines() []string { + return []string{string(id)} +} + +type DeclStatement struct { + Name string + Type code.Type +} + +func (stmt DeclStatement) CodeLines() []string { + return []string{fmt.Sprintf("var %s %s", stmt.Name, stmt.Type.Code())} +} + +type DeclAssignStatement struct { + Vars []string + Values StatementList +} + +func (stmt DeclAssignStatement) CodeLines() []string { + vars := strings.Join(stmt.Vars, ", ") + lines := stmt.Values.CodeLines() + lines[0] = fmt.Sprintf("%s := %s", vars, lines[0]) + return lines +} + +type AssignStatement struct { + Vars []string + Values StatementList +} + +func (stmt AssignStatement) CodeLines() []string { + vars := strings.Join(stmt.Vars, ", ") + lines := stmt.Values.CodeLines() + lines[0] = fmt.Sprintf("%s = %s", vars, lines[0]) + return lines +} + +type StatementList []Statement + +func (l StatementList) CodeLines() []string { + if len(l) == 0 { + return []string{""} + } + return concatenateStatements(", ", []Statement(l)) +} + +type ReturnStatement StatementList + +func (stmt ReturnStatement) CodeLines() []string { + lines := StatementList(stmt).CodeLines() + lines[0] = fmt.Sprintf("return %s", lines[0]) + return lines +} + +type ChainStatement []Statement + +func (stmt ChainStatement) CodeLines() []string { + return concatenateStatements(".", []Statement(stmt)) +} + +type CallStatement struct { + FuncName string + Params StatementList +} + +func (stmt CallStatement) CodeLines() []string { + lines := stmt.Params.CodeLines() + lines[0] = fmt.Sprintf("%s(%s", stmt.FuncName, lines[0]) + lines[len(lines)-1] += ")" + return lines +} + +type SliceStatement struct { + Type code.Type + Values []Statement +} + +func (stmt SliceStatement) CodeLines() []string { + lines := []string{stmt.Type.Code() + "{"} + for _, value := range stmt.Values { + stmtLines := value.CodeLines() + stmtLines[len(stmtLines)-1] += "," + for _, line := range stmtLines { + lines = append(lines, "\t"+line) + } + } + lines = append(lines, "}") + return lines +} + +type MapStatement struct { + Type string + Pairs []MapPair +} + +func (stmt MapStatement) CodeLines() []string { + return generateCollectionCodeLines(stmt.Type, stmt.Pairs) +} + +type MapPair struct { + Key string + Value Statement +} + +func (p MapPair) ItemCodeLines() []string { + lines := p.Value.CodeLines() + lines[0] = fmt.Sprintf(`"%s": %s`, p.Key, lines[0]) + return lines +} + +type StructStatement struct { + Type string + Pairs []StructFieldPair +} + +func (stmt StructStatement) CodeLines() []string { + return generateCollectionCodeLines(stmt.Type, stmt.Pairs) +} + +type StructFieldPair struct { + Key string + Value Statement +} + +func (p StructFieldPair) ItemCodeLines() []string { + lines := p.Value.CodeLines() + lines[0] = fmt.Sprintf(`%s: %s`, p.Key, lines[0]) + return lines +} + +type collectionItem interface { + ItemCodeLines() []string +} + +func generateCollectionCodeLines[T collectionItem](typ string, pairs []T) []string { + lines := []string{fmt.Sprintf("%s{", typ)} + for _, pair := range pairs { + pairLines := pair.ItemCodeLines() + pairLines[len(pairLines)-1] += "," + for _, line := range pairLines { + lines = append(lines, fmt.Sprintf("\t%s", line)) + } + } + lines = append(lines, "}") + return lines +} + +type RawBlock struct { + Header []string + Statements []Statement +} + +func (b RawBlock) CodeLines() []string { + lines := make([]string, len(b.Header)) + copy(lines, b.Header) + lines[len(lines)-1] += " {" + for _, stmt := range b.Statements { + stmtLines := stmt.CodeLines() + for _, line := range stmtLines { + lines = append(lines, fmt.Sprintf("\t%s", line)) + } + } + lines = append(lines, "}") + return lines +} + +type IfBlock struct { + Condition []Statement + Statements []Statement +} + +func (b IfBlock) CodeLines() []string { + conditionCode := concatenateStatements("; ", b.Condition) + conditionCode[0] = "if " + conditionCode[0] + + return RawBlock{ + Header: conditionCode, + Statements: b.Statements, + }.CodeLines() +} + +func concatenateStatements(sep string, statements []Statement) []string { + var lines []string + lastLine := "" + for _, stmt := range statements { + stmtLines := stmt.CodeLines() + + if lastLine != "" { + lastLine += sep + } + lastLine += stmtLines[0] + + if len(stmtLines) > 1 { + lines = append(lines, lastLine) + lines = append(lines, stmtLines[1:len(stmtLines)-1]...) + lastLine = stmtLines[len(stmtLines)-1] + } + } + + if lastLine != "" { + lines = append(lines, lastLine) + } + + return lines +} + +type ChainBuilder []Statement + +func NewChainBuilder(object string) ChainBuilder { + return ChainBuilder{ + Identifier(object), + } +} + +func (b ChainBuilder) Chain(field string) ChainBuilder { + b = append(b, Identifier(field)) + return b +} + +func (b ChainBuilder) Call(method string, params ...Statement) ChainBuilder { + b = append(b, CallStatement{ + FuncName: method, + Params: params, + }) + return b +} + +func (b ChainBuilder) Build() ChainStatement { + return ChainStatement(b) +} diff --git a/internal/codegen/body_test.go b/internal/codegen/body_test.go new file mode 100644 index 0000000..be0c32f --- /dev/null +++ b/internal/codegen/body_test.go @@ -0,0 +1,305 @@ +package codegen_test + +import ( + "reflect" + "testing" + + "github.com/sunboyy/repogen/internal/code" + "github.com/sunboyy/repogen/internal/codegen" +) + +func TestIdentifier(t *testing.T) { + identifier := codegen.Identifier("user") + expected := []string{"user"} + + actual := identifier.CodeLines() + + if !reflect.DeepEqual(expected, actual) { + t.Errorf("expected=%+v actual=%+v", expected, actual) + } +} + +func TestDeclStatement(t *testing.T) { + stmt := codegen.DeclStatement{ + Name: "arrs", + Type: code.ArrayType{ContainedType: code.SimpleType("int")}, + } + expected := []string{"var arrs []int"} + + actual := stmt.CodeLines() + + if !reflect.DeepEqual(expected, actual) { + t.Errorf("expected=%+v actual=%+v", expected, actual) + } +} + +func TestDeclAssignStatement(t *testing.T) { + stmt := codegen.DeclAssignStatement{ + Vars: []string{"value", "err"}, + Values: codegen.StatementList{ + codegen.Identifier("1"), + codegen.Identifier("nil"), + }, + } + expected := []string{"value, err := 1, nil"} + + actual := stmt.CodeLines() + + if !reflect.DeepEqual(expected, actual) { + t.Errorf("expected=%+v actual=%+v", expected, actual) + } +} + +func TestAssignStatement(t *testing.T) { + stmt := codegen.AssignStatement{ + Vars: []string{"value", "err"}, + Values: codegen.StatementList{ + codegen.Identifier("1"), + codegen.Identifier("nil"), + }, + } + expected := []string{"value, err = 1, nil"} + + actual := stmt.CodeLines() + + if !reflect.DeepEqual(expected, actual) { + t.Errorf("expected=%+v actual=%+v", expected, actual) + } +} + +func TestReturnStatement(t *testing.T) { + stmt := codegen.ReturnStatement{ + codegen.Identifier("result"), + codegen.Identifier("nil"), + } + expected := []string{"return result, nil"} + + actual := stmt.CodeLines() + + if !reflect.DeepEqual(expected, actual) { + t.Errorf("expected=%+v actual=%+v", expected, actual) + } +} + +func TestChainStatement(t *testing.T) { + stmt := codegen.ChainStatement{ + codegen.Identifier("r"), + codegen.Identifier("userRepository"), + codegen.CallStatement{ + FuncName: "Insert", + Params: codegen.StatementList{ + codegen.StructStatement{ + Type: "User", + Pairs: []codegen.StructFieldPair{ + { + Key: "ID", + Value: codegen.Identifier("arg0"), + }, + { + Key: "Name", + Value: codegen.Identifier("arg1"), + }, + }, + }, + }, + }, + codegen.CallStatement{ + FuncName: "Do", + }, + } + expected := []string{ + "r.userRepository.Insert(User{", + " ID: arg0,", + " Name: arg1,", + "}).Do()", + } + + actual := stmt.CodeLines() + + if !reflect.DeepEqual(expected, actual) { + t.Errorf("expected=%+v actual=%+v", expected, actual) + } +} + +func TestCallStatement(t *testing.T) { + stmt := codegen.CallStatement{ + FuncName: "FindByID", + Params: codegen.StatementList{ + codegen.Identifier("ctx"), + codegen.Identifier("user"), + }, + } + expected := []string{"FindByID(ctx, user)"} + + actual := stmt.CodeLines() + + if !reflect.DeepEqual(expected, actual) { + t.Errorf("expected=%+v actual=%+v", expected, actual) + } +} + +func TestSliceStatement(t *testing.T) { + stmt := codegen.SliceStatement{ + Type: code.ArrayType{ + ContainedType: code.SimpleType("string"), + }, + Values: []codegen.Statement{ + codegen.Identifier(`"hello"`), + codegen.ChainStatement{ + codegen.CallStatement{ + FuncName: "GetUser", + Params: codegen.StatementList{ + codegen.Identifier("userID"), + }, + }, + codegen.Identifier("Name"), + }, + }, + } + expected := []string{ + "[]string{", + ` "hello",`, + ` GetUser(userID).Name,`, + "}", + } + + actual := stmt.CodeLines() + + if !reflect.DeepEqual(expected, actual) { + t.Errorf("expected=%+v actual=%+v", expected, actual) + } +} + +func TestMapStatement(t *testing.T) { + stmt := codegen.MapStatement{ + Type: "map[string]int", + Pairs: []codegen.MapPair{ + { + Key: "key1", + Value: codegen.Identifier("value1"), + }, + { + Key: "key2", + Value: codegen.Identifier("value2"), + }, + }, + } + expected := []string{ + "map[string]int{", + ` "key1": value1,`, + ` "key2": value2,`, + "}", + } + + actual := stmt.CodeLines() + + if !reflect.DeepEqual(expected, actual) { + t.Errorf("expected=%+v actual=%+v", expected, actual) + } +} + +func TestStructStatement(t *testing.T) { + stmt := codegen.StructStatement{ + Type: "User", + Pairs: []codegen.StructFieldPair{ + { + Key: "ID", + Value: codegen.Identifier("arg0"), + }, + { + Key: "Name", + Value: codegen.Identifier("arg1"), + }, + }, + } + expected := []string{ + "User{", + ` ID: arg0,`, + ` Name: arg1,`, + "}", + } + + actual := stmt.CodeLines() + + if !reflect.DeepEqual(expected, actual) { + t.Errorf("expected=%+v actual=%+v", expected, actual) + } +} + +func TestIfBlockStatement(t *testing.T) { + stmt := codegen.IfBlock{ + Condition: []codegen.Statement{ + codegen.DeclAssignStatement{ + Vars: []string{"err"}, + Values: codegen.StatementList{ + codegen.CallStatement{ + FuncName: "Insert", + Params: codegen.StatementList{ + codegen.Identifier("ctx"), + codegen.StructStatement{ + Type: "User", + Pairs: []codegen.StructFieldPair{ + { + Key: "ID", + Value: codegen.Identifier("id"), + }, + { + Key: "Name", + Value: codegen.Identifier("name"), + }, + }, + }, + }, + }, + }, + }, + codegen.RawStatement("err != nil"), + }, + Statements: []codegen.Statement{ + codegen.ReturnStatement{ + codegen.Identifier("nil"), + codegen.Identifier("err"), + }, + }, + } + expected := []string{ + "if err := Insert(ctx, User{", + " ID: id,", + " Name: name,", + "}); err != nil {", + " return nil, err", + "}", + } + + actual := stmt.CodeLines() + + if !reflect.DeepEqual(expected, actual) { + t.Errorf("expected=%+v actual=%+v", expected, actual) + } +} + +func TestChainBuilder(t *testing.T) { + expected := codegen.ChainStatement{ + codegen.Identifier("r"), + codegen.Identifier("repository"), + codegen.CallStatement{ + FuncName: "Find", + Params: []codegen.Statement{ + codegen.Identifier("ctx"), + }, + }, + codegen.CallStatement{ + FuncName: "Decode", + }, + } + + actual := codegen.NewChainBuilder("r"). + Chain("repository"). + Call("Find", codegen.Identifier("ctx")). + Call("Decode"). + Build() + + if !reflect.DeepEqual(expected, actual) { + t.Errorf("expected=%+v actual=%+v", expected, actual) + } +} diff --git a/internal/codegen/builder_test.go b/internal/codegen/builder_test.go index 8f69ba6..194730e 100644 --- a/internal/codegen/builder_test.go +++ b/internal/codegen/builder_test.go @@ -79,17 +79,43 @@ func TestBuilderBuild(t *testing.T) { {Name: "username", Type: code.TypeString}, }, Returns: []code.Type{code.SimpleType("User")}, - Body: ` return User{ - ID: primitive.NewObjectID(), - Username: username, - }`, + Body: codegen.FunctionBody{ + codegen.ReturnStatement{ + codegen.StructStatement{ + Type: "User", + Pairs: []codegen.StructFieldPair{ + { + Key: "ID", + Value: codegen.ChainStatement{ + codegen.Identifier("primitive"), + codegen.CallStatement{ + FuncName: "NewObjectID", + }, + }, + }, + { + Key: "Username", + Value: codegen.Identifier("username"), + }, + }, + }, + }, + }, }) builder.AddImplementer(codegen.MethodBuilder{ Receiver: codegen.MethodReceiver{Name: "u", Type: code.SimpleType("User")}, Name: "IDHex", Params: nil, Returns: []code.Type{code.TypeString}, - Body: " return u.ID.Hex()", + Body: codegen.FunctionBody{ + codegen.ReturnStatement{ + codegen.ChainStatement{ + codegen.Identifier("u"), + codegen.Identifier("ID"), + codegen.CallStatement{FuncName: "Hex"}, + }, + }, + }, }) generatedCode, err := builder.Build() diff --git a/internal/codegen/function.go b/internal/codegen/function.go index 8f469c3..73d7684 100644 --- a/internal/codegen/function.go +++ b/internal/codegen/function.go @@ -11,7 +11,7 @@ import ( const functionTemplate = ` func {{.Name}}({{.GenParams}}){{.GenReturns}} { -{{.Body}} +{{.Body.Code}} } ` @@ -20,7 +20,7 @@ type FunctionBuilder struct { Name string Params []code.Param Returns []code.Type - Body string + Body FunctionBody } // Impl writes function declatation code to the buffer. diff --git a/internal/codegen/function_test.go b/internal/codegen/function_test.go index e14932b..aa9a4ee 100644 --- a/internal/codegen/function_test.go +++ b/internal/codegen/function_test.go @@ -14,7 +14,20 @@ func TestFunctionBuilderBuild_NoReturn(t *testing.T) { Name: "init", Params: nil, Returns: nil, - Body: ` logrus.SetLevel(logrus.DebugLevel)`, + Body: codegen.FunctionBody{ + codegen.ChainStatement{ + codegen.Identifier("logrus"), + codegen.CallStatement{ + FuncName: "SetLevel", + Params: codegen.StatementList{ + codegen.ChainStatement{ + codegen.Identifier("logrus"), + codegen.Identifier("DebugLevel"), + }, + }, + }, + }, + }, } expectedCode := ` func init() { @@ -57,18 +70,25 @@ func TestFunctionBuilderBuild_OneReturn(t *testing.T) { Returns: []code.Type{ code.SimpleType("User"), }, - Body: ` return User{ - Username: username, - Age: age, - Parent: parent - }`, + Body: codegen.FunctionBody{ + codegen.ReturnStatement{ + codegen.StructStatement{ + Type: "User", + Pairs: []codegen.StructFieldPair{ + {Key: "Username", Value: codegen.Identifier("username")}, + {Key: "Age", Value: codegen.Identifier("age")}, + {Key: "Parent", Value: codegen.Identifier("parent")}, + }, + }, + }, + }, } expectedCode := ` func NewUser(username string, age int, parent *User) User { return User{ Username: username, Age: age, - Parent: parent + Parent: parent, } } ` @@ -101,7 +121,19 @@ func TestFunctionBuilderBuild_MultiReturn(t *testing.T) { code.SimpleType("User"), code.TypeError, }, - Body: ` return collection.Save(user)`, + Body: codegen.FunctionBody{ + codegen.ReturnStatement{ + codegen.ChainStatement{ + codegen.Identifier("collection"), + codegen.CallStatement{ + FuncName: "Save", + Params: codegen.StatementList{ + codegen.Identifier("user"), + }, + }, + }, + }, + }, } expectedCode := ` func Save(user User) (User, error) { diff --git a/internal/codegen/method.go b/internal/codegen/method.go index 874aa87..8730aa1 100644 --- a/internal/codegen/method.go +++ b/internal/codegen/method.go @@ -10,7 +10,7 @@ import ( const methodTemplate = ` func ({{.GenReceiver}}) {{.Name}}({{.GenParams}}){{.GenReturns}} { -{{.Body}} +{{.Body.Code}} } ` @@ -20,7 +20,7 @@ type MethodBuilder struct { Name string Params []code.Param Returns []code.Type - Body string + Body FunctionBody } // MethodReceiver describes a specification of a method receiver. diff --git a/internal/codegen/method_test.go b/internal/codegen/method_test.go index db73ed3..fb0a8d9 100644 --- a/internal/codegen/method_test.go +++ b/internal/codegen/method_test.go @@ -15,7 +15,17 @@ func TestMethodBuilderBuild_IgnoreReceiverNoReturn(t *testing.T) { Name: "Init", Params: nil, Returns: nil, - Body: ` db.Init(&User{})`, + Body: codegen.FunctionBody{ + codegen.ChainStatement{ + codegen.Identifier("db"), + codegen.CallStatement{ + FuncName: "Init", + Params: codegen.StatementList{ + codegen.RawStatement("&User{}"), + }, + }, + }, + }, } expectedCode := ` func (User) Init() { @@ -47,7 +57,19 @@ func TestMethodBuilderBuild_IgnorePoinerReceiverOneReturn(t *testing.T) { Name: "Init", Params: nil, Returns: []code.Type{code.TypeError}, - Body: ` return db.Init(&User{})`, + Body: codegen.FunctionBody{ + codegen.ReturnStatement{ + codegen.ChainStatement{ + codegen.Identifier("db"), + codegen.CallStatement{ + FuncName: "Init", + Params: codegen.StatementList{ + codegen.RawStatement("&User{}"), + }, + }, + }, + }, + }, } expectedCode := ` func (*User) Init() error { @@ -81,8 +103,17 @@ func TestMethodBuilderBuild_UseReceiverMultiReturn(t *testing.T) { {Name: "age", Type: code.TypeInt}, }, Returns: []code.Type{code.SimpleType("User"), code.TypeError}, - Body: ` u.Age = age - return u`, + Body: codegen.FunctionBody{ + codegen.AssignStatement{ + Vars: []string{"u.Age"}, + Values: codegen.StatementList{ + codegen.Identifier("age"), + }, + }, + codegen.ReturnStatement{ + codegen.Identifier("u"), + }, + }, } expectedCode := ` func (u User) WithAge(age int) (User, error) { @@ -118,7 +149,14 @@ func TestMethodBuilderBuild_UsePointerReceiverNoReturn(t *testing.T) { {Name: "age", Type: code.TypeInt}, }, Returns: nil, - Body: ` u.Age = age`, + Body: codegen.FunctionBody{ + codegen.AssignStatement{ + Vars: []string{"u.Age"}, + Values: codegen.StatementList{ + codegen.Identifier("age"), + }, + }, + }, } expectedCode := ` func (u *User) SetAge(age int) { diff --git a/internal/mongo/common.go b/internal/mongo/common.go new file mode 100644 index 0000000..23ecef4 --- /dev/null +++ b/internal/mongo/common.go @@ -0,0 +1,94 @@ +package mongo + +import ( + "strings" + + "github.com/sunboyy/repogen/internal/code" + "github.com/sunboyy/repogen/internal/codegen" + "github.com/sunboyy/repogen/internal/spec" +) + +var returnNilErr = codegen.ReturnStatement{ + codegen.Identifier("nil"), + codegen.Identifier("err"), +} + +var ifErrReturnNilErr = codegen.IfBlock{ + Condition: []codegen.Statement{ + codegen.RawStatement("err != nil"), + }, + Statements: []codegen.Statement{ + returnNilErr, + }, +} + +var ifErrReturn0Err = codegen.IfBlock{ + Condition: []codegen.Statement{ + codegen.RawStatement("err != nil"), + }, + Statements: []codegen.Statement{ + codegen.ReturnStatement{ + codegen.Identifier("0"), + codegen.Identifier("err"), + }, + }, +} + +var ifErrReturnFalseErr = codegen.IfBlock{ + Condition: []codegen.Statement{ + codegen.RawStatement("err != nil"), + }, + Statements: []codegen.Statement{ + codegen.ReturnStatement{ + codegen.Identifier("false"), + codegen.Identifier("err"), + }, + }, +} + +type baseMethodGenerator struct { + structModel code.Struct +} + +func (g baseMethodGenerator) bsonFieldReference(fieldReference spec.FieldReference) (string, error) { + var bsonTags []string + for _, field := range fieldReference { + tag, err := g.bsonTagFromField(field) + if err != nil { + return "", err + } + bsonTags = append(bsonTags, tag) + } + return strings.Join(bsonTags, "."), nil +} + +func (g baseMethodGenerator) bsonTagFromField(field code.StructField) (string, error) { + bsonTag, ok := field.Tags["bson"] + if !ok { + return "", NewBsonTagNotFoundError(field.Name) + } + + return bsonTag[0], nil +} + +func (g baseMethodGenerator) convertQuerySpec(query spec.QuerySpec) (querySpec, error) { + var predicates []predicate + + for _, predicateSpec := range query.Predicates { + bsonFieldReference, err := g.bsonFieldReference(predicateSpec.FieldReference) + if err != nil { + return querySpec{}, err + } + + predicates = append(predicates, predicate{ + Field: bsonFieldReference, + Comparator: predicateSpec.Comparator, + ParamIndex: predicateSpec.ParamIndex, + }) + } + + return querySpec{ + Operator: query.Operator, + Predicates: predicates, + }, nil +} diff --git a/internal/mongo/count.go b/internal/mongo/count.go new file mode 100644 index 0000000..25898e2 --- /dev/null +++ b/internal/mongo/count.go @@ -0,0 +1,39 @@ +package mongo + +import ( + "github.com/sunboyy/repogen/internal/codegen" + "github.com/sunboyy/repogen/internal/spec" +) + +func (g RepositoryGenerator) generateCountBody( + operation spec.CountOperation) (codegen.FunctionBody, error) { + + querySpec, err := g.convertQuerySpec(operation.Query) + if err != nil { + return nil, err + } + + return codegen.FunctionBody{ + codegen.DeclAssignStatement{ + Vars: []string{"count", "err"}, + Values: codegen.StatementList{ + codegen.NewChainBuilder("r"). + Chain("collection"). + Call("CountDocuments", + codegen.Identifier("arg0"), + querySpec.Code(), + ).Build(), + }, + }, + ifErrReturn0Err, + codegen.ReturnStatement{ + codegen.CallStatement{ + FuncName: "int", + Params: codegen.StatementList{ + codegen.Identifier("count"), + }, + }, + codegen.Identifier("nil"), + }, + }, nil +} diff --git a/internal/mongo/count_test.go b/internal/mongo/count_test.go new file mode 100644 index 0000000..87c9c0c --- /dev/null +++ b/internal/mongo/count_test.go @@ -0,0 +1,437 @@ +package mongo_test + +import ( + "fmt" + "reflect" + "testing" + + "github.com/sunboyy/repogen/internal/code" + "github.com/sunboyy/repogen/internal/codegen" + "github.com/sunboyy/repogen/internal/mongo" + "github.com/sunboyy/repogen/internal/spec" + "github.com/sunboyy/repogen/internal/testutils" +) + +func TestGenerateMethod_Count(t *testing.T) { + testTable := []GenerateMethodTestCase{ + { + Name: "simple count method", + MethodSpec: spec.MethodSpec{ + Name: "CountByGender", + Params: []code.Param{ + {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Type: code.SimpleType("Gender")}, + }, + Returns: []code.Type{ + code.TypeInt, + code.TypeError, + }, + Operation: spec.CountOperation{ + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + { + FieldReference: spec.FieldReference{genderField}, + Comparator: spec.ComparatorEqual, + ParamIndex: 1, + }, + }, + }, + }, + }, + ExpectedBody: ` count, err := r.collection.CountDocuments(arg0, bson.M{ + "gender": arg1, + }) + if err != nil { + return 0, err + } + return int(count), nil`, + }, + { + Name: "count with And operator", + MethodSpec: spec.MethodSpec{ + Name: "CountByGenderAndCity", + Params: []code.Param{ + {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Type: code.SimpleType("Gender")}, + {Type: code.TypeInt}, + }, + Returns: []code.Type{ + code.TypeInt, + code.TypeError, + }, + Operation: spec.CountOperation{ + Query: spec.QuerySpec{ + Operator: spec.OperatorAnd, + Predicates: []spec.Predicate{ + { + FieldReference: spec.FieldReference{genderField}, + Comparator: spec.ComparatorEqual, + ParamIndex: 1, + }, + { + FieldReference: spec.FieldReference{ageField}, + Comparator: spec.ComparatorEqual, + ParamIndex: 2, + }, + }, + }, + }, + }, + ExpectedBody: ` count, err := r.collection.CountDocuments(arg0, bson.M{ + "$and": []bson.M{ + { + "gender": arg1, + }, + { + "age": arg2, + }, + }, + }) + if err != nil { + return 0, err + } + return int(count), nil`, + }, + { + Name: "count with Or operator", + MethodSpec: spec.MethodSpec{ + Name: "CountByGenderOrCity", + Params: []code.Param{ + {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Type: code.SimpleType("Gender")}, + {Type: code.TypeInt}, + }, + Returns: []code.Type{ + code.TypeInt, + code.TypeError, + }, + Operation: spec.CountOperation{ + Query: spec.QuerySpec{ + Operator: spec.OperatorOr, + Predicates: []spec.Predicate{ + { + FieldReference: spec.FieldReference{genderField}, + Comparator: spec.ComparatorEqual, + ParamIndex: 1, + }, + { + FieldReference: spec.FieldReference{ageField}, + Comparator: spec.ComparatorEqual, + ParamIndex: 2, + }, + }, + }, + }, + }, + ExpectedBody: ` count, err := r.collection.CountDocuments(arg0, bson.M{ + "$or": []bson.M{ + { + "gender": arg1, + }, + { + "age": arg2, + }, + }, + }) + if err != nil { + return 0, err + } + return int(count), nil`, + }, + { + Name: "count with Not comparator", + MethodSpec: spec.MethodSpec{ + Name: "CountByGenderNot", + Params: []code.Param{ + {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Type: code.SimpleType("Gender")}, + }, + Returns: []code.Type{ + code.TypeInt, + code.TypeError, + }, + Operation: spec.CountOperation{ + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + { + FieldReference: spec.FieldReference{genderField}, + Comparator: spec.ComparatorNot, + ParamIndex: 1, + }, + }, + }, + }, + }, + ExpectedBody: ` count, err := r.collection.CountDocuments(arg0, bson.M{ + "gender": bson.M{ + "$ne": arg1, + }, + }) + if err != nil { + return 0, err + } + return int(count), nil`, + }, + { + Name: "count with LessThan comparator", + MethodSpec: spec.MethodSpec{ + Name: "CountByAgeLessThan", + Params: []code.Param{ + {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Type: code.TypeInt}, + }, + Returns: []code.Type{ + code.TypeInt, + code.TypeError, + }, + Operation: spec.CountOperation{ + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + { + FieldReference: spec.FieldReference{ageField}, + Comparator: spec.ComparatorLessThan, + ParamIndex: 1, + }, + }, + }, + }, + }, + ExpectedBody: ` count, err := r.collection.CountDocuments(arg0, bson.M{ + "age": bson.M{ + "$lt": arg1, + }, + }) + if err != nil { + return 0, err + } + return int(count), nil`, + }, + { + Name: "count with LessThanEqual comparator", + MethodSpec: spec.MethodSpec{ + Name: "CountByAgeLessThanEqual", + Params: []code.Param{ + {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Type: code.TypeInt}, + }, + Returns: []code.Type{ + code.TypeInt, + code.TypeError, + }, + Operation: spec.CountOperation{ + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + { + FieldReference: spec.FieldReference{ageField}, + Comparator: spec.ComparatorLessThanEqual, + ParamIndex: 1, + }, + }, + }, + }, + }, + ExpectedBody: ` count, err := r.collection.CountDocuments(arg0, bson.M{ + "age": bson.M{ + "$lte": arg1, + }, + }) + if err != nil { + return 0, err + } + return int(count), nil`, + }, + { + Name: "count with GreaterThan comparator", + MethodSpec: spec.MethodSpec{ + Name: "CountByAgeGreaterThan", + Params: []code.Param{ + {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Type: code.TypeInt}, + }, + Returns: []code.Type{ + code.TypeInt, + code.TypeError, + }, + Operation: spec.CountOperation{ + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + { + FieldReference: spec.FieldReference{ageField}, + Comparator: spec.ComparatorGreaterThan, + ParamIndex: 1, + }, + }, + }, + }, + }, + ExpectedBody: ` count, err := r.collection.CountDocuments(arg0, bson.M{ + "age": bson.M{ + "$gt": arg1, + }, + }) + if err != nil { + return 0, err + } + return int(count), nil`, + }, + { + Name: "count with GreaterThanEqual comparator", + MethodSpec: spec.MethodSpec{ + Name: "CountByAgeGreaterThanEqual", + Params: []code.Param{ + {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Type: code.TypeInt}, + }, + Returns: []code.Type{ + code.TypeInt, + code.TypeError, + }, + Operation: spec.CountOperation{ + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + { + FieldReference: spec.FieldReference{ageField}, + Comparator: spec.ComparatorGreaterThanEqual, + ParamIndex: 1, + }, + }, + }, + }, + }, + ExpectedBody: ` count, err := r.collection.CountDocuments(arg0, bson.M{ + "age": bson.M{ + "$gte": arg1, + }, + }) + if err != nil { + return 0, err + } + return int(count), nil`, + }, + { + Name: "count with Between comparator", + MethodSpec: spec.MethodSpec{ + Name: "CountByAgeBetween", + Params: []code.Param{ + {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Type: code.TypeInt}, + {Type: code.TypeInt}, + }, + Returns: []code.Type{ + code.TypeInt, + code.TypeError, + }, + Operation: spec.CountOperation{ + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + { + FieldReference: spec.FieldReference{ageField}, + Comparator: spec.ComparatorBetween, + ParamIndex: 1, + }, + }, + }, + }, + }, + ExpectedBody: ` count, err := r.collection.CountDocuments(arg0, bson.M{ + "age": bson.M{ + "$gte": arg1, + "$lte": arg2, + }, + }) + if err != nil { + return 0, err + } + return int(count), nil`, + }, + { + Name: "count with In comparator", + MethodSpec: spec.MethodSpec{ + Name: "CountByAgeIn", + Params: []code.Param{ + {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Type: code.ArrayType{ContainedType: code.TypeInt}}, + }, + Returns: []code.Type{ + code.TypeInt, + code.TypeError, + }, + Operation: spec.CountOperation{ + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + { + FieldReference: spec.FieldReference{ageField}, + Comparator: spec.ComparatorIn, + ParamIndex: 1, + }, + }, + }, + }, + }, + ExpectedBody: ` count, err := r.collection.CountDocuments(arg0, bson.M{ + "age": bson.M{ + "$in": arg1, + }, + }) + if err != nil { + return 0, err + } + return int(count), nil`, + }, + } + + for _, testCase := range testTable { + t.Run(testCase.Name, func(t *testing.T) { + generator := mongo.NewGenerator(userModel, "UserRepository") + expectedReceiver := codegen.MethodReceiver{ + Name: "r", + Type: "UserRepositoryMongo", + Pointer: true, + } + var expectedParams []code.Param + for i, param := range testCase.MethodSpec.Params { + expectedParams = append(expectedParams, code.Param{ + Name: fmt.Sprintf("arg%d", i), + Type: param.Type, + }) + } + + actual, err := generator.GenerateMethod(testCase.MethodSpec) + + if err != nil { + t.Fatal(err) + } + if expectedReceiver != actual.Receiver { + t.Errorf( + "incorrect method receiver: expected %+v, got %+v", + expectedReceiver, + actual.Receiver, + ) + } + if testCase.MethodSpec.Name != actual.Name { + t.Errorf( + "incorrect method name: expected %s, got %s", + testCase.MethodSpec.Name, + actual.Name, + ) + } + if !reflect.DeepEqual(expectedParams, actual.Params) { + t.Errorf( + "incorrect struct params: expected %+v, got %+v", + expectedParams, + actual.Params, + ) + } + if !reflect.DeepEqual(testCase.MethodSpec.Returns, actual.Returns) { + t.Errorf( + "incorrect struct returns: expected %+v, got %+v", + testCase.MethodSpec.Returns, + actual.Returns, + ) + } + if err := testutils.ExpectMultiLineString(testCase.ExpectedBody, actual.Body.Code()); err != nil { + t.Error(err) + } + }) + } +} diff --git a/internal/mongo/delete.go b/internal/mongo/delete.go new file mode 100644 index 0000000..aa8e1b7 --- /dev/null +++ b/internal/mongo/delete.go @@ -0,0 +1,84 @@ +package mongo + +import ( + "github.com/sunboyy/repogen/internal/codegen" + "github.com/sunboyy/repogen/internal/spec" +) + +func (g RepositoryGenerator) generateDeleteBody( + operation spec.DeleteOperation) (codegen.FunctionBody, error) { + + return deleteBodyGenerator{ + baseMethodGenerator: g.baseMethodGenerator, + operation: operation, + }.generate() +} + +type deleteBodyGenerator struct { + baseMethodGenerator + operation spec.DeleteOperation +} + +func (g deleteBodyGenerator) generate() (codegen.FunctionBody, error) { + querySpec, err := g.convertQuerySpec(g.operation.Query) + if err != nil { + return nil, err + } + + if g.operation.Mode == spec.QueryModeOne { + return g.generateDeleteOneBody(querySpec), nil + } + + return g.generateDeleteManyBody(querySpec), nil +} + +func (g deleteBodyGenerator) generateDeleteOneBody( + querySpec querySpec) codegen.FunctionBody { + + return codegen.FunctionBody{ + codegen.DeclAssignStatement{ + Vars: []string{"result", "err"}, + Values: codegen.StatementList{ + codegen.NewChainBuilder("r"). + Chain("collection"). + Call("DeleteOne", + codegen.Identifier("arg0"), + querySpec.Code(), + ).Build(), + }, + }, + ifErrReturnFalseErr, + codegen.ReturnStatement{ + codegen.RawStatement("result.DeletedCount > 0"), + codegen.Identifier("nil"), + }, + } +} + +func (g deleteBodyGenerator) generateDeleteManyBody( + querySpec querySpec) codegen.FunctionBody { + + return codegen.FunctionBody{ + codegen.DeclAssignStatement{ + Vars: []string{"result", "err"}, + Values: codegen.StatementList{ + codegen.NewChainBuilder("r"). + Chain("collection"). + Call("DeleteMany", + codegen.Identifier("arg0"), + querySpec.Code(), + ).Build(), + }, + }, + ifErrReturn0Err, + codegen.ReturnStatement{ + codegen.CallStatement{ + FuncName: "int", + Params: codegen.StatementList{ + codegen.NewChainBuilder("result").Chain("DeletedCount").Build(), + }, + }, + codegen.Identifier("nil"), + }, + } +} diff --git a/internal/mongo/delete_test.go b/internal/mongo/delete_test.go new file mode 100644 index 0000000..1f16aa6 --- /dev/null +++ b/internal/mongo/delete_test.go @@ -0,0 +1,477 @@ +package mongo_test + +import ( + "fmt" + "reflect" + "testing" + + "github.com/sunboyy/repogen/internal/code" + "github.com/sunboyy/repogen/internal/codegen" + "github.com/sunboyy/repogen/internal/mongo" + "github.com/sunboyy/repogen/internal/spec" + "github.com/sunboyy/repogen/internal/testutils" +) + +func TestGenerateMethod_Delete(t *testing.T) { + testTable := []GenerateMethodTestCase{ + { + Name: "simple delete one method", + MethodSpec: spec.MethodSpec{ + Name: "DeleteByID", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Name: "id", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, + }, + Returns: []code.Type{code.TypeBool, code.TypeError}, + Operation: spec.DeleteOperation{ + Mode: spec.QueryModeOne, + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + { + Comparator: spec.ComparatorEqual, + FieldReference: spec.FieldReference{idField}, + ParamIndex: 1, + }, + }, + }, + }, + }, + ExpectedBody: ` result, err := r.collection.DeleteOne(arg0, bson.M{ + "_id": arg1, + }) + if err != nil { + return false, err + } + return result.DeletedCount > 0, nil`, + }, + { + Name: "simple delete many method", + MethodSpec: spec.MethodSpec{ + Name: "DeleteByGender", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Name: "gender", Type: code.SimpleType("Gender")}, + }, + Returns: []code.Type{ + code.TypeInt, + code.TypeError, + }, + Operation: spec.DeleteOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + { + Comparator: spec.ComparatorEqual, + FieldReference: spec.FieldReference{genderField}, + ParamIndex: 1, + }, + }, + }, + }, + }, + ExpectedBody: ` result, err := r.collection.DeleteMany(arg0, bson.M{ + "gender": arg1, + }) + if err != nil { + return 0, err + } + return int(result.DeletedCount), nil`, + }, + { + Name: "delete with And operator", + MethodSpec: spec.MethodSpec{ + Name: "DeleteByGenderAndAge", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Name: "gender", Type: code.SimpleType("Gender")}, + {Name: "age", Type: code.TypeInt}, + }, + Returns: []code.Type{ + code.TypeInt, + code.TypeError, + }, + Operation: spec.DeleteOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{ + Operator: spec.OperatorAnd, + Predicates: []spec.Predicate{ + { + Comparator: spec.ComparatorEqual, + FieldReference: spec.FieldReference{genderField}, + ParamIndex: 1, + }, + { + Comparator: spec.ComparatorEqual, + FieldReference: spec.FieldReference{ageField}, + ParamIndex: 2, + }, + }, + }, + }, + }, + ExpectedBody: ` result, err := r.collection.DeleteMany(arg0, bson.M{ + "$and": []bson.M{ + { + "gender": arg1, + }, + { + "age": arg2, + }, + }, + }) + if err != nil { + return 0, err + } + return int(result.DeletedCount), nil`, + }, + { + Name: "delete with Or operator", + MethodSpec: spec.MethodSpec{ + Name: "DeleteByGenderOrAge", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Name: "gender", Type: code.SimpleType("Gender")}, + {Name: "age", Type: code.TypeInt}, + }, + Returns: []code.Type{ + code.TypeInt, + code.TypeError, + }, + Operation: spec.DeleteOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{ + Operator: spec.OperatorOr, + Predicates: []spec.Predicate{ + { + Comparator: spec.ComparatorEqual, + FieldReference: spec.FieldReference{genderField}, + ParamIndex: 1, + }, + { + Comparator: spec.ComparatorEqual, + FieldReference: spec.FieldReference{ageField}, + ParamIndex: 2, + }, + }, + }, + }, + }, + ExpectedBody: ` result, err := r.collection.DeleteMany(arg0, bson.M{ + "$or": []bson.M{ + { + "gender": arg1, + }, + { + "age": arg2, + }, + }, + }) + if err != nil { + return 0, err + } + return int(result.DeletedCount), nil`, + }, + { + Name: "delete with Not comparator", + MethodSpec: spec.MethodSpec{ + Name: "DeleteByGenderNot", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Name: "gender", Type: code.SimpleType("Gender")}, + }, + Returns: []code.Type{ + code.TypeInt, + code.TypeError, + }, + Operation: spec.DeleteOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + { + Comparator: spec.ComparatorNot, + FieldReference: spec.FieldReference{genderField}, + ParamIndex: 1, + }, + }, + }, + }, + }, + ExpectedBody: ` result, err := r.collection.DeleteMany(arg0, bson.M{ + "gender": bson.M{ + "$ne": arg1, + }, + }) + if err != nil { + return 0, err + } + return int(result.DeletedCount), nil`, + }, + { + Name: "delete with LessThan comparator", + MethodSpec: spec.MethodSpec{ + Name: "DeleteByAgeLessThan", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Name: "age", Type: code.TypeInt}, + }, + Returns: []code.Type{ + code.TypeInt, + code.TypeError, + }, + Operation: spec.DeleteOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + { + Comparator: spec.ComparatorLessThan, + FieldReference: spec.FieldReference{ageField}, + ParamIndex: 1, + }, + }, + }, + }, + }, + ExpectedBody: ` result, err := r.collection.DeleteMany(arg0, bson.M{ + "age": bson.M{ + "$lt": arg1, + }, + }) + if err != nil { + return 0, err + } + return int(result.DeletedCount), nil`, + }, + { + Name: "delete with LessThanEqual comparator", + MethodSpec: spec.MethodSpec{ + Name: "DeleteByAgeLessThanEqual", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Name: "age", Type: code.TypeInt}, + }, + Returns: []code.Type{ + code.TypeInt, + code.TypeError, + }, + Operation: spec.DeleteOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + { + Comparator: spec.ComparatorLessThanEqual, + FieldReference: spec.FieldReference{ageField}, + ParamIndex: 1, + }, + }, + }, + }, + }, + ExpectedBody: ` result, err := r.collection.DeleteMany(arg0, bson.M{ + "age": bson.M{ + "$lte": arg1, + }, + }) + if err != nil { + return 0, err + } + return int(result.DeletedCount), nil`, + }, + { + Name: "delete with GreaterThan comparator", + MethodSpec: spec.MethodSpec{ + Name: "DeleteByAgeGreaterThan", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Name: "age", Type: code.TypeInt}, + }, + Returns: []code.Type{ + code.TypeInt, + code.TypeError, + }, + Operation: spec.DeleteOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + { + Comparator: spec.ComparatorGreaterThan, + FieldReference: spec.FieldReference{ageField}, + ParamIndex: 1, + }, + }, + }, + }, + }, + ExpectedBody: ` result, err := r.collection.DeleteMany(arg0, bson.M{ + "age": bson.M{ + "$gt": arg1, + }, + }) + if err != nil { + return 0, err + } + return int(result.DeletedCount), nil`, + }, + { + Name: "delete with GreaterThanEqual comparator", + MethodSpec: spec.MethodSpec{ + Name: "DeleteByAgeGreaterThanEqual", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Name: "age", Type: code.TypeInt}, + }, + Returns: []code.Type{ + code.TypeInt, + code.TypeError, + }, + Operation: spec.DeleteOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + { + Comparator: spec.ComparatorGreaterThanEqual, + FieldReference: spec.FieldReference{ageField}, + ParamIndex: 1, + }, + }, + }, + }, + }, + ExpectedBody: ` result, err := r.collection.DeleteMany(arg0, bson.M{ + "age": bson.M{ + "$gte": arg1, + }, + }) + if err != nil { + return 0, err + } + return int(result.DeletedCount), nil`, + }, + { + Name: "delete with Between comparator", + MethodSpec: spec.MethodSpec{ + Name: "DeleteByAgeBetween", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Name: "fromAge", Type: code.TypeInt}, + {Name: "toAge", Type: code.TypeInt}, + }, + Returns: []code.Type{ + code.TypeInt, + code.TypeError, + }, + Operation: spec.DeleteOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + { + Comparator: spec.ComparatorBetween, + FieldReference: spec.FieldReference{ageField}, + ParamIndex: 1, + }, + }, + }, + }, + }, + ExpectedBody: ` result, err := r.collection.DeleteMany(arg0, bson.M{ + "age": bson.M{ + "$gte": arg1, + "$lte": arg2, + }, + }) + if err != nil { + return 0, err + } + return int(result.DeletedCount), nil`, + }, + { + Name: "delete with In comparator", + MethodSpec: spec.MethodSpec{ + Name: "DeleteByGenderIn", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Name: "gender", Type: code.ArrayType{ContainedType: code.SimpleType("Gender")}}, + }, + Returns: []code.Type{ + code.TypeInt, + code.TypeError, + }, + Operation: spec.DeleteOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + { + Comparator: spec.ComparatorIn, + FieldReference: spec.FieldReference{genderField}, + ParamIndex: 1, + }, + }, + }, + }, + }, + ExpectedBody: ` result, err := r.collection.DeleteMany(arg0, bson.M{ + "gender": bson.M{ + "$in": arg1, + }, + }) + if err != nil { + return 0, err + } + return int(result.DeletedCount), nil`, + }, + } + + for _, testCase := range testTable { + t.Run(testCase.Name, func(t *testing.T) { + generator := mongo.NewGenerator(userModel, "UserRepository") + expectedReceiver := codegen.MethodReceiver{ + Name: "r", + Type: "UserRepositoryMongo", + Pointer: true, + } + var expectedParams []code.Param + for i, param := range testCase.MethodSpec.Params { + expectedParams = append(expectedParams, code.Param{ + Name: fmt.Sprintf("arg%d", i), + Type: param.Type, + }) + } + + actual, err := generator.GenerateMethod(testCase.MethodSpec) + + if err != nil { + t.Fatal(err) + } + if expectedReceiver != actual.Receiver { + t.Errorf( + "incorrect method receiver: expected %+v, got %+v", + expectedReceiver, + actual.Receiver, + ) + } + if testCase.MethodSpec.Name != actual.Name { + t.Errorf( + "incorrect method name: expected %s, got %s", + testCase.MethodSpec.Name, + actual.Name, + ) + } + if !reflect.DeepEqual(expectedParams, actual.Params) { + t.Errorf( + "incorrect struct params: expected %+v, got %+v", + expectedParams, + actual.Params, + ) + } + if !reflect.DeepEqual(testCase.MethodSpec.Returns, actual.Returns) { + t.Errorf( + "incorrect struct returns: expected %+v, got %+v", + testCase.MethodSpec.Returns, + actual.Returns, + ) + } + if err := testutils.ExpectMultiLineString(testCase.ExpectedBody, actual.Body.Code()); err != nil { + t.Error(err) + } + }) + } +} diff --git a/internal/mongo/find.go b/internal/mongo/find.go new file mode 100644 index 0000000..ed6f155 --- /dev/null +++ b/internal/mongo/find.go @@ -0,0 +1,160 @@ +package mongo + +import ( + "github.com/sunboyy/repogen/internal/code" + "github.com/sunboyy/repogen/internal/codegen" + "github.com/sunboyy/repogen/internal/spec" +) + +func (g RepositoryGenerator) generateFindBody( + operation spec.FindOperation) (codegen.FunctionBody, error) { + + return findBodyGenerator{ + baseMethodGenerator: g.baseMethodGenerator, + operation: operation, + }.generate() +} + +type findBodyGenerator struct { + baseMethodGenerator + operation spec.FindOperation +} + +func (g findBodyGenerator) generate() (codegen.FunctionBody, error) { + querySpec, err := g.convertQuerySpec(g.operation.Query) + if err != nil { + return nil, err + } + + sortsCode, err := g.generateSortMap() + if err != nil { + return nil, err + } + + if g.operation.Mode == spec.QueryModeOne { + return g.generateFindOneBody(querySpec, sortsCode), nil + } + + return g.generateFindManyBody(querySpec, sortsCode), nil +} + +func (g findBodyGenerator) generateFindOneBody(querySpec querySpec, + sortsCode codegen.MapStatement) codegen.FunctionBody { + + return codegen.FunctionBody{ + codegen.DeclStatement{ + Name: "entity", + Type: code.SimpleType(g.structModel.Name), + }, + codegen.IfBlock{ + Condition: []codegen.Statement{ + codegen.DeclAssignStatement{ + Vars: []string{"err"}, + Values: codegen.StatementList{ + codegen.NewChainBuilder("r"). + Chain("collection"). + Call("FindOne", + codegen.Identifier("arg0"), + querySpec.Code(), + codegen.NewChainBuilder("options"). + Call("FindOne"). + Call("SetSort", sortsCode). + Build(), + ). + Call("Decode", + codegen.RawStatement("&entity"), + ).Build(), + }, + }, + codegen.RawStatement("err != nil"), + }, + Statements: []codegen.Statement{ + returnNilErr, + }, + }, + codegen.ReturnStatement{ + codegen.RawStatement("&entity"), + codegen.Identifier("nil"), + }, + } +} + +func (g findBodyGenerator) generateFindManyBody(querySpec querySpec, + sortsCode codegen.MapStatement) codegen.FunctionBody { + + return codegen.FunctionBody{ + codegen.DeclAssignStatement{ + Vars: []string{"cursor", "err"}, + Values: codegen.StatementList{ + codegen.NewChainBuilder("r"). + Chain("collection"). + Call("Find", + codegen.Identifier("arg0"), + querySpec.Code(), + codegen.NewChainBuilder("options"). + Call("Find"). + Call("SetSort", sortsCode). + Build(), + ).Build(), + }, + }, + ifErrReturnNilErr, + codegen.DeclStatement{ + Name: "entities", + Type: code.ArrayType{ + ContainedType: code.PointerType{ + ContainedType: code.SimpleType(g.structModel.Name), + }, + }, + }, + codegen.IfBlock{ + Condition: []codegen.Statement{ + codegen.DeclAssignStatement{ + Vars: []string{"err"}, + Values: codegen.StatementList{ + codegen.NewChainBuilder("cursor"). + Call("All", + codegen.Identifier("arg0"), + codegen.RawStatement("&entities"), + ).Build(), + }, + }, + codegen.RawStatement("err != nil"), + }, + Statements: []codegen.Statement{ + returnNilErr, + }, + }, + codegen.ReturnStatement{ + codegen.Identifier("entities"), + codegen.Identifier("nil"), + }, + } +} + +func (g findBodyGenerator) generateSortMap() ( + codegen.MapStatement, error) { + + sortsCode := codegen.MapStatement{ + Type: "bson.M", + } + + for _, s := range g.operation.Sorts { + bsonFieldReference, err := g.bsonFieldReference(s.FieldReference) + if err != nil { + return codegen.MapStatement{}, err + } + + sortValueIdentifier := codegen.Identifier("1") + if s.Ordering == spec.OrderingDescending { + sortValueIdentifier = codegen.Identifier("-1") + } + + sortsCode.Pairs = append(sortsCode.Pairs, codegen.MapPair{ + Key: bsonFieldReference, + Value: sortValueIdentifier, + }) + } + + return sortsCode, nil +} diff --git a/internal/mongo/find_test.go b/internal/mongo/find_test.go new file mode 100644 index 0000000..de1299d --- /dev/null +++ b/internal/mongo/find_test.go @@ -0,0 +1,886 @@ +package mongo_test + +import ( + "fmt" + "reflect" + "testing" + + "github.com/sunboyy/repogen/internal/code" + "github.com/sunboyy/repogen/internal/codegen" + "github.com/sunboyy/repogen/internal/mongo" + "github.com/sunboyy/repogen/internal/spec" + "github.com/sunboyy/repogen/internal/testutils" +) + +func TestGenerateMethod_Find(t *testing.T) { + testTable := []GenerateMethodTestCase{ + { + Name: "simple find one method", + MethodSpec: spec.MethodSpec{ + Name: "FindByID", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Name: "id", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, + }, + Returns: []code.Type{ + code.PointerType{ContainedType: code.SimpleType("UserModel")}, + code.TypeError, + }, + Operation: spec.FindOperation{ + Mode: spec.QueryModeOne, + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + { + Comparator: spec.ComparatorEqual, + FieldReference: spec.FieldReference{idField}, + ParamIndex: 1, + }, + }, + }, + }, + }, + ExpectedBody: ` var entity UserModel + if err := r.collection.FindOne(arg0, bson.M{ + "_id": arg1, + }, options.FindOne().SetSort(bson.M{ + })).Decode(&entity); err != nil { + return nil, err + } + return &entity, nil`, + }, + { + Name: "simple find many method", + MethodSpec: spec.MethodSpec{ + Name: "FindByGender", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Name: "gender", Type: code.SimpleType("Gender")}, + }, + Returns: []code.Type{ + code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + code.TypeError, + }, + Operation: spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + { + Comparator: spec.ComparatorEqual, + FieldReference: spec.FieldReference{genderField}, + ParamIndex: 1, + }, + }, + }, + }, + }, + ExpectedBody: ` cursor, err := r.collection.Find(arg0, bson.M{ + "gender": arg1, + }, options.Find().SetSort(bson.M{ + })) + if err != nil { + return nil, err + } + var entities []*UserModel + if err := cursor.All(arg0, &entities); err != nil { + return nil, err + } + return entities, nil`, + }, + { + Name: "find with deep field reference", + MethodSpec: spec.MethodSpec{ + Name: "FindByNameFirst", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Name: "firstName", Type: code.TypeString}, + }, + Returns: []code.Type{ + code.PointerType{ContainedType: code.SimpleType("UserModel")}, + code.TypeError, + }, + Operation: spec.FindOperation{ + Mode: spec.QueryModeOne, + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + { + Comparator: spec.ComparatorEqual, + FieldReference: spec.FieldReference{nameField, firstNameField}, + ParamIndex: 1, + }, + }, + }, + }, + }, + ExpectedBody: ` var entity UserModel + if err := r.collection.FindOne(arg0, bson.M{ + "name.first": arg1, + }, options.FindOne().SetSort(bson.M{ + })).Decode(&entity); err != nil { + return nil, err + } + return &entity, nil`, + }, + { + Name: "find with And operator", + MethodSpec: spec.MethodSpec{ + Name: "FindByGenderAndAge", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Name: "gender", Type: code.SimpleType("Gender")}, + {Name: "age", Type: code.TypeInt}, + }, + Returns: []code.Type{ + code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + code.TypeError, + }, + Operation: spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{ + Operator: spec.OperatorAnd, + Predicates: []spec.Predicate{ + { + Comparator: spec.ComparatorEqual, + FieldReference: spec.FieldReference{genderField}, + ParamIndex: 1, + }, + { + Comparator: spec.ComparatorEqual, + FieldReference: spec.FieldReference{ageField}, + ParamIndex: 2, + }, + }, + }, + }, + }, + ExpectedBody: ` cursor, err := r.collection.Find(arg0, bson.M{ + "$and": []bson.M{ + { + "gender": arg1, + }, + { + "age": arg2, + }, + }, + }, options.Find().SetSort(bson.M{ + })) + if err != nil { + return nil, err + } + var entities []*UserModel + if err := cursor.All(arg0, &entities); err != nil { + return nil, err + } + return entities, nil`, + }, + { + Name: "find with Or operator", + MethodSpec: spec.MethodSpec{ + Name: "FindByGenderOrAge", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Name: "gender", Type: code.SimpleType("Gender")}, + {Name: "age", Type: code.TypeInt}, + }, + Returns: []code.Type{ + code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + code.TypeError, + }, + Operation: spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{ + Operator: spec.OperatorOr, + Predicates: []spec.Predicate{ + { + Comparator: spec.ComparatorEqual, + FieldReference: spec.FieldReference{genderField}, + ParamIndex: 1, + }, + { + Comparator: spec.ComparatorEqual, + FieldReference: spec.FieldReference{ageField}, + ParamIndex: 2, + }, + }, + }, + }, + }, + ExpectedBody: ` cursor, err := r.collection.Find(arg0, bson.M{ + "$or": []bson.M{ + { + "gender": arg1, + }, + { + "age": arg2, + }, + }, + }, options.Find().SetSort(bson.M{ + })) + if err != nil { + return nil, err + } + var entities []*UserModel + if err := cursor.All(arg0, &entities); err != nil { + return nil, err + } + return entities, nil`, + }, + { + Name: "find with Not comparator", + MethodSpec: spec.MethodSpec{ + Name: "FindByGenderNot", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Name: "gender", Type: code.SimpleType("Gender")}, + }, + Returns: []code.Type{ + code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + code.TypeError, + }, + Operation: spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + { + Comparator: spec.ComparatorNot, + FieldReference: spec.FieldReference{genderField}, + ParamIndex: 1, + }, + }, + }, + }, + }, + ExpectedBody: ` cursor, err := r.collection.Find(arg0, bson.M{ + "gender": bson.M{ + "$ne": arg1, + }, + }, options.Find().SetSort(bson.M{ + })) + if err != nil { + return nil, err + } + var entities []*UserModel + if err := cursor.All(arg0, &entities); err != nil { + return nil, err + } + return entities, nil`, + }, + { + Name: "find with LessThan comparator", + MethodSpec: spec.MethodSpec{ + Name: "FindByAgeLessThan", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Name: "age", Type: code.TypeInt}, + }, + Returns: []code.Type{ + code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + code.TypeError, + }, + Operation: spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + { + Comparator: spec.ComparatorLessThan, + FieldReference: spec.FieldReference{ageField}, + ParamIndex: 1, + }, + }, + }, + }, + }, + ExpectedBody: ` cursor, err := r.collection.Find(arg0, bson.M{ + "age": bson.M{ + "$lt": arg1, + }, + }, options.Find().SetSort(bson.M{ + })) + if err != nil { + return nil, err + } + var entities []*UserModel + if err := cursor.All(arg0, &entities); err != nil { + return nil, err + } + return entities, nil`, + }, + { + Name: "find with LessThanEqual comparator", + MethodSpec: spec.MethodSpec{ + Name: "FindByAgeLessThanEqual", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Name: "age", Type: code.TypeInt}, + }, + Returns: []code.Type{ + code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + code.TypeError, + }, + Operation: spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + { + Comparator: spec.ComparatorLessThanEqual, + FieldReference: spec.FieldReference{ageField}, + ParamIndex: 1, + }, + }, + }, + }, + }, + ExpectedBody: ` cursor, err := r.collection.Find(arg0, bson.M{ + "age": bson.M{ + "$lte": arg1, + }, + }, options.Find().SetSort(bson.M{ + })) + if err != nil { + return nil, err + } + var entities []*UserModel + if err := cursor.All(arg0, &entities); err != nil { + return nil, err + } + return entities, nil`, + }, + { + Name: "find with GreaterThan comparator", + MethodSpec: spec.MethodSpec{ + Name: "FindByAgeGreaterThan", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Name: "age", Type: code.TypeInt}, + }, + Returns: []code.Type{ + code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + code.TypeError, + }, + Operation: spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + { + Comparator: spec.ComparatorGreaterThan, + FieldReference: spec.FieldReference{ageField}, + ParamIndex: 1, + }, + }, + }, + }, + }, + ExpectedBody: ` cursor, err := r.collection.Find(arg0, bson.M{ + "age": bson.M{ + "$gt": arg1, + }, + }, options.Find().SetSort(bson.M{ + })) + if err != nil { + return nil, err + } + var entities []*UserModel + if err := cursor.All(arg0, &entities); err != nil { + return nil, err + } + return entities, nil`, + }, + { + Name: "find with GreaterThanEqual comparator", + MethodSpec: spec.MethodSpec{ + Name: "FindByAgeGreaterThanEqual", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Name: "age", Type: code.TypeInt}, + }, + Returns: []code.Type{ + code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + code.TypeError, + }, + Operation: spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + { + Comparator: spec.ComparatorGreaterThanEqual, + FieldReference: spec.FieldReference{ageField}, + ParamIndex: 1, + }, + }, + }, + }, + }, + ExpectedBody: ` cursor, err := r.collection.Find(arg0, bson.M{ + "age": bson.M{ + "$gte": arg1, + }, + }, options.Find().SetSort(bson.M{ + })) + if err != nil { + return nil, err + } + var entities []*UserModel + if err := cursor.All(arg0, &entities); err != nil { + return nil, err + } + return entities, nil`, + }, + { + Name: "find with Between comparator", + MethodSpec: spec.MethodSpec{ + Name: "FindByAgeBetween", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Name: "fromAge", Type: code.TypeInt}, + {Name: "toAge", Type: code.TypeInt}, + }, + Returns: []code.Type{ + code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + code.TypeError, + }, + Operation: spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + { + Comparator: spec.ComparatorBetween, + FieldReference: spec.FieldReference{ageField}, + ParamIndex: 1, + }, + }, + }, + }, + }, + ExpectedBody: ` cursor, err := r.collection.Find(arg0, bson.M{ + "age": bson.M{ + "$gte": arg1, + "$lte": arg2, + }, + }, options.Find().SetSort(bson.M{ + })) + if err != nil { + return nil, err + } + var entities []*UserModel + if err := cursor.All(arg0, &entities); err != nil { + return nil, err + } + return entities, nil`, + }, + { + Name: "find with In comparator", + MethodSpec: spec.MethodSpec{ + Name: "FindByGenderIn", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Name: "gender", Type: code.ArrayType{ContainedType: code.SimpleType("Gender")}}, + }, + Returns: []code.Type{ + code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + code.TypeError, + }, + Operation: spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + { + Comparator: spec.ComparatorIn, + FieldReference: spec.FieldReference{genderField}, + ParamIndex: 1, + }, + }, + }, + }, + }, + ExpectedBody: ` cursor, err := r.collection.Find(arg0, bson.M{ + "gender": bson.M{ + "$in": arg1, + }, + }, options.Find().SetSort(bson.M{ + })) + if err != nil { + return nil, err + } + var entities []*UserModel + if err := cursor.All(arg0, &entities); err != nil { + return nil, err + } + return entities, nil`, + }, + { + Name: "find with NotIn comparator", + MethodSpec: spec.MethodSpec{ + Name: "FindByGenderNotIn", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Name: "gender", Type: code.ArrayType{ContainedType: code.SimpleType("Gender")}}, + }, + Returns: []code.Type{ + code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + code.TypeError, + }, + Operation: spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + { + Comparator: spec.ComparatorNotIn, + FieldReference: spec.FieldReference{genderField}, + ParamIndex: 1, + }, + }, + }, + }, + }, + ExpectedBody: ` cursor, err := r.collection.Find(arg0, bson.M{ + "gender": bson.M{ + "$nin": arg1, + }, + }, options.Find().SetSort(bson.M{ + })) + if err != nil { + return nil, err + } + var entities []*UserModel + if err := cursor.All(arg0, &entities); err != nil { + return nil, err + } + return entities, nil`, + }, + { + Name: "find with True comparator", + MethodSpec: spec.MethodSpec{ + Name: "FindByEnabledTrue", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + }, + Returns: []code.Type{ + code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + code.TypeError, + }, + Operation: spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + { + Comparator: spec.ComparatorTrue, + FieldReference: spec.FieldReference{enabledField}, + ParamIndex: 1, + }, + }, + }, + }, + }, + ExpectedBody: ` cursor, err := r.collection.Find(arg0, bson.M{ + "enabled": true, + }, options.Find().SetSort(bson.M{ + })) + if err != nil { + return nil, err + } + var entities []*UserModel + if err := cursor.All(arg0, &entities); err != nil { + return nil, err + } + return entities, nil`, + }, + { + Name: "find with False comparator", + MethodSpec: spec.MethodSpec{ + Name: "FindByEnabledFalse", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + }, + Returns: []code.Type{ + code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + code.TypeError, + }, + Operation: spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + { + Comparator: spec.ComparatorFalse, + FieldReference: spec.FieldReference{enabledField}, + ParamIndex: 1, + }, + }, + }, + }, + }, + ExpectedBody: ` cursor, err := r.collection.Find(arg0, bson.M{ + "enabled": false, + }, options.Find().SetSort(bson.M{ + })) + if err != nil { + return nil, err + } + var entities []*UserModel + if err := cursor.All(arg0, &entities); err != nil { + return nil, err + } + return entities, nil`, + }, + { + Name: "find with Exists comparator", + MethodSpec: spec.MethodSpec{ + Name: "FindByReferrerExists", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + }, + Returns: []code.Type{ + code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + code.TypeError, + }, + Operation: spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + { + Comparator: spec.ComparatorExists, + FieldReference: spec.FieldReference{referrerField}, + ParamIndex: 1, + }, + }, + }, + }, + }, + ExpectedBody: ` cursor, err := r.collection.Find(arg0, bson.M{ + "referrer": bson.M{ + "$exists": 1, + }, + }, options.Find().SetSort(bson.M{ + })) + if err != nil { + return nil, err + } + var entities []*UserModel + if err := cursor.All(arg0, &entities); err != nil { + return nil, err + } + return entities, nil`, + }, + { + Name: "find with NotExists comparator", + MethodSpec: spec.MethodSpec{ + Name: "FindByReferrerNotExists", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + }, + Returns: []code.Type{ + code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + code.TypeError, + }, + Operation: spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + { + Comparator: spec.ComparatorNotExists, + FieldReference: spec.FieldReference{referrerField}, + ParamIndex: 1, + }, + }, + }, + }, + }, + ExpectedBody: ` cursor, err := r.collection.Find(arg0, bson.M{ + "referrer": bson.M{ + "$exists": 0, + }, + }, options.Find().SetSort(bson.M{ + })) + if err != nil { + return nil, err + } + var entities []*UserModel + if err := cursor.All(arg0, &entities); err != nil { + return nil, err + } + return entities, nil`, + }, + { + Name: "find with sort ascending", + MethodSpec: spec.MethodSpec{ + Name: "FindAllOrderByAge", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + }, + Returns: []code.Type{ + code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + code.TypeError, + }, + Operation: spec.FindOperation{ + Mode: spec.QueryModeMany, + Sorts: []spec.Sort{ + {FieldReference: spec.FieldReference{ageField}, Ordering: spec.OrderingAscending}, + }, + }, + }, + ExpectedBody: ` cursor, err := r.collection.Find(arg0, bson.M{ + }, options.Find().SetSort(bson.M{ + "age": 1, + })) + if err != nil { + return nil, err + } + var entities []*UserModel + if err := cursor.All(arg0, &entities); err != nil { + return nil, err + } + return entities, nil`, + }, + { + Name: "find with sort descending", + MethodSpec: spec.MethodSpec{ + Name: "FindAllOrderByAgeDesc", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + }, + Returns: []code.Type{ + code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + code.TypeError, + }, + Operation: spec.FindOperation{ + Mode: spec.QueryModeMany, + Sorts: []spec.Sort{ + {FieldReference: spec.FieldReference{ageField}, Ordering: spec.OrderingDescending}, + }, + }, + }, + ExpectedBody: ` cursor, err := r.collection.Find(arg0, bson.M{ + }, options.Find().SetSort(bson.M{ + "age": -1, + })) + if err != nil { + return nil, err + } + var entities []*UserModel + if err := cursor.All(arg0, &entities); err != nil { + return nil, err + } + return entities, nil`, + }, + { + Name: "find with deep sort ascending", + MethodSpec: spec.MethodSpec{ + Name: "FindAllOrderByNameFirst", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + }, + Returns: []code.Type{ + code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + code.TypeError, + }, + Operation: spec.FindOperation{ + Mode: spec.QueryModeMany, + Sorts: []spec.Sort{ + { + FieldReference: spec.FieldReference{nameField, firstNameField}, + Ordering: spec.OrderingAscending, + }, + }, + }, + }, + ExpectedBody: ` cursor, err := r.collection.Find(arg0, bson.M{ + }, options.Find().SetSort(bson.M{ + "name.first": 1, + })) + if err != nil { + return nil, err + } + var entities []*UserModel + if err := cursor.All(arg0, &entities); err != nil { + return nil, err + } + return entities, nil`, + }, + { + Name: "find with multiple sorts", + MethodSpec: spec.MethodSpec{ + Name: "FindAllOrderByGenderAndAgeDesc", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + }, + Returns: []code.Type{ + code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + code.TypeError, + }, + Operation: spec.FindOperation{ + Mode: spec.QueryModeMany, + Sorts: []spec.Sort{ + {FieldReference: spec.FieldReference{genderField}, Ordering: spec.OrderingAscending}, + {FieldReference: spec.FieldReference{ageField}, Ordering: spec.OrderingDescending}, + }, + }, + }, + ExpectedBody: ` cursor, err := r.collection.Find(arg0, bson.M{ + }, options.Find().SetSort(bson.M{ + "gender": 1, + "age": -1, + })) + if err != nil { + return nil, err + } + var entities []*UserModel + if err := cursor.All(arg0, &entities); err != nil { + return nil, err + } + return entities, nil`, + }, + } + + for _, testCase := range testTable { + t.Run(testCase.Name, func(t *testing.T) { + generator := mongo.NewGenerator(userModel, "UserRepository") + expectedReceiver := codegen.MethodReceiver{ + Name: "r", + Type: "UserRepositoryMongo", + Pointer: true, + } + var expectedParams []code.Param + for i, param := range testCase.MethodSpec.Params { + expectedParams = append(expectedParams, code.Param{ + Name: fmt.Sprintf("arg%d", i), + Type: param.Type, + }) + } + + actual, err := generator.GenerateMethod(testCase.MethodSpec) + + if err != nil { + t.Fatal(err) + } + if expectedReceiver != actual.Receiver { + t.Errorf( + "incorrect method receiver: expected %+v, got %+v", + expectedReceiver, + actual.Receiver, + ) + } + if testCase.MethodSpec.Name != actual.Name { + t.Errorf( + "incorrect method name: expected %s, got %s", + testCase.MethodSpec.Name, + actual.Name, + ) + } + if !reflect.DeepEqual(expectedParams, actual.Params) { + t.Errorf( + "incorrect struct params: expected %+v, got %+v", + expectedParams, + actual.Params, + ) + } + if !reflect.DeepEqual(testCase.MethodSpec.Returns, actual.Returns) { + t.Errorf( + "incorrect struct returns: expected %+v, got %+v", + testCase.MethodSpec.Returns, + actual.Returns, + ) + } + if err := testutils.ExpectMultiLineString(testCase.ExpectedBody, actual.Body.Code()); err != nil { + t.Error(err) + } + }) + } +} diff --git a/internal/mongo/generator.go b/internal/mongo/generator.go index 3265a74..26c040e 100644 --- a/internal/mongo/generator.go +++ b/internal/mongo/generator.go @@ -1,10 +1,7 @@ package mongo import ( - "bytes" "fmt" - "strings" - "text/template" "github.com/sunboyy/repogen/internal/code" "github.com/sunboyy/repogen/internal/codegen" @@ -14,7 +11,9 @@ import ( // NewGenerator creates a new instance of MongoDB repository generator func NewGenerator(structModel code.Struct, interfaceName string) RepositoryGenerator { return RepositoryGenerator{ - StructModel: structModel, + baseMethodGenerator: baseMethodGenerator{ + structModel: structModel, + }, InterfaceName: interfaceName, } } @@ -22,7 +21,7 @@ func NewGenerator(structModel code.Struct, interfaceName string) RepositoryGener // RepositoryGenerator is a MongoDB repository generator that provides // necessary information required to construct an implementation. type RepositoryGenerator struct { - StructModel code.Struct + baseMethodGenerator InterfaceName string } @@ -63,20 +62,6 @@ func (g RepositoryGenerator) GenerateStruct() codegen.StructBuilder { // GenerateConstructor creates codegen.FunctionBuilder of a constructor for // mongo repository implementation struct. func (g RepositoryGenerator) GenerateConstructor() (codegen.FunctionBuilder, error) { - tmpl, err := template.New("mongo_constructor_body").Parse(constructorBody) - if err != nil { - return codegen.FunctionBuilder{}, err - } - - tmplData := constructorBodyData{ - ImplStructName: g.repoImplStructName(), - } - - buffer := new(bytes.Buffer) - if err := tmpl.Execute(buffer, tmplData); err != nil { - return codegen.FunctionBuilder{}, err - } - return codegen.FunctionBuilder{ Name: "New" + g.InterfaceName, Params: []code.Param{ @@ -93,7 +78,17 @@ func (g RepositoryGenerator) GenerateConstructor() (codegen.FunctionBuilder, err Returns: []code.Type{ code.SimpleType(g.InterfaceName), }, - Body: buffer.String(), + Body: codegen.FunctionBody{ + codegen.ReturnStatement{ + codegen.StructStatement{ + Type: fmt.Sprintf("&%s", g.repoImplStructName()), + Pairs: []codegen.StructFieldPair{{ + Key: "collection", + Value: codegen.Identifier("collection"), + }}, + }, + }, + }, }, nil } @@ -126,220 +121,25 @@ func (g RepositoryGenerator) GenerateMethod(methodSpec spec.MethodSpec) (codegen }, nil } -func (g RepositoryGenerator) generateMethodImplementation(methodSpec spec.MethodSpec) (string, error) { +func (g RepositoryGenerator) generateMethodImplementation( + methodSpec spec.MethodSpec) (codegen.FunctionBody, error) { + switch operation := methodSpec.Operation.(type) { case spec.InsertOperation: - return g.generateInsertImplementation(operation) + return g.generateInsertBody(operation), nil case spec.FindOperation: - return g.generateFindImplementation(operation) + return g.generateFindBody(operation) case spec.UpdateOperation: - return g.generateUpdateImplementation(operation) + return g.generateUpdateBody(operation) case spec.DeleteOperation: - return g.generateDeleteImplementation(operation) + return g.generateDeleteBody(operation) case spec.CountOperation: - return g.generateCountImplementation(operation) + return g.generateCountBody(operation) default: - return "", NewOperationNotSupportedError(operation.Name()) + return nil, NewOperationNotSupportedError(operation.Name()) } } -func (g RepositoryGenerator) generateInsertImplementation(operation spec.InsertOperation) (string, error) { - if operation.Mode == spec.QueryModeOne { - return insertOneTemplate, nil - } - return insertManyTemplate, nil -} - -func (g RepositoryGenerator) generateFindImplementation(operation spec.FindOperation) (string, error) { - querySpec, err := g.mongoQuerySpec(operation.Query) - if err != nil { - return "", err - } - - sorts, err := g.mongoSorts(operation.Sorts) - if err != nil { - return "", err - } - - tmplData := mongoFindTemplateData{ - EntityType: g.StructModel.Name, - QuerySpec: querySpec, - Sorts: sorts, - } - - if operation.Mode == spec.QueryModeOne { - return generateFromTemplate("mongo_repository_findone", findOneTemplate, tmplData) - } - return generateFromTemplate("mongo_repository_findmany", findManyTemplate, tmplData) -} - -func (g RepositoryGenerator) mongoSorts(sortSpec []spec.Sort) ([]findSort, error) { - var sorts []findSort - - for _, s := range sortSpec { - bsonFieldReference, err := g.bsonFieldReference(s.FieldReference) - if err != nil { - return nil, err - } - - sorts = append(sorts, findSort{ - BsonTag: bsonFieldReference, - Ordering: s.Ordering, - }) - } - - return sorts, nil -} - -func (g RepositoryGenerator) generateUpdateImplementation(operation spec.UpdateOperation) (string, error) { - update, err := g.getMongoUpdate(operation.Update) - if err != nil { - return "", err - } - - querySpec, err := g.mongoQuerySpec(operation.Query) - if err != nil { - return "", err - } - - tmplData := mongoUpdateTemplateData{ - Update: update, - QuerySpec: querySpec, - } - - if operation.Mode == spec.QueryModeOne { - return generateFromTemplate("mongo_repository_updateone", updateOneTemplate, tmplData) - } - return generateFromTemplate("mongo_repository_updatemany", updateManyTemplate, tmplData) -} - -func (g RepositoryGenerator) getMongoUpdate(updateSpec spec.Update) (update, error) { - switch updateSpec := updateSpec.(type) { - case spec.UpdateModel: - return updateModel{}, nil - case spec.UpdateFields: - update := make(updateFields) - for _, field := range updateSpec { - bsonFieldReference, err := g.bsonFieldReference(field.FieldReference) - if err != nil { - return querySpec{}, err - } - - updateKey := getUpdateOperatorKey(field.Operator) - if updateKey == "" { - return querySpec{}, NewUpdateOperatorNotSupportedError(field.Operator) - } - updateField := updateField{ - BsonTag: bsonFieldReference, - ParamIndex: field.ParamIndex, - } - update[updateKey] = append(update[updateKey], updateField) - } - return update, nil - default: - return nil, NewUpdateTypeNotSupportedError(updateSpec) - } -} - -func getUpdateOperatorKey(operator spec.UpdateOperator) string { - switch operator { - case spec.UpdateOperatorSet: - return "$set" - case spec.UpdateOperatorPush: - return "$push" - case spec.UpdateOperatorInc: - return "$inc" - default: - return "" - } -} - -func (g RepositoryGenerator) generateDeleteImplementation(operation spec.DeleteOperation) (string, error) { - querySpec, err := g.mongoQuerySpec(operation.Query) - if err != nil { - return "", err - } - - tmplData := mongoDeleteTemplateData{ - QuerySpec: querySpec, - } - - if operation.Mode == spec.QueryModeOne { - return generateFromTemplate("mongo_repository_deleteone", deleteOneTemplate, tmplData) - } - return generateFromTemplate("mongo_repository_deletemany", deleteManyTemplate, tmplData) -} - -func (g RepositoryGenerator) generateCountImplementation(operation spec.CountOperation) (string, error) { - querySpec, err := g.mongoQuerySpec(operation.Query) - if err != nil { - return "", err - } - - tmplData := mongoCountTemplateData{ - QuerySpec: querySpec, - } - - return generateFromTemplate("mongo_repository_count", countTemplate, tmplData) -} - -func (g RepositoryGenerator) mongoQuerySpec(query spec.QuerySpec) (querySpec, error) { - var predicates []predicate - - for _, predicateSpec := range query.Predicates { - bsonFieldReference, err := g.bsonFieldReference(predicateSpec.FieldReference) - if err != nil { - return querySpec{}, err - } - - predicates = append(predicates, predicate{ - Field: bsonFieldReference, - Comparator: predicateSpec.Comparator, - ParamIndex: predicateSpec.ParamIndex, - }) - } - - return querySpec{ - Operator: query.Operator, - Predicates: predicates, - }, nil -} - -func (g RepositoryGenerator) bsonFieldReference(fieldReference spec.FieldReference) (string, error) { - var bsonTags []string - for _, field := range fieldReference { - tag, err := g.bsonTagFromField(field) - if err != nil { - return "", err - } - bsonTags = append(bsonTags, tag) - } - return strings.Join(bsonTags, "."), nil -} - -func (g RepositoryGenerator) bsonTagFromField(field code.StructField) (string, error) { - bsonTag, ok := field.Tags["bson"] - if !ok { - return "", NewBsonTagNotFoundError(field.Name) - } - - return bsonTag[0], nil -} - func (g RepositoryGenerator) repoImplStructName() string { return g.InterfaceName + "Mongo" } - -func generateFromTemplate(name string, templateString string, tmplData interface{}) (string, error) { - tmpl, err := template.New(name).Parse(templateString) - if err != nil { - return "", err - } - - buffer := new(bytes.Buffer) - if err := tmpl.Execute(buffer, tmplData); err != nil { - return "", err - } - - return buffer.String(), nil -} diff --git a/internal/mongo/generator_test.go b/internal/mongo/generator_test.go index cdee726..34900f8 100644 --- a/internal/mongo/generator_test.go +++ b/internal/mongo/generator_test.go @@ -2,7 +2,6 @@ package mongo_test import ( "errors" - "fmt" "reflect" "testing" @@ -10,7 +9,6 @@ import ( "github.com/sunboyy/repogen/internal/codegen" "github.com/sunboyy/repogen/internal/mongo" "github.com/sunboyy/repogen/internal/spec" - "github.com/sunboyy/repogen/internal/testutils" ) var ( @@ -80,10 +78,6 @@ var userModel = code.Struct{ }, } -const expectedConstructorBody = ` return &UserRepositoryMongo{ - collection: collection, - }` - func TestImports(t *testing.T) { generator := mongo.NewGenerator(userModel, "UserRepository") expected := [][]code.Import{ @@ -158,7 +152,17 @@ func TestGenerateConstructor(t *testing.T) { Returns: []code.Type{ code.SimpleType("UserRepository"), }, - Body: expectedConstructorBody, + Body: codegen.FunctionBody{ + codegen.ReturnStatement{ + codegen.StructStatement{ + Type: "&UserRepositoryMongo", + Pairs: []codegen.StructFieldPair{{ + Key: "collection", + Value: codegen.Identifier("collection"), + }}, + }, + }, + }, } actual, err := generator.GenerateConstructor() @@ -180,8 +184,11 @@ func TestGenerateConstructor(t *testing.T) { actual.Params, ) } - if err := testutils.ExpectMultiLineString(expected.Body, actual.Body); err != nil { - t.Error(err) + if !reflect.DeepEqual(expected.Body, actual.Body) { + t.Errorf("incorrect function body: expected %+v got %+v", + expected.Body, + actual.Body, + ) } } @@ -191,2212 +198,6 @@ type GenerateMethodTestCase struct { ExpectedBody string } -func TestGenerateMethod_Insert(t *testing.T) { - testTable := []GenerateMethodTestCase{ - { - Name: "insert one method", - MethodSpec: spec.MethodSpec{ - Name: "InsertOne", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "userModel", Type: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - }, - Returns: []code.Type{ - code.InterfaceType{}, - code.TypeError, - }, - Operation: spec.InsertOperation{ - Mode: spec.QueryModeOne, - }, - }, - ExpectedBody: ` result, err := r.collection.InsertOne(arg0, arg1) - if err != nil { - return nil, err - } - return result.InsertedID, nil`, - }, - { - Name: "insert many method", - MethodSpec: spec.MethodSpec{ - Name: "Insert", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "userModel", Type: code.ArrayType{ - ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}, - }}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.InterfaceType{}}, - code.TypeError, - }, - Operation: spec.InsertOperation{ - Mode: spec.QueryModeMany, - }, - }, - ExpectedBody: ` var entities []interface{} - for _, model := range arg1 { - entities = append(entities, model) - } - result, err := r.collection.InsertMany(arg0, entities) - if err != nil { - return nil, err - } - return result.InsertedIDs, nil`, - }, - } - - for _, testCase := range testTable { - t.Run(testCase.Name, func(t *testing.T) { - generator := mongo.NewGenerator(userModel, "UserRepository") - var expectedParams []code.Param - for i, param := range testCase.MethodSpec.Params { - expectedParams = append(expectedParams, code.Param{ - Name: fmt.Sprintf("arg%d", i), - Type: param.Type, - }) - } - expected := codegen.MethodBuilder{ - Receiver: codegen.MethodReceiver{ - Name: "r", - Type: "UserRepositoryMongo", - Pointer: true, - }, - Name: testCase.MethodSpec.Name, - Params: expectedParams, - Returns: testCase.MethodSpec.Returns, - Body: testCase.ExpectedBody, - } - - actual, err := generator.GenerateMethod(testCase.MethodSpec) - - if err != nil { - t.Fatal(err) - } - if expected.Receiver != actual.Receiver { - t.Errorf( - "incorrect method receiver: expected %+v, got %+v", - expected.Receiver, - actual.Receiver, - ) - } - if expected.Name != actual.Name { - t.Errorf( - "incorrect method name: expected %s, got %s", - expected.Name, - actual.Name, - ) - } - if !reflect.DeepEqual(expected.Params, actual.Params) { - t.Errorf( - "incorrect struct params: expected %+v, got %+v", - expected.Params, - actual.Params, - ) - } - if !reflect.DeepEqual(expected.Returns, actual.Returns) { - t.Errorf( - "incorrect struct returns: expected %+v, got %+v", - expected.Returns, - actual.Returns, - ) - } - if err := testutils.ExpectMultiLineString(expected.Body, actual.Body); err != nil { - t.Error(err) - } - }) - } -} - -func TestGenerateMethod_Find(t *testing.T) { - testTable := []GenerateMethodTestCase{ - { - Name: "simple find one method", - MethodSpec: spec.MethodSpec{ - Name: "FindByID", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "id", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, - }, - Returns: []code.Type{ - code.PointerType{ContainedType: code.SimpleType("UserModel")}, - code.TypeError, - }, - Operation: spec.FindOperation{ - Mode: spec.QueryModeOne, - Query: spec.QuerySpec{ - Predicates: []spec.Predicate{ - { - Comparator: spec.ComparatorEqual, - FieldReference: spec.FieldReference{idField}, - ParamIndex: 1, - }, - }, - }, - }, - }, - ExpectedBody: ` var entity UserModel - if err := r.collection.FindOne(arg0, bson.M{ - "_id": arg1, - }, options.FindOne().SetSort(bson.M{ - })).Decode(&entity); err != nil { - return nil, err - } - return &entity, nil`, - }, - { - Name: "simple find many method", - MethodSpec: spec.MethodSpec{ - Name: "FindByGender", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "gender", Type: code.SimpleType("Gender")}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, - }, - Operation: spec.FindOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{ - Predicates: []spec.Predicate{ - { - Comparator: spec.ComparatorEqual, - FieldReference: spec.FieldReference{genderField}, - ParamIndex: 1, - }, - }, - }, - }, - }, - ExpectedBody: ` cursor, err := r.collection.Find(arg0, bson.M{ - "gender": arg1, - }, options.Find().SetSort(bson.M{ - })) - if err != nil { - return nil, err - } - var entities []*UserModel - if err := cursor.All(arg0, &entities); err != nil { - return nil, err - } - return entities, nil`, - }, - { - Name: "find with deep field reference", - MethodSpec: spec.MethodSpec{ - Name: "FindByNameFirst", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "firstName", Type: code.TypeString}, - }, - Returns: []code.Type{ - code.PointerType{ContainedType: code.SimpleType("UserModel")}, - code.TypeError, - }, - Operation: spec.FindOperation{ - Mode: spec.QueryModeOne, - Query: spec.QuerySpec{ - Predicates: []spec.Predicate{ - { - Comparator: spec.ComparatorEqual, - FieldReference: spec.FieldReference{nameField, firstNameField}, - ParamIndex: 1, - }, - }, - }, - }, - }, - ExpectedBody: ` var entity UserModel - if err := r.collection.FindOne(arg0, bson.M{ - "name.first": arg1, - }, options.FindOne().SetSort(bson.M{ - })).Decode(&entity); err != nil { - return nil, err - } - return &entity, nil`, - }, - { - Name: "find with And operator", - MethodSpec: spec.MethodSpec{ - Name: "FindByGenderAndAge", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "gender", Type: code.SimpleType("Gender")}, - {Name: "age", Type: code.TypeInt}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, - }, - Operation: spec.FindOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{ - Operator: spec.OperatorAnd, - Predicates: []spec.Predicate{ - { - Comparator: spec.ComparatorEqual, - FieldReference: spec.FieldReference{genderField}, - ParamIndex: 1, - }, - { - Comparator: spec.ComparatorEqual, - FieldReference: spec.FieldReference{ageField}, - ParamIndex: 2, - }, - }, - }, - }, - }, - ExpectedBody: ` cursor, err := r.collection.Find(arg0, bson.M{ - "$and": []bson.M{ - {"gender": arg1}, - {"age": arg2}, - }, - }, options.Find().SetSort(bson.M{ - })) - if err != nil { - return nil, err - } - var entities []*UserModel - if err := cursor.All(arg0, &entities); err != nil { - return nil, err - } - return entities, nil`, - }, - { - Name: "find with Or operator", - MethodSpec: spec.MethodSpec{ - Name: "FindByGenderOrAge", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "gender", Type: code.SimpleType("Gender")}, - {Name: "age", Type: code.TypeInt}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, - }, - Operation: spec.FindOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{ - Operator: spec.OperatorOr, - Predicates: []spec.Predicate{ - { - Comparator: spec.ComparatorEqual, - FieldReference: spec.FieldReference{genderField}, - ParamIndex: 1, - }, - { - Comparator: spec.ComparatorEqual, - FieldReference: spec.FieldReference{ageField}, - ParamIndex: 2, - }, - }, - }, - }, - }, - ExpectedBody: ` cursor, err := r.collection.Find(arg0, bson.M{ - "$or": []bson.M{ - {"gender": arg1}, - {"age": arg2}, - }, - }, options.Find().SetSort(bson.M{ - })) - if err != nil { - return nil, err - } - var entities []*UserModel - if err := cursor.All(arg0, &entities); err != nil { - return nil, err - } - return entities, nil`, - }, - { - Name: "find with Not comparator", - MethodSpec: spec.MethodSpec{ - Name: "FindByGenderNot", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "gender", Type: code.SimpleType("Gender")}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, - }, - Operation: spec.FindOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{ - Predicates: []spec.Predicate{ - { - Comparator: spec.ComparatorNot, - FieldReference: spec.FieldReference{genderField}, - ParamIndex: 1, - }, - }, - }, - }, - }, - ExpectedBody: ` cursor, err := r.collection.Find(arg0, bson.M{ - "gender": bson.M{"$ne": arg1}, - }, options.Find().SetSort(bson.M{ - })) - if err != nil { - return nil, err - } - var entities []*UserModel - if err := cursor.All(arg0, &entities); err != nil { - return nil, err - } - return entities, nil`, - }, - { - Name: "find with LessThan comparator", - MethodSpec: spec.MethodSpec{ - Name: "FindByAgeLessThan", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "age", Type: code.TypeInt}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, - }, - Operation: spec.FindOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{ - Predicates: []spec.Predicate{ - { - Comparator: spec.ComparatorLessThan, - FieldReference: spec.FieldReference{ageField}, - ParamIndex: 1, - }, - }, - }, - }, - }, - ExpectedBody: ` cursor, err := r.collection.Find(arg0, bson.M{ - "age": bson.M{"$lt": arg1}, - }, options.Find().SetSort(bson.M{ - })) - if err != nil { - return nil, err - } - var entities []*UserModel - if err := cursor.All(arg0, &entities); err != nil { - return nil, err - } - return entities, nil`, - }, - { - Name: "find with LessThanEqual comparator", - MethodSpec: spec.MethodSpec{ - Name: "FindByAgeLessThanEqual", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "age", Type: code.TypeInt}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, - }, - Operation: spec.FindOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{ - Predicates: []spec.Predicate{ - { - Comparator: spec.ComparatorLessThanEqual, - FieldReference: spec.FieldReference{ageField}, - ParamIndex: 1, - }, - }, - }, - }, - }, - ExpectedBody: ` cursor, err := r.collection.Find(arg0, bson.M{ - "age": bson.M{"$lte": arg1}, - }, options.Find().SetSort(bson.M{ - })) - if err != nil { - return nil, err - } - var entities []*UserModel - if err := cursor.All(arg0, &entities); err != nil { - return nil, err - } - return entities, nil`, - }, - { - Name: "find with GreaterThan comparator", - MethodSpec: spec.MethodSpec{ - Name: "FindByAgeGreaterThan", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "age", Type: code.TypeInt}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, - }, - Operation: spec.FindOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{ - Predicates: []spec.Predicate{ - { - Comparator: spec.ComparatorGreaterThan, - FieldReference: spec.FieldReference{ageField}, - ParamIndex: 1, - }, - }, - }, - }, - }, - ExpectedBody: ` cursor, err := r.collection.Find(arg0, bson.M{ - "age": bson.M{"$gt": arg1}, - }, options.Find().SetSort(bson.M{ - })) - if err != nil { - return nil, err - } - var entities []*UserModel - if err := cursor.All(arg0, &entities); err != nil { - return nil, err - } - return entities, nil`, - }, - { - Name: "find with GreaterThanEqual comparator", - MethodSpec: spec.MethodSpec{ - Name: "FindByAgeGreaterThanEqual", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "age", Type: code.TypeInt}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, - }, - Operation: spec.FindOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{ - Predicates: []spec.Predicate{ - { - Comparator: spec.ComparatorGreaterThanEqual, - FieldReference: spec.FieldReference{ageField}, - ParamIndex: 1, - }, - }, - }, - }, - }, - ExpectedBody: ` cursor, err := r.collection.Find(arg0, bson.M{ - "age": bson.M{"$gte": arg1}, - }, options.Find().SetSort(bson.M{ - })) - if err != nil { - return nil, err - } - var entities []*UserModel - if err := cursor.All(arg0, &entities); err != nil { - return nil, err - } - return entities, nil`, - }, - { - Name: "find with Between comparator", - MethodSpec: spec.MethodSpec{ - Name: "FindByAgeBetween", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "fromAge", Type: code.TypeInt}, - {Name: "toAge", Type: code.TypeInt}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, - }, - Operation: spec.FindOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{ - Predicates: []spec.Predicate{ - { - Comparator: spec.ComparatorBetween, - FieldReference: spec.FieldReference{ageField}, - ParamIndex: 1, - }, - }, - }, - }, - }, - ExpectedBody: ` cursor, err := r.collection.Find(arg0, bson.M{ - "age": bson.M{"$gte": arg1, "$lte": arg2}, - }, options.Find().SetSort(bson.M{ - })) - if err != nil { - return nil, err - } - var entities []*UserModel - if err := cursor.All(arg0, &entities); err != nil { - return nil, err - } - return entities, nil`, - }, - { - Name: "find with In comparator", - MethodSpec: spec.MethodSpec{ - Name: "FindByGenderIn", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "gender", Type: code.ArrayType{ContainedType: code.SimpleType("Gender")}}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, - }, - Operation: spec.FindOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{ - Predicates: []spec.Predicate{ - { - Comparator: spec.ComparatorIn, - FieldReference: spec.FieldReference{genderField}, - ParamIndex: 1, - }, - }, - }, - }, - }, - ExpectedBody: ` cursor, err := r.collection.Find(arg0, bson.M{ - "gender": bson.M{"$in": arg1}, - }, options.Find().SetSort(bson.M{ - })) - if err != nil { - return nil, err - } - var entities []*UserModel - if err := cursor.All(arg0, &entities); err != nil { - return nil, err - } - return entities, nil`, - }, - { - Name: "find with NotIn comparator", - MethodSpec: spec.MethodSpec{ - Name: "FindByGenderNotIn", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "gender", Type: code.ArrayType{ContainedType: code.SimpleType("Gender")}}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, - }, - Operation: spec.FindOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{ - Predicates: []spec.Predicate{ - { - Comparator: spec.ComparatorNotIn, - FieldReference: spec.FieldReference{genderField}, - ParamIndex: 1, - }, - }, - }, - }, - }, - ExpectedBody: ` cursor, err := r.collection.Find(arg0, bson.M{ - "gender": bson.M{"$nin": arg1}, - }, options.Find().SetSort(bson.M{ - })) - if err != nil { - return nil, err - } - var entities []*UserModel - if err := cursor.All(arg0, &entities); err != nil { - return nil, err - } - return entities, nil`, - }, - { - Name: "find with True comparator", - MethodSpec: spec.MethodSpec{ - Name: "FindByEnabledTrue", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, - }, - Operation: spec.FindOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{ - Predicates: []spec.Predicate{ - { - Comparator: spec.ComparatorTrue, - FieldReference: spec.FieldReference{enabledField}, - ParamIndex: 1, - }, - }, - }, - }, - }, - ExpectedBody: ` cursor, err := r.collection.Find(arg0, bson.M{ - "enabled": true, - }, options.Find().SetSort(bson.M{ - })) - if err != nil { - return nil, err - } - var entities []*UserModel - if err := cursor.All(arg0, &entities); err != nil { - return nil, err - } - return entities, nil`, - }, - { - Name: "find with False comparator", - MethodSpec: spec.MethodSpec{ - Name: "FindByEnabledFalse", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, - }, - Operation: spec.FindOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{ - Predicates: []spec.Predicate{ - { - Comparator: spec.ComparatorFalse, - FieldReference: spec.FieldReference{enabledField}, - ParamIndex: 1, - }, - }, - }, - }, - }, - ExpectedBody: ` cursor, err := r.collection.Find(arg0, bson.M{ - "enabled": false, - }, options.Find().SetSort(bson.M{ - })) - if err != nil { - return nil, err - } - var entities []*UserModel - if err := cursor.All(arg0, &entities); err != nil { - return nil, err - } - return entities, nil`, - }, - { - Name: "find with Exists comparator", - MethodSpec: spec.MethodSpec{ - Name: "FindByReferrerExists", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, - }, - Operation: spec.FindOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{ - Predicates: []spec.Predicate{ - { - Comparator: spec.ComparatorExists, - FieldReference: spec.FieldReference{referrerField}, - ParamIndex: 1, - }, - }, - }, - }, - }, - ExpectedBody: ` cursor, err := r.collection.Find(arg0, bson.M{ - "referrer": bson.M{"$exists": 1}, - }, options.Find().SetSort(bson.M{ - })) - if err != nil { - return nil, err - } - var entities []*UserModel - if err := cursor.All(arg0, &entities); err != nil { - return nil, err - } - return entities, nil`, - }, - { - Name: "find with NotExists comparator", - MethodSpec: spec.MethodSpec{ - Name: "FindByReferrerNotExists", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, - }, - Operation: spec.FindOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{ - Predicates: []spec.Predicate{ - { - Comparator: spec.ComparatorNotExists, - FieldReference: spec.FieldReference{referrerField}, - ParamIndex: 1, - }, - }, - }, - }, - }, - ExpectedBody: ` cursor, err := r.collection.Find(arg0, bson.M{ - "referrer": bson.M{"$exists": 0}, - }, options.Find().SetSort(bson.M{ - })) - if err != nil { - return nil, err - } - var entities []*UserModel - if err := cursor.All(arg0, &entities); err != nil { - return nil, err - } - return entities, nil`, - }, - { - Name: "find with sort ascending", - MethodSpec: spec.MethodSpec{ - Name: "FindAllOrderByAge", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, - }, - Operation: spec.FindOperation{ - Mode: spec.QueryModeMany, - Sorts: []spec.Sort{ - {FieldReference: spec.FieldReference{ageField}, Ordering: spec.OrderingAscending}, - }, - }, - }, - ExpectedBody: ` cursor, err := r.collection.Find(arg0, bson.M{ - - }, options.Find().SetSort(bson.M{ - "age": 1, - })) - if err != nil { - return nil, err - } - var entities []*UserModel - if err := cursor.All(arg0, &entities); err != nil { - return nil, err - } - return entities, nil`, - }, - { - Name: "find with sort descending", - MethodSpec: spec.MethodSpec{ - Name: "FindAllOrderByAgeDesc", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, - }, - Operation: spec.FindOperation{ - Mode: spec.QueryModeMany, - Sorts: []spec.Sort{ - {FieldReference: spec.FieldReference{ageField}, Ordering: spec.OrderingDescending}, - }, - }, - }, - ExpectedBody: ` cursor, err := r.collection.Find(arg0, bson.M{ - - }, options.Find().SetSort(bson.M{ - "age": -1, - })) - if err != nil { - return nil, err - } - var entities []*UserModel - if err := cursor.All(arg0, &entities); err != nil { - return nil, err - } - return entities, nil`, - }, - { - Name: "find with deep sort ascending", - MethodSpec: spec.MethodSpec{ - Name: "FindAllOrderByNameFirst", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, - }, - Operation: spec.FindOperation{ - Mode: spec.QueryModeMany, - Sorts: []spec.Sort{ - { - FieldReference: spec.FieldReference{nameField, firstNameField}, - Ordering: spec.OrderingAscending, - }, - }, - }, - }, - ExpectedBody: ` cursor, err := r.collection.Find(arg0, bson.M{ - - }, options.Find().SetSort(bson.M{ - "name.first": 1, - })) - if err != nil { - return nil, err - } - var entities []*UserModel - if err := cursor.All(arg0, &entities); err != nil { - return nil, err - } - return entities, nil`, - }, - { - Name: "find with multiple sorts", - MethodSpec: spec.MethodSpec{ - Name: "FindAllOrderByGenderAndAgeDesc", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, - }, - Operation: spec.FindOperation{ - Mode: spec.QueryModeMany, - Sorts: []spec.Sort{ - {FieldReference: spec.FieldReference{genderField}, Ordering: spec.OrderingAscending}, - {FieldReference: spec.FieldReference{ageField}, Ordering: spec.OrderingDescending}, - }, - }, - }, - ExpectedBody: ` cursor, err := r.collection.Find(arg0, bson.M{ - - }, options.Find().SetSort(bson.M{ - "gender": 1, - "age": -1, - })) - if err != nil { - return nil, err - } - var entities []*UserModel - if err := cursor.All(arg0, &entities); err != nil { - return nil, err - } - return entities, nil`, - }, - } - - for _, testCase := range testTable { - t.Run(testCase.Name, func(t *testing.T) { - generator := mongo.NewGenerator(userModel, "UserRepository") - var expectedParams []code.Param - for i, param := range testCase.MethodSpec.Params { - expectedParams = append(expectedParams, code.Param{ - Name: fmt.Sprintf("arg%d", i), - Type: param.Type, - }) - } - expected := codegen.MethodBuilder{ - Receiver: codegen.MethodReceiver{ - Name: "r", - Type: "UserRepositoryMongo", - Pointer: true, - }, - Name: testCase.MethodSpec.Name, - Params: expectedParams, - Returns: testCase.MethodSpec.Returns, - Body: testCase.ExpectedBody, - } - - actual, err := generator.GenerateMethod(testCase.MethodSpec) - - if err != nil { - t.Fatal(err) - } - if expected.Receiver != actual.Receiver { - t.Errorf( - "incorrect method receiver: expected %+v, got %+v", - expected.Receiver, - actual.Receiver, - ) - } - if expected.Name != actual.Name { - t.Errorf( - "incorrect method name: expected %s, got %s", - expected.Name, - actual.Name, - ) - } - if !reflect.DeepEqual(expected.Params, actual.Params) { - t.Errorf( - "incorrect struct params: expected %+v, got %+v", - expected.Params, - actual.Params, - ) - } - if !reflect.DeepEqual(expected.Returns, actual.Returns) { - t.Errorf( - "incorrect struct returns: expected %+v, got %+v", - expected.Returns, - actual.Returns, - ) - } - if err := testutils.ExpectMultiLineString(expected.Body, actual.Body); err != nil { - t.Error(err) - } - }) - } -} - -func TestGenerateMethod_Update(t *testing.T) { - testTable := []GenerateMethodTestCase{ - { - Name: "update model method", - MethodSpec: spec.MethodSpec{ - Name: "UpdateByID", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "model", Type: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - {Name: "id", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, - }, - Returns: []code.Type{ - code.TypeBool, - code.TypeError, - }, - Operation: spec.UpdateOperation{ - Update: spec.UpdateModel{}, - Mode: spec.QueryModeOne, - Query: spec.QuerySpec{ - Predicates: []spec.Predicate{ - { - FieldReference: spec.FieldReference{idField}, - Comparator: spec.ComparatorEqual, - ParamIndex: 2, - }, - }, - }, - }, - }, - ExpectedBody: ` result, err := r.collection.UpdateOne(arg0, bson.M{ - "_id": arg2, - }, bson.M{ - "$set": arg1, - }) - if err != nil { - return false, err - } - return result.MatchedCount > 0, err`, - }, - { - Name: "simple update one method", - MethodSpec: spec.MethodSpec{ - Name: "UpdateAgeByID", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "age", Type: code.TypeInt}, - {Name: "id", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, - }, - Returns: []code.Type{ - code.TypeBool, - code.TypeError, - }, - Operation: spec.UpdateOperation{ - Update: spec.UpdateFields{ - spec.UpdateField{ - FieldReference: spec.FieldReference{ageField}, - ParamIndex: 1, - Operator: spec.UpdateOperatorSet, - }, - }, - Mode: spec.QueryModeOne, - Query: spec.QuerySpec{ - Predicates: []spec.Predicate{ - { - FieldReference: spec.FieldReference{idField}, - Comparator: spec.ComparatorEqual, - ParamIndex: 2, - }, - }, - }, - }, - }, - ExpectedBody: ` result, err := r.collection.UpdateOne(arg0, bson.M{ - "_id": arg2, - }, bson.M{ - "$set": bson.M{ - "age": arg1, - }, - }) - if err != nil { - return false, err - } - return result.MatchedCount > 0, err`, - }, - { - Name: "simple update many method", - MethodSpec: spec.MethodSpec{ - Name: "UpdateAgeByGender", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "age", Type: code.TypeInt}, - {Name: "gender", Type: code.SimpleType("Gender")}, - }, - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, - Operation: spec.UpdateOperation{ - Update: spec.UpdateFields{ - spec.UpdateField{ - FieldReference: spec.FieldReference{ageField}, - ParamIndex: 1, - Operator: spec.UpdateOperatorSet, - }, - }, - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{ - Predicates: []spec.Predicate{ - { - FieldReference: spec.FieldReference{genderField}, - Comparator: spec.ComparatorEqual, - ParamIndex: 2, - }, - }, - }, - }, - }, - ExpectedBody: ` result, err := r.collection.UpdateMany(arg0, bson.M{ - "gender": arg2, - }, bson.M{ - "$set": bson.M{ - "age": arg1, - }, - }) - if err != nil { - return 0, err - } - return int(result.MatchedCount), err`, - }, - { - Name: "simple update push method", - MethodSpec: spec.MethodSpec{ - Name: "UpdateConsentHistoryPushByID", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "consentHistory", Type: code.SimpleType("ConsentHistory")}, - {Name: "id", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, - }, - Returns: []code.Type{ - code.TypeBool, - code.TypeError, - }, - Operation: spec.UpdateOperation{ - Update: spec.UpdateFields{ - spec.UpdateField{ - FieldReference: spec.FieldReference{consentHistoryField}, - ParamIndex: 1, - Operator: spec.UpdateOperatorPush, - }, - }, - Mode: spec.QueryModeOne, - Query: spec.QuerySpec{ - Predicates: []spec.Predicate{ - { - FieldReference: spec.FieldReference{idField}, - Comparator: spec.ComparatorEqual, - ParamIndex: 2, - }, - }, - }, - }, - }, - ExpectedBody: ` result, err := r.collection.UpdateOne(arg0, bson.M{ - "_id": arg2, - }, bson.M{ - "$push": bson.M{ - "consent_history": arg1, - }, - }) - if err != nil { - return false, err - } - return result.MatchedCount > 0, err`, - }, - { - Name: "simple update inc method", - MethodSpec: spec.MethodSpec{ - Name: "UpdateAgeIncByID", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "age", Type: code.TypeInt}, - {Name: "id", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, - }, - Returns: []code.Type{ - code.TypeBool, - code.TypeError, - }, - Operation: spec.UpdateOperation{ - Update: spec.UpdateFields{ - spec.UpdateField{ - FieldReference: spec.FieldReference{ageField}, - ParamIndex: 1, - Operator: spec.UpdateOperatorInc, - }, - }, - Mode: spec.QueryModeOne, - Query: spec.QuerySpec{ - Predicates: []spec.Predicate{ - { - FieldReference: spec.FieldReference{idField}, - Comparator: spec.ComparatorEqual, - ParamIndex: 2, - }, - }, - }, - }, - }, - ExpectedBody: ` result, err := r.collection.UpdateOne(arg0, bson.M{ - "_id": arg2, - }, bson.M{ - "$inc": bson.M{ - "age": arg1, - }, - }) - if err != nil { - return false, err - } - return result.MatchedCount > 0, err`, - }, - { - Name: "simple update set and push method", - MethodSpec: spec.MethodSpec{ - Name: "UpdateEnabledAndConsentHistoryPushByID", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "enabled", Type: code.TypeBool}, - {Name: "consentHistory", Type: code.SimpleType("ConsentHistory")}, - {Name: "gender", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, - }, - Returns: []code.Type{ - code.TypeBool, - code.TypeError, - }, - Operation: spec.UpdateOperation{ - Update: spec.UpdateFields{ - spec.UpdateField{ - FieldReference: spec.FieldReference{enabledField}, - ParamIndex: 1, - Operator: spec.UpdateOperatorSet, - }, - spec.UpdateField{ - FieldReference: spec.FieldReference{consentHistoryField}, - ParamIndex: 2, - Operator: spec.UpdateOperatorPush, - }, - }, - Mode: spec.QueryModeOne, - Query: spec.QuerySpec{ - Predicates: []spec.Predicate{ - { - FieldReference: spec.FieldReference{idField}, - Comparator: spec.ComparatorEqual, - ParamIndex: 3, - }, - }, - }, - }, - }, - ExpectedBody: ` result, err := r.collection.UpdateOne(arg0, bson.M{ - "_id": arg3, - }, bson.M{ - "$push": bson.M{ - "consent_history": arg2, - }, - "$set": bson.M{ - "enabled": arg1, - }, - }) - if err != nil { - return false, err - } - return result.MatchedCount > 0, err`, - }, - { - Name: "update with deeply referenced field", - MethodSpec: spec.MethodSpec{ - Name: "UpdateNameFirstByID", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "firstName", Type: code.TypeString}, - {Name: "id", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, - }, - Returns: []code.Type{ - code.TypeBool, - code.TypeError, - }, - Operation: spec.UpdateOperation{ - Update: spec.UpdateFields{ - spec.UpdateField{ - FieldReference: spec.FieldReference{nameField, firstNameField}, - ParamIndex: 1, - Operator: spec.UpdateOperatorSet, - }, - }, - Mode: spec.QueryModeOne, - Query: spec.QuerySpec{ - Predicates: []spec.Predicate{ - { - FieldReference: spec.FieldReference{idField}, - Comparator: spec.ComparatorEqual, - ParamIndex: 2, - }, - }, - }, - }, - }, - ExpectedBody: ` result, err := r.collection.UpdateOne(arg0, bson.M{ - "_id": arg2, - }, bson.M{ - "$set": bson.M{ - "name.first": arg1, - }, - }) - if err != nil { - return false, err - } - return result.MatchedCount > 0, err`, - }, - } - - for _, testCase := range testTable { - t.Run(testCase.Name, func(t *testing.T) { - generator := mongo.NewGenerator(userModel, "UserRepository") - var expectedParams []code.Param - for i, param := range testCase.MethodSpec.Params { - expectedParams = append(expectedParams, code.Param{ - Name: fmt.Sprintf("arg%d", i), - Type: param.Type, - }) - } - expected := codegen.MethodBuilder{ - Receiver: codegen.MethodReceiver{ - Name: "r", - Type: "UserRepositoryMongo", - Pointer: true, - }, - Name: testCase.MethodSpec.Name, - Params: expectedParams, - Returns: testCase.MethodSpec.Returns, - Body: testCase.ExpectedBody, - } - - actual, err := generator.GenerateMethod(testCase.MethodSpec) - - if err != nil { - t.Fatal(err) - } - if expected.Receiver != actual.Receiver { - t.Errorf( - "incorrect method receiver: expected %+v, got %+v", - expected.Receiver, - actual.Receiver, - ) - } - if expected.Name != actual.Name { - t.Errorf( - "incorrect method name: expected %s, got %s", - expected.Name, - actual.Name, - ) - } - if !reflect.DeepEqual(expected.Params, actual.Params) { - t.Errorf( - "incorrect struct params: expected %+v, got %+v", - expected.Params, - actual.Params, - ) - } - if !reflect.DeepEqual(expected.Returns, actual.Returns) { - t.Errorf( - "incorrect struct returns: expected %+v, got %+v", - expected.Returns, - actual.Returns, - ) - } - if err := testutils.ExpectMultiLineString(expected.Body, actual.Body); err != nil { - t.Error(err) - } - }) - } -} - -func TestGenerateMethod_Delete(t *testing.T) { - testTable := []GenerateMethodTestCase{ - { - Name: "simple delete one method", - MethodSpec: spec.MethodSpec{ - Name: "DeleteByID", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "id", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, - }, - Returns: []code.Type{code.TypeBool, code.TypeError}, - Operation: spec.DeleteOperation{ - Mode: spec.QueryModeOne, - Query: spec.QuerySpec{ - Predicates: []spec.Predicate{ - { - Comparator: spec.ComparatorEqual, - FieldReference: spec.FieldReference{idField}, - ParamIndex: 1, - }, - }, - }, - }, - }, - ExpectedBody: ` result, err := r.collection.DeleteOne(arg0, bson.M{ - "_id": arg1, - }) - if err != nil { - return false, err - } - return result.DeletedCount > 0, nil`, - }, - { - Name: "simple delete many method", - MethodSpec: spec.MethodSpec{ - Name: "DeleteByGender", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "gender", Type: code.SimpleType("Gender")}, - }, - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, - Operation: spec.DeleteOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{ - Predicates: []spec.Predicate{ - { - Comparator: spec.ComparatorEqual, - FieldReference: spec.FieldReference{genderField}, - ParamIndex: 1, - }, - }, - }, - }, - }, - ExpectedBody: ` result, err := r.collection.DeleteMany(arg0, bson.M{ - "gender": arg1, - }) - if err != nil { - return 0, err - } - return int(result.DeletedCount), nil`, - }, - { - Name: "delete with And operator", - MethodSpec: spec.MethodSpec{ - Name: "DeleteByGenderAndAge", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "gender", Type: code.SimpleType("Gender")}, - {Name: "age", Type: code.TypeInt}, - }, - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, - Operation: spec.DeleteOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{ - Operator: spec.OperatorAnd, - Predicates: []spec.Predicate{ - { - Comparator: spec.ComparatorEqual, - FieldReference: spec.FieldReference{genderField}, - ParamIndex: 1, - }, - { - Comparator: spec.ComparatorEqual, - FieldReference: spec.FieldReference{ageField}, - ParamIndex: 2, - }, - }, - }, - }, - }, - ExpectedBody: ` result, err := r.collection.DeleteMany(arg0, bson.M{ - "$and": []bson.M{ - {"gender": arg1}, - {"age": arg2}, - }, - }) - if err != nil { - return 0, err - } - return int(result.DeletedCount), nil`, - }, - { - Name: "delete with Or operator", - MethodSpec: spec.MethodSpec{ - Name: "DeleteByGenderOrAge", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "gender", Type: code.SimpleType("Gender")}, - {Name: "age", Type: code.TypeInt}, - }, - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, - Operation: spec.DeleteOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{ - Operator: spec.OperatorOr, - Predicates: []spec.Predicate{ - { - Comparator: spec.ComparatorEqual, - FieldReference: spec.FieldReference{genderField}, - ParamIndex: 1, - }, - { - Comparator: spec.ComparatorEqual, - FieldReference: spec.FieldReference{ageField}, - ParamIndex: 2, - }, - }, - }, - }, - }, - ExpectedBody: ` result, err := r.collection.DeleteMany(arg0, bson.M{ - "$or": []bson.M{ - {"gender": arg1}, - {"age": arg2}, - }, - }) - if err != nil { - return 0, err - } - return int(result.DeletedCount), nil`, - }, - { - Name: "delete with Not comparator", - MethodSpec: spec.MethodSpec{ - Name: "DeleteByGenderNot", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "gender", Type: code.SimpleType("Gender")}, - }, - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, - Operation: spec.DeleteOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{ - Predicates: []spec.Predicate{ - { - Comparator: spec.ComparatorNot, - FieldReference: spec.FieldReference{genderField}, - ParamIndex: 1, - }, - }, - }, - }, - }, - ExpectedBody: ` result, err := r.collection.DeleteMany(arg0, bson.M{ - "gender": bson.M{"$ne": arg1}, - }) - if err != nil { - return 0, err - } - return int(result.DeletedCount), nil`, - }, - { - Name: "delete with LessThan comparator", - MethodSpec: spec.MethodSpec{ - Name: "DeleteByAgeLessThan", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "age", Type: code.TypeInt}, - }, - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, - Operation: spec.DeleteOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{ - Predicates: []spec.Predicate{ - { - Comparator: spec.ComparatorLessThan, - FieldReference: spec.FieldReference{ageField}, - ParamIndex: 1, - }, - }, - }, - }, - }, - ExpectedBody: ` result, err := r.collection.DeleteMany(arg0, bson.M{ - "age": bson.M{"$lt": arg1}, - }) - if err != nil { - return 0, err - } - return int(result.DeletedCount), nil`, - }, - { - Name: "delete with LessThanEqual comparator", - MethodSpec: spec.MethodSpec{ - Name: "DeleteByAgeLessThanEqual", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "age", Type: code.TypeInt}, - }, - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, - Operation: spec.DeleteOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{ - Predicates: []spec.Predicate{ - { - Comparator: spec.ComparatorLessThanEqual, - FieldReference: spec.FieldReference{ageField}, - ParamIndex: 1, - }, - }, - }, - }, - }, - ExpectedBody: ` result, err := r.collection.DeleteMany(arg0, bson.M{ - "age": bson.M{"$lte": arg1}, - }) - if err != nil { - return 0, err - } - return int(result.DeletedCount), nil`, - }, - { - Name: "delete with GreaterThan comparator", - MethodSpec: spec.MethodSpec{ - Name: "DeleteByAgeGreaterThan", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "age", Type: code.TypeInt}, - }, - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, - Operation: spec.DeleteOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{ - Predicates: []spec.Predicate{ - { - Comparator: spec.ComparatorGreaterThan, - FieldReference: spec.FieldReference{ageField}, - ParamIndex: 1, - }, - }, - }, - }, - }, - ExpectedBody: ` result, err := r.collection.DeleteMany(arg0, bson.M{ - "age": bson.M{"$gt": arg1}, - }) - if err != nil { - return 0, err - } - return int(result.DeletedCount), nil`, - }, - { - Name: "delete with GreaterThanEqual comparator", - MethodSpec: spec.MethodSpec{ - Name: "DeleteByAgeGreaterThanEqual", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "age", Type: code.TypeInt}, - }, - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, - Operation: spec.DeleteOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{ - Predicates: []spec.Predicate{ - { - Comparator: spec.ComparatorGreaterThanEqual, - FieldReference: spec.FieldReference{ageField}, - ParamIndex: 1, - }, - }, - }, - }, - }, - ExpectedBody: ` result, err := r.collection.DeleteMany(arg0, bson.M{ - "age": bson.M{"$gte": arg1}, - }) - if err != nil { - return 0, err - } - return int(result.DeletedCount), nil`, - }, - { - Name: "delete with Between comparator", - MethodSpec: spec.MethodSpec{ - Name: "DeleteByAgeBetween", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "fromAge", Type: code.TypeInt}, - {Name: "toAge", Type: code.TypeInt}, - }, - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, - Operation: spec.DeleteOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{ - Predicates: []spec.Predicate{ - { - Comparator: spec.ComparatorBetween, - FieldReference: spec.FieldReference{ageField}, - ParamIndex: 1, - }, - }, - }, - }, - }, - ExpectedBody: ` result, err := r.collection.DeleteMany(arg0, bson.M{ - "age": bson.M{"$gte": arg1, "$lte": arg2}, - }) - if err != nil { - return 0, err - } - return int(result.DeletedCount), nil`, - }, - { - Name: "delete with In comparator", - MethodSpec: spec.MethodSpec{ - Name: "DeleteByGenderIn", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "gender", Type: code.ArrayType{ContainedType: code.SimpleType("Gender")}}, - }, - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, - Operation: spec.DeleteOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{ - Predicates: []spec.Predicate{ - { - Comparator: spec.ComparatorIn, - FieldReference: spec.FieldReference{genderField}, - ParamIndex: 1, - }, - }, - }, - }, - }, - ExpectedBody: ` result, err := r.collection.DeleteMany(arg0, bson.M{ - "gender": bson.M{"$in": arg1}, - }) - if err != nil { - return 0, err - } - return int(result.DeletedCount), nil`, - }, - } - - for _, testCase := range testTable { - t.Run(testCase.Name, func(t *testing.T) { - generator := mongo.NewGenerator(userModel, "UserRepository") - var expectedParams []code.Param - for i, param := range testCase.MethodSpec.Params { - expectedParams = append(expectedParams, code.Param{ - Name: fmt.Sprintf("arg%d", i), - Type: param.Type, - }) - } - expected := codegen.MethodBuilder{ - Receiver: codegen.MethodReceiver{ - Name: "r", - Type: "UserRepositoryMongo", - Pointer: true, - }, - Name: testCase.MethodSpec.Name, - Params: expectedParams, - Returns: testCase.MethodSpec.Returns, - Body: testCase.ExpectedBody, - } - - actual, err := generator.GenerateMethod(testCase.MethodSpec) - - if err != nil { - t.Fatal(err) - } - if expected.Receiver != actual.Receiver { - t.Errorf( - "incorrect method receiver: expected %+v, got %+v", - expected.Receiver, - actual.Receiver, - ) - } - if expected.Name != actual.Name { - t.Errorf( - "incorrect method name: expected %s, got %s", - expected.Name, - actual.Name, - ) - } - if !reflect.DeepEqual(expected.Params, actual.Params) { - t.Errorf( - "incorrect struct params: expected %+v, got %+v", - expected.Params, - actual.Params, - ) - } - if !reflect.DeepEqual(expected.Returns, actual.Returns) { - t.Errorf( - "incorrect struct returns: expected %+v, got %+v", - expected.Returns, - actual.Returns, - ) - } - if err := testutils.ExpectMultiLineString(expected.Body, actual.Body); err != nil { - t.Error(err) - } - }) - } -} - -func TestGenerateMethod_Count(t *testing.T) { - testTable := []GenerateMethodTestCase{ - { - Name: "simple count method", - MethodSpec: spec.MethodSpec{ - Name: "CountByGender", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.SimpleType("Gender")}, - }, - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, - Operation: spec.CountOperation{ - Query: spec.QuerySpec{ - Predicates: []spec.Predicate{ - { - FieldReference: spec.FieldReference{genderField}, - Comparator: spec.ComparatorEqual, - ParamIndex: 1, - }, - }, - }, - }, - }, - ExpectedBody: ` count, err := r.collection.CountDocuments(arg0, bson.M{ - "gender": arg1, - }) - if err != nil { - return 0, err - } - return int(count), nil`, - }, - { - Name: "count with And operator", - MethodSpec: spec.MethodSpec{ - Name: "CountByGenderAndCity", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.SimpleType("Gender")}, - {Type: code.TypeInt}, - }, - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, - Operation: spec.CountOperation{ - Query: spec.QuerySpec{ - Operator: spec.OperatorAnd, - Predicates: []spec.Predicate{ - { - FieldReference: spec.FieldReference{genderField}, - Comparator: spec.ComparatorEqual, - ParamIndex: 1, - }, - { - FieldReference: spec.FieldReference{ageField}, - Comparator: spec.ComparatorEqual, - ParamIndex: 2, - }, - }, - }, - }, - }, - ExpectedBody: ` count, err := r.collection.CountDocuments(arg0, bson.M{ - "$and": []bson.M{ - {"gender": arg1}, - {"age": arg2}, - }, - }) - if err != nil { - return 0, err - } - return int(count), nil`, - }, - { - Name: "count with Or operator", - MethodSpec: spec.MethodSpec{ - Name: "CountByGenderOrCity", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.SimpleType("Gender")}, - {Type: code.TypeInt}, - }, - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, - Operation: spec.CountOperation{ - Query: spec.QuerySpec{ - Operator: spec.OperatorOr, - Predicates: []spec.Predicate{ - { - FieldReference: spec.FieldReference{genderField}, - Comparator: spec.ComparatorEqual, - ParamIndex: 1, - }, - { - FieldReference: spec.FieldReference{ageField}, - Comparator: spec.ComparatorEqual, - ParamIndex: 2, - }, - }, - }, - }, - }, - ExpectedBody: ` count, err := r.collection.CountDocuments(arg0, bson.M{ - "$or": []bson.M{ - {"gender": arg1}, - {"age": arg2}, - }, - }) - if err != nil { - return 0, err - } - return int(count), nil`, - }, - { - Name: "count with Not comparator", - MethodSpec: spec.MethodSpec{ - Name: "CountByGenderNot", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.SimpleType("Gender")}, - }, - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, - Operation: spec.CountOperation{ - Query: spec.QuerySpec{ - Predicates: []spec.Predicate{ - { - FieldReference: spec.FieldReference{genderField}, - Comparator: spec.ComparatorNot, - ParamIndex: 1, - }, - }, - }, - }, - }, - ExpectedBody: ` count, err := r.collection.CountDocuments(arg0, bson.M{ - "gender": bson.M{"$ne": arg1}, - }) - if err != nil { - return 0, err - } - return int(count), nil`, - }, - { - Name: "count with LessThan comparator", - MethodSpec: spec.MethodSpec{ - Name: "CountByAgeLessThan", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.TypeInt}, - }, - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, - Operation: spec.CountOperation{ - Query: spec.QuerySpec{ - Predicates: []spec.Predicate{ - { - FieldReference: spec.FieldReference{ageField}, - Comparator: spec.ComparatorLessThan, - ParamIndex: 1, - }, - }, - }, - }, - }, - ExpectedBody: ` count, err := r.collection.CountDocuments(arg0, bson.M{ - "age": bson.M{"$lt": arg1}, - }) - if err != nil { - return 0, err - } - return int(count), nil`, - }, - { - Name: "count with LessThanEqual comparator", - MethodSpec: spec.MethodSpec{ - Name: "CountByAgeLessThanEqual", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.TypeInt}, - }, - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, - Operation: spec.CountOperation{ - Query: spec.QuerySpec{ - Predicates: []spec.Predicate{ - { - FieldReference: spec.FieldReference{ageField}, - Comparator: spec.ComparatorLessThanEqual, - ParamIndex: 1, - }, - }, - }, - }, - }, - ExpectedBody: ` count, err := r.collection.CountDocuments(arg0, bson.M{ - "age": bson.M{"$lte": arg1}, - }) - if err != nil { - return 0, err - } - return int(count), nil`, - }, - { - Name: "count with GreaterThan comparator", - MethodSpec: spec.MethodSpec{ - Name: "CountByAgeGreaterThan", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.TypeInt}, - }, - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, - Operation: spec.CountOperation{ - Query: spec.QuerySpec{ - Predicates: []spec.Predicate{ - { - FieldReference: spec.FieldReference{ageField}, - Comparator: spec.ComparatorGreaterThan, - ParamIndex: 1, - }, - }, - }, - }, - }, - ExpectedBody: ` count, err := r.collection.CountDocuments(arg0, bson.M{ - "age": bson.M{"$gt": arg1}, - }) - if err != nil { - return 0, err - } - return int(count), nil`, - }, - { - Name: "count with GreaterThanEqual comparator", - MethodSpec: spec.MethodSpec{ - Name: "CountByAgeGreaterThanEqual", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.TypeInt}, - }, - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, - Operation: spec.CountOperation{ - Query: spec.QuerySpec{ - Predicates: []spec.Predicate{ - { - FieldReference: spec.FieldReference{ageField}, - Comparator: spec.ComparatorGreaterThanEqual, - ParamIndex: 1, - }, - }, - }, - }, - }, - ExpectedBody: ` count, err := r.collection.CountDocuments(arg0, bson.M{ - "age": bson.M{"$gte": arg1}, - }) - if err != nil { - return 0, err - } - return int(count), nil`, - }, - { - Name: "count with Between comparator", - MethodSpec: spec.MethodSpec{ - Name: "CountByAgeBetween", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.TypeInt}, - {Type: code.TypeInt}, - }, - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, - Operation: spec.CountOperation{ - Query: spec.QuerySpec{ - Predicates: []spec.Predicate{ - { - FieldReference: spec.FieldReference{ageField}, - Comparator: spec.ComparatorBetween, - ParamIndex: 1, - }, - }, - }, - }, - }, - ExpectedBody: ` count, err := r.collection.CountDocuments(arg0, bson.M{ - "age": bson.M{"$gte": arg1, "$lte": arg2}, - }) - if err != nil { - return 0, err - } - return int(count), nil`, - }, - { - Name: "count with In comparator", - MethodSpec: spec.MethodSpec{ - Name: "CountByAgeIn", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.ArrayType{ContainedType: code.TypeInt}}, - }, - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, - Operation: spec.CountOperation{ - Query: spec.QuerySpec{ - Predicates: []spec.Predicate{ - { - FieldReference: spec.FieldReference{ageField}, - Comparator: spec.ComparatorIn, - ParamIndex: 1, - }, - }, - }, - }, - }, - ExpectedBody: ` count, err := r.collection.CountDocuments(arg0, bson.M{ - "age": bson.M{"$in": arg1}, - }) - if err != nil { - return 0, err - } - return int(count), nil`, - }, - } - - for _, testCase := range testTable { - t.Run(testCase.Name, func(t *testing.T) { - generator := mongo.NewGenerator(userModel, "UserRepository") - var expectedParams []code.Param - for i, param := range testCase.MethodSpec.Params { - expectedParams = append(expectedParams, code.Param{ - Name: fmt.Sprintf("arg%d", i), - Type: param.Type, - }) - } - expected := codegen.MethodBuilder{ - Receiver: codegen.MethodReceiver{ - Name: "r", - Type: "UserRepositoryMongo", - Pointer: true, - }, - Name: testCase.MethodSpec.Name, - Params: expectedParams, - Returns: testCase.MethodSpec.Returns, - Body: testCase.ExpectedBody, - } - - actual, err := generator.GenerateMethod(testCase.MethodSpec) - - if err != nil { - t.Fatal(err) - } - if expected.Receiver != actual.Receiver { - t.Errorf( - "incorrect method receiver: expected %+v, got %+v", - expected.Receiver, - actual.Receiver, - ) - } - if expected.Name != actual.Name { - t.Errorf( - "incorrect method name: expected %s, got %s", - expected.Name, - actual.Name, - ) - } - if !reflect.DeepEqual(expected.Params, actual.Params) { - t.Errorf( - "incorrect struct params: expected %+v, got %+v", - expected.Params, - actual.Params, - ) - } - if !reflect.DeepEqual(expected.Returns, actual.Returns) { - t.Errorf( - "incorrect struct returns: expected %+v, got %+v", - expected.Returns, - actual.Returns, - ) - } - if err := testutils.ExpectMultiLineString(expected.Body, actual.Body); err != nil { - t.Error(err) - } - }) - } -} - type GenerateMethodInvalidTestCase struct { Name string Method spec.MethodSpec diff --git a/internal/mongo/insert.go b/internal/mongo/insert.go new file mode 100644 index 0000000..996acbc --- /dev/null +++ b/internal/mongo/insert.go @@ -0,0 +1,81 @@ +package mongo + +import ( + "github.com/sunboyy/repogen/internal/code" + "github.com/sunboyy/repogen/internal/codegen" + "github.com/sunboyy/repogen/internal/spec" +) + +func (g RepositoryGenerator) generateInsertBody( + operation spec.InsertOperation) codegen.FunctionBody { + + if operation.Mode == spec.QueryModeOne { + return g.generateInsertOneBody() + } + return g.generateInsertManyBody() +} + +func (g RepositoryGenerator) generateInsertOneBody() codegen.FunctionBody { + return codegen.FunctionBody{ + codegen.DeclAssignStatement{ + Vars: []string{"result", "err"}, + Values: codegen.StatementList{ + codegen.NewChainBuilder("r"). + Chain("collection"). + Call("InsertOne", + codegen.Identifier("arg0"), + codegen.Identifier("arg1"), + ).Build(), + }, + }, + ifErrReturnNilErr, + codegen.ReturnStatement{ + codegen.NewChainBuilder("result").Chain("InsertedID").Build(), + codegen.Identifier("nil"), + }, + } +} + +func (g RepositoryGenerator) generateInsertManyBody() codegen.FunctionBody { + return codegen.FunctionBody{ + codegen.DeclStatement{ + Name: "entities", + Type: code.ArrayType{ + ContainedType: code.InterfaceType{}, + }, + }, + codegen.RawBlock{ + Header: []string{"for _, model := range arg1"}, + Statements: []codegen.Statement{ + codegen.AssignStatement{ + Vars: []string{"entities"}, + Values: codegen.StatementList{ + codegen.CallStatement{ + FuncName: "append", + Params: codegen.StatementList{ + codegen.Identifier("entities"), + codegen.Identifier("model"), + }, + }, + }, + }, + }, + }, + codegen.DeclAssignStatement{ + Vars: []string{"result", "err"}, + Values: codegen.StatementList{ + codegen.NewChainBuilder("r"). + Chain("collection"). + Call("InsertMany", + codegen.Identifier("arg0"), + codegen.Identifier("entities"), + ).Build(), + }, + }, + ifErrReturnNilErr, + codegen.ReturnStatement{ + codegen.NewChainBuilder("result").Chain("InsertedIDs").Build(), + codegen.Identifier("nil"), + }, + } +} diff --git a/internal/mongo/insert_test.go b/internal/mongo/insert_test.go new file mode 100644 index 0000000..c921192 --- /dev/null +++ b/internal/mongo/insert_test.go @@ -0,0 +1,123 @@ +package mongo_test + +import ( + "fmt" + "reflect" + "testing" + + "github.com/sunboyy/repogen/internal/code" + "github.com/sunboyy/repogen/internal/codegen" + "github.com/sunboyy/repogen/internal/mongo" + "github.com/sunboyy/repogen/internal/spec" + "github.com/sunboyy/repogen/internal/testutils" +) + +func TestGenerateMethod_Insert(t *testing.T) { + testTable := []GenerateMethodTestCase{ + { + Name: "insert one method", + MethodSpec: spec.MethodSpec{ + Name: "InsertOne", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Name: "userModel", Type: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + }, + Returns: []code.Type{ + code.InterfaceType{}, + code.TypeError, + }, + Operation: spec.InsertOperation{ + Mode: spec.QueryModeOne, + }, + }, + ExpectedBody: ` result, err := r.collection.InsertOne(arg0, arg1) + if err != nil { + return nil, err + } + return result.InsertedID, nil`, + }, + { + Name: "insert many method", + MethodSpec: spec.MethodSpec{ + Name: "Insert", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Name: "userModel", Type: code.ArrayType{ + ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}, + }}, + }, + Returns: []code.Type{ + code.ArrayType{ContainedType: code.InterfaceType{}}, + code.TypeError, + }, + Operation: spec.InsertOperation{ + Mode: spec.QueryModeMany, + }, + }, + ExpectedBody: ` var entities []interface{} + for _, model := range arg1 { + entities = append(entities, model) + } + result, err := r.collection.InsertMany(arg0, entities) + if err != nil { + return nil, err + } + return result.InsertedIDs, nil`, + }, + } + + for _, testCase := range testTable { + t.Run(testCase.Name, func(t *testing.T) { + generator := mongo.NewGenerator(userModel, "UserRepository") + expectedReceiver := codegen.MethodReceiver{ + Name: "r", + Type: "UserRepositoryMongo", + Pointer: true, + } + var expectedParams []code.Param + for i, param := range testCase.MethodSpec.Params { + expectedParams = append(expectedParams, code.Param{ + Name: fmt.Sprintf("arg%d", i), + Type: param.Type, + }) + } + + actual, err := generator.GenerateMethod(testCase.MethodSpec) + + if err != nil { + t.Fatal(err) + } + if expectedReceiver != actual.Receiver { + t.Errorf( + "incorrect method receiver: expected %+v, got %+v", + expectedReceiver, + actual.Receiver, + ) + } + if testCase.MethodSpec.Name != actual.Name { + t.Errorf( + "incorrect method name: expected %s, got %s", + testCase.MethodSpec.Name, + actual.Name, + ) + } + if !reflect.DeepEqual(expectedParams, actual.Params) { + t.Errorf( + "incorrect struct params: expected %+v, got %+v", + expectedParams, + actual.Params, + ) + } + if !reflect.DeepEqual(testCase.MethodSpec.Returns, actual.Returns) { + t.Errorf( + "incorrect struct returns: expected %+v, got %+v", + testCase.MethodSpec.Returns, + actual.Returns, + ) + } + if err := testutils.ExpectMultiLineString(testCase.ExpectedBody, actual.Body.Code()); err != nil { + t.Error(err) + } + }) + } +} diff --git a/internal/mongo/models.go b/internal/mongo/models.go index 5d5ee75..4b863de 100644 --- a/internal/mongo/models.go +++ b/internal/mongo/models.go @@ -3,8 +3,9 @@ package mongo import ( "fmt" "sort" - "strings" + "github.com/sunboyy/repogen/internal/code" + "github.com/sunboyy/repogen/internal/codegen" "github.com/sunboyy/repogen/internal/spec" ) @@ -14,36 +15,54 @@ type updateField struct { } type update interface { - Code() string + Code() codegen.Statement } type updateModel struct { } -func (u updateModel) Code() string { - return ` "$set": arg1,` +func (u updateModel) Code() codegen.Statement { + return codegen.MapStatement{ + Type: "bson.M", + Pairs: []codegen.MapPair{ + { + Key: "$set", + Value: codegen.Identifier("arg1"), + }, + }, + } } type updateFields map[string][]updateField -func (u updateFields) Code() string { +func (u updateFields) Code() codegen.Statement { var keys []string for k := range u { keys = append(keys, k) } sort.Strings(keys) - var lines []string + stmt := codegen.MapStatement{ + Type: "bson.M", + } for _, key := range keys { - lines = append(lines, fmt.Sprintf(` "%s": bson.M{`, key)) - - for _, field := range u[key] { - lines = append(lines, fmt.Sprintf(` "%s": arg%d,`, field.BsonTag, field.ParamIndex)) + applicationMap := codegen.MapStatement{ + Type: "bson.M", } - lines = append(lines, ` },`) + for _, field := range u[key] { + applicationMap.Pairs = append(applicationMap.Pairs, codegen.MapPair{ + Key: field.BsonTag, + Value: codegen.Identifier(fmt.Sprintf("arg%d", field.ParamIndex)), + }) + } + + stmt.Pairs = append(stmt.Pairs, codegen.MapPair{ + Key: key, + Value: applicationMap, + }) } - return strings.Join(lines, "\n") + return stmt } type querySpec struct { @@ -51,32 +70,52 @@ type querySpec struct { Predicates []predicate } -func (q querySpec) Code() string { - var predicateCodes []string +func (q querySpec) Code() codegen.Statement { + var predicatePairs []codegen.MapPair for _, predicate := range q.Predicates { - predicateCodes = append(predicateCodes, predicate.Code()) + predicatePairs = append(predicatePairs, predicate.Code()) + } + var predicateMaps []codegen.Statement + for _, pair := range predicatePairs { + predicateMaps = append(predicateMaps, codegen.MapStatement{ + Pairs: []codegen.MapPair{pair}, + }) } - var lines []string + stmt := codegen.MapStatement{ + Type: "bson.M", + } switch q.Operator { case spec.OperatorOr: - lines = append(lines, ` "$or": []bson.M{`) - for _, predicateCode := range predicateCodes { - lines = append(lines, fmt.Sprintf(` {%s},`, predicateCode)) - } - lines = append(lines, ` },`) + stmt.Pairs = append(stmt.Pairs, codegen.MapPair{ + Key: "$or", + Value: codegen.SliceStatement{ + Type: code.ArrayType{ + ContainedType: code.ExternalType{ + PackageAlias: "bson", + Name: "M", + }, + }, + Values: predicateMaps, + }, + }) case spec.OperatorAnd: - lines = append(lines, ` "$and": []bson.M{`) - for _, predicateCode := range predicateCodes { - lines = append(lines, fmt.Sprintf(` {%s},`, predicateCode)) - } - lines = append(lines, ` },`) + stmt.Pairs = append(stmt.Pairs, codegen.MapPair{ + Key: "$and", + Value: codegen.SliceStatement{ + Type: code.ArrayType{ + ContainedType: code.ExternalType{ + PackageAlias: "bson", + Name: "M", + }, + }, + Values: predicateMaps, + }, + }) default: - for _, predicateCode := range predicateCodes { - lines = append(lines, fmt.Sprintf(` %s,`, predicateCode)) - } + stmt.Pairs = predicatePairs } - return strings.Join(lines, "\n") + return stmt } type predicate struct { @@ -85,34 +124,86 @@ type predicate struct { ParamIndex int } -func (p predicate) Code() string { +func (p predicate) Code() codegen.MapPair { + argStmt := codegen.Identifier(fmt.Sprintf("arg%d", p.ParamIndex)) + switch p.Comparator { case spec.ComparatorEqual: - return fmt.Sprintf(`"%s": arg%d`, p.Field, p.ParamIndex) + return p.createValueMapPair(argStmt) case spec.ComparatorNot: - return fmt.Sprintf(`"%s": bson.M{"$ne": arg%d}`, p.Field, p.ParamIndex) + return p.createSingleComparisonMapPair("$ne", argStmt) case spec.ComparatorLessThan: - return fmt.Sprintf(`"%s": bson.M{"$lt": arg%d}`, p.Field, p.ParamIndex) + return p.createSingleComparisonMapPair("$lt", argStmt) case spec.ComparatorLessThanEqual: - return fmt.Sprintf(`"%s": bson.M{"$lte": arg%d}`, p.Field, p.ParamIndex) + return p.createSingleComparisonMapPair("$lte", argStmt) case spec.ComparatorGreaterThan: - return fmt.Sprintf(`"%s": bson.M{"$gt": arg%d}`, p.Field, p.ParamIndex) + return p.createSingleComparisonMapPair("$gt", argStmt) case spec.ComparatorGreaterThanEqual: - return fmt.Sprintf(`"%s": bson.M{"$gte": arg%d}`, p.Field, p.ParamIndex) + return p.createSingleComparisonMapPair("$gte", argStmt) case spec.ComparatorBetween: - return fmt.Sprintf(`"%s": bson.M{"$gte": arg%d, "$lte": arg%d}`, p.Field, p.ParamIndex, p.ParamIndex+1) + argStmt2 := codegen.Identifier(fmt.Sprintf("arg%d", p.ParamIndex+1)) + return p.createBetweenMapPair(argStmt, argStmt2) case spec.ComparatorIn: - return fmt.Sprintf(`"%s": bson.M{"$in": arg%d}`, p.Field, p.ParamIndex) + return p.createSingleComparisonMapPair("$in", argStmt) case spec.ComparatorNotIn: - return fmt.Sprintf(`"%s": bson.M{"$nin": arg%d}`, p.Field, p.ParamIndex) + return p.createSingleComparisonMapPair("$nin", argStmt) case spec.ComparatorTrue: - return fmt.Sprintf(`"%s": true`, p.Field) + return p.createValueMapPair(codegen.Identifier("true")) case spec.ComparatorFalse: - return fmt.Sprintf(`"%s": false`, p.Field) + return p.createValueMapPair(codegen.Identifier("false")) case spec.ComparatorExists: - return fmt.Sprintf(`"%s": bson.M{"$exists": 1}`, p.Field) + return p.createExistsMapPair("1") case spec.ComparatorNotExists: - return fmt.Sprintf(`"%s": bson.M{"$exists": 0}`, p.Field) + return p.createExistsMapPair("0") + } + return codegen.MapPair{} +} + +func (p predicate) createValueMapPair( + argStmt codegen.Statement) codegen.MapPair { + + return codegen.MapPair{ + Key: p.Field, + Value: argStmt, + } +} + +func (p predicate) createSingleComparisonMapPair(comparatorKey string, + argStmt codegen.Statement) codegen.MapPair { + + return codegen.MapPair{ + Key: p.Field, + Value: codegen.MapStatement{ + Type: "bson.M", + Pairs: []codegen.MapPair{{Key: comparatorKey, Value: argStmt}}, + }, + } +} + +func (p predicate) createBetweenMapPair(argStmt codegen.Statement, + argStmt2 codegen.Statement) codegen.MapPair { + + return codegen.MapPair{ + Key: p.Field, + Value: codegen.MapStatement{ + Type: "bson.M", + Pairs: []codegen.MapPair{ + {Key: "$gte", Value: argStmt}, + {Key: "$lte", Value: argStmt2}, + }, + }, + } +} + +func (p predicate) createExistsMapPair(existsValue string) codegen.MapPair { + return codegen.MapPair{ + Key: p.Field, + Value: codegen.MapStatement{ + Type: "bson.M", + Pairs: []codegen.MapPair{{ + Key: "$exists", + Value: codegen.Identifier(existsValue), + }}, + }, } - return "" } diff --git a/internal/mongo/templates.go b/internal/mongo/templates.go deleted file mode 100644 index 28df6ba..0000000 --- a/internal/mongo/templates.go +++ /dev/null @@ -1,128 +0,0 @@ -package mongo - -import ( - "github.com/sunboyy/repogen/internal/spec" -) - -const constructorBody = ` return &{{.ImplStructName}}{ - collection: collection, - }` - -type constructorBodyData struct { - ImplStructName string -} - -const insertOneTemplate = ` result, err := r.collection.InsertOne(arg0, arg1) - if err != nil { - return nil, err - } - return result.InsertedID, nil` - -const insertManyTemplate = ` var entities []interface{} - for _, model := range arg1 { - entities = append(entities, model) - } - result, err := r.collection.InsertMany(arg0, entities) - if err != nil { - return nil, err - } - return result.InsertedIDs, nil` - -type mongoFindTemplateData struct { - EntityType string - QuerySpec querySpec - Sorts []findSort -} - -type findSort struct { - BsonTag string - Ordering spec.Ordering -} - -func (s findSort) OrderNum() int { - if s.Ordering == spec.OrderingAscending { - return 1 - } - return -1 -} - -const findOneTemplate = ` var entity {{.EntityType}} - if err := r.collection.FindOne(arg0, bson.M{ -{{.QuerySpec.Code}} - }, options.FindOne().SetSort(bson.M{ -{{range $index, $element := .Sorts}} "{{$element.BsonTag}}": {{$element.OrderNum}}, -{{end}} })).Decode(&entity); err != nil { - return nil, err - } - return &entity, nil` - -const findManyTemplate = ` cursor, err := r.collection.Find(arg0, bson.M{ -{{.QuerySpec.Code}} - }, options.Find().SetSort(bson.M{ -{{range $index, $element := .Sorts}} "{{$element.BsonTag}}": {{$element.OrderNum}}, -{{end}} })) - if err != nil { - return nil, err - } - var entities []*{{.EntityType}} - if err := cursor.All(arg0, &entities); err != nil { - return nil, err - } - return entities, nil` - -type mongoUpdateTemplateData struct { - Update update - QuerySpec querySpec -} - -const updateOneTemplate = ` result, err := r.collection.UpdateOne(arg0, bson.M{ -{{.QuerySpec.Code}} - }, bson.M{ -{{.Update.Code}} - }) - if err != nil { - return false, err - } - return result.MatchedCount > 0, err` - -const updateManyTemplate = ` result, err := r.collection.UpdateMany(arg0, bson.M{ -{{.QuerySpec.Code}} - }, bson.M{ -{{.Update.Code}} - }) - if err != nil { - return 0, err - } - return int(result.MatchedCount), err` - -type mongoDeleteTemplateData struct { - QuerySpec querySpec -} - -const deleteOneTemplate = ` result, err := r.collection.DeleteOne(arg0, bson.M{ -{{.QuerySpec.Code}} - }) - if err != nil { - return false, err - } - return result.DeletedCount > 0, nil` - -const deleteManyTemplate = ` result, err := r.collection.DeleteMany(arg0, bson.M{ -{{.QuerySpec.Code}} - }) - if err != nil { - return 0, err - } - return int(result.DeletedCount), nil` - -type mongoCountTemplateData struct { - QuerySpec querySpec -} - -const countTemplate = ` count, err := r.collection.CountDocuments(arg0, bson.M{ -{{.QuerySpec.Code}} - }) - if err != nil { - return 0, err - } - return int(count), nil` diff --git a/internal/mongo/update.go b/internal/mongo/update.go new file mode 100644 index 0000000..a29f77e --- /dev/null +++ b/internal/mongo/update.go @@ -0,0 +1,133 @@ +package mongo + +import ( + "github.com/sunboyy/repogen/internal/codegen" + "github.com/sunboyy/repogen/internal/spec" +) + +func (g RepositoryGenerator) generateUpdateBody( + operation spec.UpdateOperation) (codegen.FunctionBody, error) { + + return updateBodyGenerator{ + baseMethodGenerator: g.baseMethodGenerator, + operation: operation, + }.generate() +} + +type updateBodyGenerator struct { + baseMethodGenerator + operation spec.UpdateOperation +} + +func (g updateBodyGenerator) generate() (codegen.FunctionBody, error) { + update, err := g.convertUpdate(g.operation.Update) + if err != nil { + return nil, err + } + + querySpec, err := g.convertQuerySpec(g.operation.Query) + if err != nil { + return nil, err + } + + if g.operation.Mode == spec.QueryModeOne { + return g.generateUpdateOneBody(update, querySpec), nil + } + + return g.generateUpdateManyBody(update, querySpec), nil +} + +func (g updateBodyGenerator) generateUpdateOneBody(update update, + querySpec querySpec) codegen.FunctionBody { + + return codegen.FunctionBody{ + codegen.DeclAssignStatement{ + Vars: []string{"result", "err"}, + Values: codegen.StatementList{ + codegen.NewChainBuilder("r"). + Chain("collection"). + Call("UpdateOne", + codegen.Identifier("arg0"), + querySpec.Code(), + update.Code(), + ).Build(), + }, + }, + ifErrReturnFalseErr, + codegen.ReturnStatement{ + codegen.RawStatement("result.MatchedCount > 0"), + codegen.Identifier("nil"), + }, + } +} + +func (g updateBodyGenerator) generateUpdateManyBody(update update, + querySpec querySpec) codegen.FunctionBody { + + return codegen.FunctionBody{ + codegen.DeclAssignStatement{ + Vars: []string{"result", "err"}, + Values: codegen.StatementList{ + codegen.NewChainBuilder("r"). + Chain("collection"). + Call("UpdateMany", + codegen.Identifier("arg0"), + querySpec.Code(), + update.Code(), + ).Build(), + }, + }, + ifErrReturn0Err, + codegen.ReturnStatement{ + codegen.CallStatement{ + FuncName: "int", + Params: codegen.StatementList{ + codegen.NewChainBuilder("result"). + Chain("MatchedCount").Build(), + }, + }, + codegen.Identifier("nil"), + }, + } +} + +func (g updateBodyGenerator) convertUpdate(updateSpec spec.Update) (update, error) { + switch updateSpec := updateSpec.(type) { + case spec.UpdateModel: + return updateModel{}, nil + case spec.UpdateFields: + update := make(updateFields) + for _, field := range updateSpec { + bsonFieldReference, err := g.bsonFieldReference(field.FieldReference) + if err != nil { + return nil, err + } + + updateKey := getUpdateOperatorKey(field.Operator) + if updateKey == "" { + return nil, NewUpdateOperatorNotSupportedError(field.Operator) + } + updateField := updateField{ + BsonTag: bsonFieldReference, + ParamIndex: field.ParamIndex, + } + update[updateKey] = append(update[updateKey], updateField) + } + return update, nil + default: + return nil, NewUpdateTypeNotSupportedError(updateSpec) + } +} + +func getUpdateOperatorKey(operator spec.UpdateOperator) string { + switch operator { + case spec.UpdateOperatorSet: + return "$set" + case spec.UpdateOperatorPush: + return "$push" + case spec.UpdateOperatorInc: + return "$inc" + default: + return "" + } +} diff --git a/internal/mongo/update_test.go b/internal/mongo/update_test.go new file mode 100644 index 0000000..d67a6e5 --- /dev/null +++ b/internal/mongo/update_test.go @@ -0,0 +1,389 @@ +package mongo_test + +import ( + "fmt" + "reflect" + "testing" + + "github.com/sunboyy/repogen/internal/code" + "github.com/sunboyy/repogen/internal/codegen" + "github.com/sunboyy/repogen/internal/mongo" + "github.com/sunboyy/repogen/internal/spec" + "github.com/sunboyy/repogen/internal/testutils" +) + +func TestGenerateMethod_Update(t *testing.T) { + testTable := []GenerateMethodTestCase{ + { + Name: "update model method", + MethodSpec: spec.MethodSpec{ + Name: "UpdateByID", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Name: "model", Type: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + {Name: "id", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, + }, + Returns: []code.Type{ + code.TypeBool, + code.TypeError, + }, + Operation: spec.UpdateOperation{ + Update: spec.UpdateModel{}, + Mode: spec.QueryModeOne, + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + { + FieldReference: spec.FieldReference{idField}, + Comparator: spec.ComparatorEqual, + ParamIndex: 2, + }, + }, + }, + }, + }, + ExpectedBody: ` result, err := r.collection.UpdateOne(arg0, bson.M{ + "_id": arg2, + }, bson.M{ + "$set": arg1, + }) + if err != nil { + return false, err + } + return result.MatchedCount > 0, nil`, + }, + { + Name: "simple update one method", + MethodSpec: spec.MethodSpec{ + Name: "UpdateAgeByID", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Name: "age", Type: code.TypeInt}, + {Name: "id", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, + }, + Returns: []code.Type{ + code.TypeBool, + code.TypeError, + }, + Operation: spec.UpdateOperation{ + Update: spec.UpdateFields{ + spec.UpdateField{ + FieldReference: spec.FieldReference{ageField}, + ParamIndex: 1, + Operator: spec.UpdateOperatorSet, + }, + }, + Mode: spec.QueryModeOne, + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + { + FieldReference: spec.FieldReference{idField}, + Comparator: spec.ComparatorEqual, + ParamIndex: 2, + }, + }, + }, + }, + }, + ExpectedBody: ` result, err := r.collection.UpdateOne(arg0, bson.M{ + "_id": arg2, + }, bson.M{ + "$set": bson.M{ + "age": arg1, + }, + }) + if err != nil { + return false, err + } + return result.MatchedCount > 0, nil`, + }, + { + Name: "simple update many method", + MethodSpec: spec.MethodSpec{ + Name: "UpdateAgeByGender", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Name: "age", Type: code.TypeInt}, + {Name: "gender", Type: code.SimpleType("Gender")}, + }, + Returns: []code.Type{ + code.TypeInt, + code.TypeError, + }, + Operation: spec.UpdateOperation{ + Update: spec.UpdateFields{ + spec.UpdateField{ + FieldReference: spec.FieldReference{ageField}, + ParamIndex: 1, + Operator: spec.UpdateOperatorSet, + }, + }, + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + { + FieldReference: spec.FieldReference{genderField}, + Comparator: spec.ComparatorEqual, + ParamIndex: 2, + }, + }, + }, + }, + }, + ExpectedBody: ` result, err := r.collection.UpdateMany(arg0, bson.M{ + "gender": arg2, + }, bson.M{ + "$set": bson.M{ + "age": arg1, + }, + }) + if err != nil { + return 0, err + } + return int(result.MatchedCount), nil`, + }, + { + Name: "simple update push method", + MethodSpec: spec.MethodSpec{ + Name: "UpdateConsentHistoryPushByID", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Name: "consentHistory", Type: code.SimpleType("ConsentHistory")}, + {Name: "id", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, + }, + Returns: []code.Type{ + code.TypeBool, + code.TypeError, + }, + Operation: spec.UpdateOperation{ + Update: spec.UpdateFields{ + spec.UpdateField{ + FieldReference: spec.FieldReference{consentHistoryField}, + ParamIndex: 1, + Operator: spec.UpdateOperatorPush, + }, + }, + Mode: spec.QueryModeOne, + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + { + FieldReference: spec.FieldReference{idField}, + Comparator: spec.ComparatorEqual, + ParamIndex: 2, + }, + }, + }, + }, + }, + ExpectedBody: ` result, err := r.collection.UpdateOne(arg0, bson.M{ + "_id": arg2, + }, bson.M{ + "$push": bson.M{ + "consent_history": arg1, + }, + }) + if err != nil { + return false, err + } + return result.MatchedCount > 0, nil`, + }, + { + Name: "simple update inc method", + MethodSpec: spec.MethodSpec{ + Name: "UpdateAgeIncByID", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Name: "age", Type: code.TypeInt}, + {Name: "id", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, + }, + Returns: []code.Type{ + code.TypeBool, + code.TypeError, + }, + Operation: spec.UpdateOperation{ + Update: spec.UpdateFields{ + spec.UpdateField{ + FieldReference: spec.FieldReference{ageField}, + ParamIndex: 1, + Operator: spec.UpdateOperatorInc, + }, + }, + Mode: spec.QueryModeOne, + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + { + FieldReference: spec.FieldReference{idField}, + Comparator: spec.ComparatorEqual, + ParamIndex: 2, + }, + }, + }, + }, + }, + ExpectedBody: ` result, err := r.collection.UpdateOne(arg0, bson.M{ + "_id": arg2, + }, bson.M{ + "$inc": bson.M{ + "age": arg1, + }, + }) + if err != nil { + return false, err + } + return result.MatchedCount > 0, nil`, + }, + { + Name: "simple update set and push method", + MethodSpec: spec.MethodSpec{ + Name: "UpdateEnabledAndConsentHistoryPushByID", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Name: "enabled", Type: code.TypeBool}, + {Name: "consentHistory", Type: code.SimpleType("ConsentHistory")}, + {Name: "gender", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, + }, + Returns: []code.Type{ + code.TypeBool, + code.TypeError, + }, + Operation: spec.UpdateOperation{ + Update: spec.UpdateFields{ + spec.UpdateField{ + FieldReference: spec.FieldReference{enabledField}, + ParamIndex: 1, + Operator: spec.UpdateOperatorSet, + }, + spec.UpdateField{ + FieldReference: spec.FieldReference{consentHistoryField}, + ParamIndex: 2, + Operator: spec.UpdateOperatorPush, + }, + }, + Mode: spec.QueryModeOne, + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + { + FieldReference: spec.FieldReference{idField}, + Comparator: spec.ComparatorEqual, + ParamIndex: 3, + }, + }, + }, + }, + }, + ExpectedBody: ` result, err := r.collection.UpdateOne(arg0, bson.M{ + "_id": arg3, + }, bson.M{ + "$push": bson.M{ + "consent_history": arg2, + }, + "$set": bson.M{ + "enabled": arg1, + }, + }) + if err != nil { + return false, err + } + return result.MatchedCount > 0, nil`, + }, + { + Name: "update with deeply referenced field", + MethodSpec: spec.MethodSpec{ + Name: "UpdateNameFirstByID", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Name: "firstName", Type: code.TypeString}, + {Name: "id", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, + }, + Returns: []code.Type{ + code.TypeBool, + code.TypeError, + }, + Operation: spec.UpdateOperation{ + Update: spec.UpdateFields{ + spec.UpdateField{ + FieldReference: spec.FieldReference{nameField, firstNameField}, + ParamIndex: 1, + Operator: spec.UpdateOperatorSet, + }, + }, + Mode: spec.QueryModeOne, + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + { + FieldReference: spec.FieldReference{idField}, + Comparator: spec.ComparatorEqual, + ParamIndex: 2, + }, + }, + }, + }, + }, + ExpectedBody: ` result, err := r.collection.UpdateOne(arg0, bson.M{ + "_id": arg2, + }, bson.M{ + "$set": bson.M{ + "name.first": arg1, + }, + }) + if err != nil { + return false, err + } + return result.MatchedCount > 0, nil`, + }, + } + + for _, testCase := range testTable { + t.Run(testCase.Name, func(t *testing.T) { + generator := mongo.NewGenerator(userModel, "UserRepository") + expectedReceiver := codegen.MethodReceiver{ + Name: "r", + Type: "UserRepositoryMongo", + Pointer: true, + } + var expectedParams []code.Param + for i, param := range testCase.MethodSpec.Params { + expectedParams = append(expectedParams, code.Param{ + Name: fmt.Sprintf("arg%d", i), + Type: param.Type, + }) + } + + actual, err := generator.GenerateMethod(testCase.MethodSpec) + + if err != nil { + t.Fatal(err) + } + if expectedReceiver != actual.Receiver { + t.Errorf( + "incorrect method receiver: expected %+v, got %+v", + expectedReceiver, + actual.Receiver, + ) + } + if testCase.MethodSpec.Name != actual.Name { + t.Errorf( + "incorrect method name: expected %s, got %s", + testCase.MethodSpec.Name, + actual.Name, + ) + } + if !reflect.DeepEqual(expectedParams, actual.Params) { + t.Errorf( + "incorrect struct params: expected %+v, got %+v", + expectedParams, + actual.Params, + ) + } + if !reflect.DeepEqual(testCase.MethodSpec.Returns, actual.Returns) { + t.Errorf( + "incorrect struct returns: expected %+v, got %+v", + testCase.MethodSpec.Returns, + actual.Returns, + ) + } + if err := testutils.ExpectMultiLineString(testCase.ExpectedBody, actual.Body.Code()); err != nil { + t.Error(err) + } + }) + } +} diff --git a/test/generator_test_expected.txt b/test/generator_test_expected.txt index 672f3e1..8be45cc 100644 --- a/test/generator_test_expected.txt +++ b/test/generator_test_expected.txt @@ -33,8 +33,16 @@ func (r *UserRepositoryMongo) FindByID(arg0 context.Context, arg1 primitive.Obje func (r *UserRepositoryMongo) FindByGenderNotAndAgeLessThan(arg0 context.Context, arg1 Gender, arg2 int) (*UserModel, error) { cursor, err := r.collection.Find(arg0, bson.M{ "$and": []bson.M{ - {"gender": bson.M{"$ne": arg1}}, - {"age": bson.M{"$lt": arg2}}, + { + "gender": bson.M{ + "$ne": arg1, + }, + }, + { + "age": bson.M{ + "$lt": arg2, + }, + }, }, }, options.Find().SetSort(bson.M{})) if err != nil { @@ -49,7 +57,9 @@ func (r *UserRepositoryMongo) FindByGenderNotAndAgeLessThan(arg0 context.Context func (r *UserRepositoryMongo) FindByAgeLessThanEqualOrderByAge(arg0 context.Context, arg1 int) ([]*UserModel, error) { cursor, err := r.collection.Find(arg0, bson.M{ - "age": bson.M{"$lte": arg1}, + "age": bson.M{ + "$lte": arg1, + }, }, options.Find().SetSort(bson.M{ "age": 1, })) @@ -65,7 +75,9 @@ func (r *UserRepositoryMongo) FindByAgeLessThanEqualOrderByAge(arg0 context.Cont func (r *UserRepositoryMongo) FindByAgeGreaterThanOrderByAgeAsc(arg0 context.Context, arg1 int) ([]*UserModel, error) { cursor, err := r.collection.Find(arg0, bson.M{ - "age": bson.M{"$gt": arg1}, + "age": bson.M{ + "$gt": arg1, + }, }, options.Find().SetSort(bson.M{ "age": 1, })) @@ -81,7 +93,9 @@ func (r *UserRepositoryMongo) FindByAgeGreaterThanOrderByAgeAsc(arg0 context.Con func (r *UserRepositoryMongo) FindByAgeGreaterThanEqualOrderByAgeDesc(arg0 context.Context, arg1 int) ([]*UserModel, error) { cursor, err := r.collection.Find(arg0, bson.M{ - "age": bson.M{"$gte": arg1}, + "age": bson.M{ + "$gte": arg1, + }, }, options.Find().SetSort(bson.M{ "age": -1, })) @@ -97,7 +111,10 @@ func (r *UserRepositoryMongo) FindByAgeGreaterThanEqualOrderByAgeDesc(arg0 conte func (r *UserRepositoryMongo) FindByAgeBetween(arg0 context.Context, arg1 int, arg2 int) ([]*UserModel, error) { cursor, err := r.collection.Find(arg0, bson.M{ - "age": bson.M{"$gte": arg1, "$lte": arg2}, + "age": bson.M{ + "$gte": arg1, + "$lte": arg2, + }, }, options.Find().SetSort(bson.M{})) if err != nil { return nil, err @@ -112,8 +129,12 @@ func (r *UserRepositoryMongo) FindByAgeBetween(arg0 context.Context, arg1 int, a func (r *UserRepositoryMongo) FindByGenderOrAge(arg0 context.Context, arg1 Gender, arg2 int) ([]*UserModel, error) { cursor, err := r.collection.Find(arg0, bson.M{ "$or": []bson.M{ - {"gender": arg1}, - {"age": arg2}, + { + "gender": arg1, + }, + { + "age": arg2, + }, }, }, options.Find().SetSort(bson.M{})) if err != nil {