repogen/internal/mongo/templates.go

99 lines
2.2 KiB
Go

package mongo
import (
"fmt"
"strings"
"github.com/sunboyy/repogen/internal/code"
)
const baseTemplate = `// Code generated by repogen. DO NOT EDIT.
package {{.PackageName}}
import (
"context"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo"
)
func New{{.InterfaceName}}(collection *mongo.Collection) {{.InterfaceName}} {
return &{{.StructName}}{
collection: collection,
}
}
type {{.StructName}} struct {
collection *mongo.Collection
}
`
type mongoBaseTemplateData struct {
PackageName string
InterfaceName string
StructName string
}
const methodTemplate = `
func (r *{{.StructName}}) {{.MethodName}}(ctx context.Context, {{.Parameters}}){{.Returns}} {
{{.Implementation}}
}
`
type mongoMethodTemplateData struct {
StructName string
MethodName string
ParamTypes []code.Type
ReturnTypes []code.Type
Implementation string
}
func (data mongoMethodTemplateData) Parameters() string {
var params []string
for i, paramType := range data.ParamTypes {
params = append(params, fmt.Sprintf("arg%d %s", i, paramType.Code()))
}
return strings.Join(params, ", ")
}
func (data mongoMethodTemplateData) Returns() string {
if len(data.ReturnTypes) == 0 {
return ""
}
if len(data.ReturnTypes) == 1 {
return fmt.Sprintf(" %s", data.ReturnTypes[0].Code())
}
var returns []string
for _, returnType := range data.ReturnTypes {
returns = append(returns, returnType.Code())
}
return fmt.Sprintf(" (%s)", strings.Join(returns, ", "))
}
const findOneTemplate = ` var entity {{.EntityType}}
if err := r.collection.FindOne(ctx, bson.M{
{{range $index, $field := .QueryFields}} "{{$field}}": arg{{$index}},
{{end}} }).Decode(&entity); err != nil {
return nil, err
}
return &entity, nil`
type mongoFindTemplateData struct {
EntityType string
QueryFields []string
}
const findManyTemplate = ` cursor, err := r.collection.Find(ctx, bson.M{
{{range $index, $field := .QueryFields}} "{{$field}}": arg{{$index}},
{{end}} })
if err != nil {
return nil, err
}
var entities []*{{.EntityType}}
if err := cursor.All(ctx, &entities); err != nil {
return nil, err
}
return entities, nil`