Fix error when a method argument is assigned with map type (#26)

This commit is contained in:
sunboyy 2021-11-27 11:54:18 +07:00 committed by GitHub
parent dd206ba46b
commit 8081ffcb0f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 49 additions and 2 deletions

View file

@ -167,6 +167,11 @@ func getType(expr ast.Expr) Type {
containedType := getType(expr.Elt) containedType := getType(expr.Elt)
return ArrayType{ContainedType: containedType} return ArrayType{ContainedType: containedType}
case *ast.MapType:
keyType := getType(expr.Key)
valueType := getType(expr.Value)
return MapType{KeyType: keyType, ValueType: valueType}
case *ast.InterfaceType: case *ast.InterfaceType:
var methods []Method var methods []Method
for _, method := range expr.Methods.List { for _, method := range expr.Methods.List {

View file

@ -98,6 +98,7 @@ type UserRepository interface {
FindAll(context.Context) ([]*UserModel, error) FindAll(context.Context) ([]*UserModel, error)
FindByAgeBetween(ctx context.Context, fromAge, toAge int) ([]*UserModel, error) FindByAgeBetween(ctx context.Context, fromAge, toAge int) ([]*UserModel, error)
InsertOne(ctx context.Context, user *UserModel) (interface{}, error) InsertOne(ctx context.Context, user *UserModel) (interface{}, error)
UpdateAgreementByID(ctx context.Context, agreement map[string]bool, id primitive.ObjectID) (bool, error)
CustomMethod(interface { CustomMethod(interface {
Run(arg1 int) Run(arg1 int)
}) interface { }) interface {
@ -154,6 +155,18 @@ type UserRepository interface {
code.SimpleType("error"), code.SimpleType("error"),
}, },
}, },
{
Name: "UpdateAgreementByID",
Params: []code.Param{
{Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}},
{Name: "agreement", Type: code.MapType{KeyType: code.SimpleType("string"), ValueType: code.SimpleType("bool")}},
{Name: "id", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}},
},
Returns: []code.Type{
code.SimpleType("bool"),
code.SimpleType("error"),
},
},
{ {
Name: "CustomMethod", Name: "CustomMethod",
Params: []code.Param{ Params: []code.Param{

View file

@ -172,3 +172,19 @@ func (t ArrayType) Code() string {
func (t ArrayType) IsNumber() bool { func (t ArrayType) IsNumber() bool {
return false return false
} }
// MapType is a model of map
type MapType struct {
KeyType Type
ValueType Type
}
// Code returns token string in code format
func (t MapType) Code() string {
return fmt.Sprintf("map[%s]%s", t.KeyType.Code(), t.ValueType.Code())
}
// IsNumber returns false
func (t MapType) IsNumber() bool {
return false
}

View file

@ -114,6 +114,14 @@ func TestTypeCode(t *testing.T) {
Type: code.ArrayType{ContainedType: code.SimpleType("UserModel")}, Type: code.ArrayType{ContainedType: code.SimpleType("UserModel")},
ExpectedCode: "[]UserModel", ExpectedCode: "[]UserModel",
}, },
{
Name: "map type",
Type: code.MapType{
KeyType: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"},
ValueType: code.PointerType{ContainedType: code.SimpleType("UserModel")},
},
ExpectedCode: "map[primitive.ObjectID]*UserModel",
},
} }
for _, testCase := range testTable { for _, testCase := range testTable {
@ -195,6 +203,11 @@ func TestTypeIsNumber(t *testing.T) {
Type: code.ArrayType{ContainedType: code.SimpleType("int")}, Type: code.ArrayType{ContainedType: code.SimpleType("int")},
IsNumber: false, IsNumber: false,
}, },
{
Name: "map type",
Type: code.MapType{KeyType: code.SimpleType("int"), ValueType: code.SimpleType("float64")},
IsNumber: false,
},
{ {
Name: "interface type", Name: "interface type",
Type: code.InterfaceType{}, Type: code.InterfaceType{},

View file

@ -21,7 +21,7 @@ const usageText = `repogen generates MongoDB repository implementation from repo
Supported options:` Supported options:`
const version = "v0.2.0" const version = "v0.2.1"
func main() { func main() {
flag.Usage = printUsage flag.Usage = printUsage
@ -41,7 +41,7 @@ func main() {
if *sourcePtr == "" { if *sourcePtr == "" {
printUsage() printUsage()
log.Fatal("-source flag required") log.Fatal("-src flag required")
} }
if *modelPtr == "" { if *modelPtr == "" {
printUsage() printUsage()