From 8081ffcb0fd31b3bdaed9aa43e33032b5e57e2a9 Mon Sep 17 00:00:00 2001 From: sunboyy Date: Sat, 27 Nov 2021 11:54:18 +0700 Subject: [PATCH] Fix error when a method argument is assigned with map type (#26) --- internal/code/extractor.go | 5 +++++ internal/code/extractor_test.go | 13 +++++++++++++ internal/code/models.go | 16 ++++++++++++++++ internal/code/models_test.go | 13 +++++++++++++ main.go | 4 ++-- 5 files changed, 49 insertions(+), 2 deletions(-) diff --git a/internal/code/extractor.go b/internal/code/extractor.go index cbf62bb..0236583 100644 --- a/internal/code/extractor.go +++ b/internal/code/extractor.go @@ -167,6 +167,11 @@ func getType(expr ast.Expr) Type { containedType := getType(expr.Elt) return ArrayType{ContainedType: containedType} + case *ast.MapType: + keyType := getType(expr.Key) + valueType := getType(expr.Value) + return MapType{KeyType: keyType, ValueType: valueType} + case *ast.InterfaceType: var methods []Method for _, method := range expr.Methods.List { diff --git a/internal/code/extractor_test.go b/internal/code/extractor_test.go index 86ced4d..17ffd40 100644 --- a/internal/code/extractor_test.go +++ b/internal/code/extractor_test.go @@ -98,6 +98,7 @@ type UserRepository interface { FindAll(context.Context) ([]*UserModel, error) FindByAgeBetween(ctx context.Context, fromAge, toAge int) ([]*UserModel, error) InsertOne(ctx context.Context, user *UserModel) (interface{}, error) + UpdateAgreementByID(ctx context.Context, agreement map[string]bool, id primitive.ObjectID) (bool, error) CustomMethod(interface { Run(arg1 int) }) interface { @@ -154,6 +155,18 @@ type UserRepository interface { code.SimpleType("error"), }, }, + { + Name: "UpdateAgreementByID", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Name: "agreement", Type: code.MapType{KeyType: code.SimpleType("string"), ValueType: code.SimpleType("bool")}}, + {Name: "id", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, + }, + Returns: []code.Type{ + code.SimpleType("bool"), + code.SimpleType("error"), + }, + }, { Name: "CustomMethod", Params: []code.Param{ diff --git a/internal/code/models.go b/internal/code/models.go index 094efa0..c086e36 100644 --- a/internal/code/models.go +++ b/internal/code/models.go @@ -172,3 +172,19 @@ func (t ArrayType) Code() string { func (t ArrayType) IsNumber() bool { return false } + +// MapType is a model of map +type MapType struct { + KeyType Type + ValueType Type +} + +// Code returns token string in code format +func (t MapType) Code() string { + return fmt.Sprintf("map[%s]%s", t.KeyType.Code(), t.ValueType.Code()) +} + +// IsNumber returns false +func (t MapType) IsNumber() bool { + return false +} diff --git a/internal/code/models_test.go b/internal/code/models_test.go index 66b6f54..34698ef 100644 --- a/internal/code/models_test.go +++ b/internal/code/models_test.go @@ -114,6 +114,14 @@ func TestTypeCode(t *testing.T) { Type: code.ArrayType{ContainedType: code.SimpleType("UserModel")}, ExpectedCode: "[]UserModel", }, + { + Name: "map type", + Type: code.MapType{ + KeyType: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}, + ValueType: code.PointerType{ContainedType: code.SimpleType("UserModel")}, + }, + ExpectedCode: "map[primitive.ObjectID]*UserModel", + }, } for _, testCase := range testTable { @@ -195,6 +203,11 @@ func TestTypeIsNumber(t *testing.T) { Type: code.ArrayType{ContainedType: code.SimpleType("int")}, IsNumber: false, }, + { + Name: "map type", + Type: code.MapType{KeyType: code.SimpleType("int"), ValueType: code.SimpleType("float64")}, + IsNumber: false, + }, { Name: "interface type", Type: code.InterfaceType{}, diff --git a/main.go b/main.go index f0b5c86..313046a 100644 --- a/main.go +++ b/main.go @@ -21,7 +21,7 @@ const usageText = `repogen generates MongoDB repository implementation from repo Supported options:` -const version = "v0.2.0" +const version = "v0.2.1" func main() { flag.Usage = printUsage @@ -41,7 +41,7 @@ func main() { if *sourcePtr == "" { printUsage() - log.Fatal("-source flag required") + log.Fatal("-src flag required") } if *modelPtr == "" { printUsage()