diff --git a/internal/code/extractor.go b/internal/code/extractor.go new file mode 100644 index 0000000..313aa8c --- /dev/null +++ b/internal/code/extractor.go @@ -0,0 +1,157 @@ +package code + +import ( + "fmt" + "go/ast" + "strconv" + "strings" +) + +// ExtractComponents converts ast file into code components model +func ExtractComponents(f *ast.File) File { + var file File + file.PackageName = f.Name.Name + + for _, decl := range f.Decls { + genDecl, ok := decl.(*ast.GenDecl) + if !ok { + continue + } + + for _, spec := range genDecl.Specs { + importSpec, ok := spec.(*ast.ImportSpec) + if ok { + var imp Import + if importSpec.Name != nil { + imp.Name = importSpec.Name.Name + } + importPath, err := strconv.Unquote(importSpec.Path.Value) + if err != nil { + fmt.Printf("cannot unquote import %s : %s \n", importSpec.Path.Value, err) + continue + } + imp.Path = importPath + + file.Imports = append(file.Imports, imp) + } + + typeSpec, ok := spec.(*ast.TypeSpec) + if ok { + structType, ok := typeSpec.Type.(*ast.StructType) + if ok { + str := Struct{ + Name: typeSpec.Name.Name, + } + + for _, field := range structType.Fields.List { + var strField StructField + for _, name := range field.Names { + strField.Name = name.Name + break + } + strField.Type = getType(field.Type) + if field.Tag != nil { + strField.Tags = extractStructTag(field.Tag.Value) + } + + str.Fields = append(str.Fields, strField) + } + + file.Structs = append(file.Structs, str) + } + + interfaceType, ok := typeSpec.Type.(*ast.InterfaceType) + if ok { + intf := Interface{ + Name: typeSpec.Name.Name, + } + + for _, method := range interfaceType.Methods.List { + var meth Method + for _, name := range method.Names { + meth.Name = name.Name + break + } + + funcType, ok := method.Type.(*ast.FuncType) + if !ok { + continue + } + + for _, param := range funcType.Params.List { + var p Param + for _, name := range param.Names { + p.Name = name.Name + break + } + p.Type = getType(param.Type) + + meth.Params = append(meth.Params, p) + } + + for _, result := range funcType.Results.List { + meth.Returns = append(meth.Returns, getType(result.Type)) + } + + intf.Methods = append(intf.Methods, meth) + } + + file.Interfaces = append(file.Interfaces, intf) + } + } + } + } + return file +} + +func extractStructTag(tagValue string) map[string][]string { + tagTokens := strings.Fields(tagValue[1 : len(tagValue)-1]) + + tags := make(map[string][]string) + for _, tagToken := range tagTokens { + colonIndex := strings.Index(tagToken, ":") + if colonIndex == -1 { + continue + } + tagKey := tagToken[:colonIndex] + tagValue, err := strconv.Unquote(tagToken[colonIndex+1:]) + if err != nil { + fmt.Printf("cannot unquote struct tag %s : %s\n", tagToken[colonIndex+1:], err) + continue + } + tagValues := strings.Split(tagValue, ",") + tags[tagKey] = tagValues + } + + return tags +} + +func getType(expr ast.Expr) Type { + identExpr, ok := expr.(*ast.Ident) + if ok { + return SimpleType(identExpr.Name) + } + + selectorExpr, ok := expr.(*ast.SelectorExpr) + if ok { + xExpr, ok := selectorExpr.X.(*ast.Ident) + if !ok { + return ExternalType{Name: selectorExpr.Sel.Name} + } + return ExternalType{PackageAlias: xExpr.Name, Name: selectorExpr.Sel.Name} + } + + starExpr, ok := expr.(*ast.StarExpr) + if ok { + containedType := getType(starExpr.X) + return PointerType{ContainedType: containedType} + } + + arrayType, ok := expr.(*ast.ArrayType) + if ok { + containedType := getType(arrayType.Elt) + return ArrayType{containedType} + } + + return nil +} diff --git a/internal/code/extractor_test.go b/internal/code/extractor_test.go new file mode 100644 index 0000000..d77af3d --- /dev/null +++ b/internal/code/extractor_test.go @@ -0,0 +1,225 @@ +package code_test + +import ( + "go/parser" + "go/token" + "reflect" + "testing" + + "github.com/sunboyy/repogen/internal/code" +) + +type TestCase struct { + Name string + Source string + ExpectedOutput code.File +} + +func TestExtractComponents(t *testing.T) { + testTable := []TestCase{ + { + Name: "package name", + Source: `package user`, + ExpectedOutput: code.File{ + PackageName: "user", + }, + }, + { + Name: "single line imports", + Source: `package user + +import ctx "context" +import "go.mongodb.org/mongo-driver/bson/primitive"`, + ExpectedOutput: code.File{ + PackageName: "user", + Imports: []code.Import{ + {Name: "ctx", Path: "context"}, + {Path: "go.mongodb.org/mongo-driver/bson/primitive"}, + }, + }, + }, + { + Name: "multiple line imports", + Source: `package user + +import ( + ctx "context" + "go.mongodb.org/mongo-driver/bson/primitive" +)`, + ExpectedOutput: code.File{ + PackageName: "user", + Imports: []code.Import{ + {Name: "ctx", Path: "context"}, + {Path: "go.mongodb.org/mongo-driver/bson/primitive"}, + }, + }, + }, + { + Name: "struct declaration", + Source: `package user + +type UserModel struct { + ID primitive.ObjectID ` + "`bson:\"_id,omitempty\" json:\"id\"`" + ` + Username string ` + "`bson:\"username\" json:\"username\"`" + ` +}`, + ExpectedOutput: code.File{ + PackageName: "user", + Structs: code.Structs{ + { + Name: "UserModel", + Fields: code.StructFields{ + { + Name: "ID", + Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}, + Tags: map[string][]string{ + "bson": {"_id", "omitempty"}, + "json": {"id"}, + }, + }, + { + Name: "Username", + Type: code.SimpleType("string"), + Tags: map[string][]string{ + "bson": {"username"}, + "json": {"username"}, + }, + }, + }, + }, + }, + }, + }, + { + Name: "interface declaration", + Source: `package user + +type UserRepository interface { + FindOneByID(ctx context.Context, id primitive.ObjectID) (*UserModel, error) + FindAll(context.Context) ([]*UserModel, error) +}`, + ExpectedOutput: code.File{ + PackageName: "user", + Interfaces: code.Interfaces{ + { + Name: "UserRepository", + Methods: []code.Method{ + { + Name: "FindOneByID", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Name: "id", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, + }, + Returns: []code.Type{ + code.PointerType{ContainedType: code.SimpleType("UserModel")}, + code.SimpleType("error"), + }, + }, + { + Name: "FindAll", + Params: []code.Param{ + {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + }, + Returns: []code.Type{ + code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + code.SimpleType("error"), + }, + }, + }, + }, + }, + }, + }, + { + Name: "integration", + Source: `package user + +import ( + "context" + + "go.mongodb.org/mongo-driver/bson/primitive" +) + +type UserModel struct { + ID primitive.ObjectID ` + "`bson:\"_id,omitempty\" json:\"id\"`" + ` + Username string ` + "`bson:\"username\" json:\"username\"`" + ` +} + +type UserRepository interface { + FindOneByID(ctx context.Context, id primitive.ObjectID) (*UserModel, error) + FindAll(ctx context.Context) ([]*UserModel, error) +} +`, + ExpectedOutput: code.File{ + PackageName: "user", + Imports: []code.Import{ + {Path: "context"}, + {Path: "go.mongodb.org/mongo-driver/bson/primitive"}, + }, + Structs: code.Structs{ + { + Name: "UserModel", + Fields: code.StructFields{ + { + Name: "ID", + Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}, + Tags: map[string][]string{ + "bson": {"_id", "omitempty"}, + "json": {"id"}, + }, + }, + { + Name: "Username", + Type: code.SimpleType("string"), + Tags: map[string][]string{ + "bson": {"username"}, + "json": {"username"}, + }, + }, + }, + }, + }, + Interfaces: code.Interfaces{ + { + Name: "UserRepository", + Methods: []code.Method{ + { + Name: "FindOneByID", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + {Name: "id", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, + }, + Returns: []code.Type{ + code.PointerType{ContainedType: code.SimpleType("UserModel")}, + code.SimpleType("error"), + }, + }, + { + Name: "FindAll", + Params: []code.Param{ + {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + }, + Returns: []code.Type{ + code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, + code.SimpleType("error"), + }, + }, + }, + }, + }, + }, + }, + } + + for _, testCase := range testTable { + t.Run(testCase.Name, func(t *testing.T) { + fset := token.NewFileSet() + f, _ := parser.ParseFile(fset, "", testCase.Source, parser.ParseComments) + + file := code.ExtractComponents(f) + + if !reflect.DeepEqual(file, testCase.ExpectedOutput) { + t.Errorf("Expected = %v\nReceived = %v", testCase.ExpectedOutput, file) + } + }) + } +} diff --git a/internal/code/models.go b/internal/code/models.go new file mode 100644 index 0000000..4e1bc4d --- /dev/null +++ b/internal/code/models.go @@ -0,0 +1,136 @@ +package code + +import ( + "fmt" +) + +// File is a container of all required components for code generation in the file +type File struct { + PackageName string + Imports []Import + Structs Structs + Interfaces Interfaces +} + +// Import is a model for package imports +type Import struct { + Name string + Path string +} + +// Structs is a group of Struct model +type Structs []Struct + +// ByName return struct with matching name. Another return value shows whether there is a struct +// with that name exists. +func (strs Structs) ByName(name string) (Struct, bool) { + for _, str := range strs { + if str.Name == name { + return str, true + } + } + return Struct{}, false +} + +// Struct is a definition of the struct +type Struct struct { + Name string + Fields StructFields +} + +// StructFields is a group of the StructField model +type StructFields []StructField + +// ByName return struct field with matching name +func (fields StructFields) ByName(name string) (StructField, bool) { + for _, field := range fields { + if field.Name == name { + return field, true + } + } + return StructField{}, false +} + +// StructField is a definition of the struct field +type StructField struct { + Name string + Type Type + Tags map[string][]string +} + +// Interfaces is a group of Interface model +type Interfaces []Interface + +// ByName return interface by name Another return value shows whether there is an interface +// with that name exists. +func (intfs Interfaces) ByName(name string) (Interface, bool) { + for _, intf := range intfs { + if intf.Name == name { + return intf, true + } + } + return Interface{}, false +} + +// Interface is a definition of the interface +type Interface struct { + Name string + Methods []Method +} + +// Method is a definition of the method inside the interface +type Method struct { + Name string + Params []Param + Returns []Type +} + +// Param is a model of method parameter +type Param struct { + Name string + Type Type +} + +// Type is an interface for value types +type Type interface { + Code() string +} + +// SimpleType is a type that can be called directly +type SimpleType string + +// Code returns token string in code format +func (t SimpleType) Code() string { + return string(t) +} + +// ExternalType is a type that is called to another package +type ExternalType struct { + PackageAlias string + Name string +} + +// Code returns token string in code format +func (t ExternalType) Code() string { + return fmt.Sprintf("%s.%s", t.PackageAlias, t.Name) +} + +// PointerType is a model of pointer +type PointerType struct { + ContainedType Type +} + +// Code returns token string in code format +func (t PointerType) Code() string { + return fmt.Sprintf("*%s", t.ContainedType.Code()) +} + +// ArrayType is a model of array +type ArrayType struct { + ContainedType Type +} + +// Code returns token string in code format +func (t ArrayType) Code() string { + return fmt.Sprintf("[]%s", t.ContainedType.Code()) +}