Scan directory to find source code in multiple files in the same package (#30)
This commit is contained in:
parent
0a1d5c8545
commit
737c1a4044
15 changed files with 444 additions and 177 deletions
28
internal/code/errors.go
Normal file
28
internal/code/errors.go
Normal file
|
@ -0,0 +1,28 @@
|
|||
package code
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrAmbiguousPackageName = errors.New("code: ambiguous package name")
|
||||
)
|
||||
|
||||
type DuplicateStructError string
|
||||
|
||||
func (err DuplicateStructError) Error() string {
|
||||
return fmt.Sprintf(
|
||||
"code: duplicate implementation of struct '%s'",
|
||||
string(err),
|
||||
)
|
||||
}
|
||||
|
||||
type DuplicateInterfaceError string
|
||||
|
||||
func (err DuplicateInterfaceError) Error() string {
|
||||
return fmt.Sprintf(
|
||||
"code: duplicate implementation of interface '%s'",
|
||||
string(err),
|
||||
)
|
||||
}
|
36
internal/code/errors_test.go
Normal file
36
internal/code/errors_test.go
Normal file
|
@ -0,0 +1,36 @@
|
|||
package code_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/sunboyy/repogen/internal/code"
|
||||
)
|
||||
|
||||
type ErrorTestCase struct {
|
||||
Name string
|
||||
Error error
|
||||
ExpectedString string
|
||||
}
|
||||
|
||||
func TestError(t *testing.T) {
|
||||
testTable := []ErrorTestCase{
|
||||
{
|
||||
Name: "DuplicateStructError",
|
||||
Error: code.DuplicateStructError("User"),
|
||||
ExpectedString: "code: duplicate implementation of struct 'User'",
|
||||
},
|
||||
{
|
||||
Name: "DuplicateInterfaceError",
|
||||
Error: code.DuplicateInterfaceError("UserRepository"),
|
||||
ExpectedString: "code: duplicate implementation of interface 'UserRepository'",
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range testTable {
|
||||
t.Run(testCase.Name, func(t *testing.T) {
|
||||
if testCase.Error.Error() != testCase.ExpectedString {
|
||||
t.Errorf("Expected = %+v\nReceived = %+v", testCase.ExpectedString, testCase.Error.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -64,8 +64,8 @@ type UserModel struct {
|
|||
}`,
|
||||
ExpectedOutput: code.File{
|
||||
PackageName: "user",
|
||||
Structs: code.Structs{
|
||||
code.Struct{
|
||||
Structs: []code.Struct{
|
||||
{
|
||||
Name: "UserModel",
|
||||
Fields: code.StructFields{
|
||||
code.StructField{
|
||||
|
@ -108,8 +108,8 @@ type UserRepository interface {
|
|||
}`,
|
||||
ExpectedOutput: code.File{
|
||||
PackageName: "user",
|
||||
Interfaces: code.Interfaces{
|
||||
code.InterfaceType{
|
||||
Interfaces: []code.InterfaceType{
|
||||
{
|
||||
Name: "UserRepository",
|
||||
Methods: []code.Method{
|
||||
{
|
||||
|
@ -229,8 +229,8 @@ type UserRepository interface {
|
|||
{Path: "context"},
|
||||
{Path: "go.mongodb.org/mongo-driver/bson/primitive"},
|
||||
},
|
||||
Structs: code.Structs{
|
||||
code.Struct{
|
||||
Structs: []code.Struct{
|
||||
{
|
||||
Name: "UserModel",
|
||||
Fields: code.StructFields{
|
||||
code.StructField{
|
||||
|
@ -252,8 +252,8 @@ type UserRepository interface {
|
|||
},
|
||||
},
|
||||
},
|
||||
Interfaces: code.Interfaces{
|
||||
code.InterfaceType{
|
||||
Interfaces: []code.InterfaceType{
|
||||
{
|
||||
Name: "UserRepository",
|
||||
Methods: []code.Method{
|
||||
{
|
||||
|
|
|
@ -8,8 +8,8 @@ import (
|
|||
type File struct {
|
||||
PackageName string
|
||||
Imports []Import
|
||||
Structs Structs
|
||||
Interfaces Interfaces
|
||||
Structs []Struct
|
||||
Interfaces []InterfaceType
|
||||
}
|
||||
|
||||
// Import is a model for package imports
|
||||
|
@ -18,20 +18,6 @@ type Import struct {
|
|||
Path string
|
||||
}
|
||||
|
||||
// Structs is a group of Struct model
|
||||
type Structs []Struct
|
||||
|
||||
// ByName return struct with matching name. Another return value shows whether there is a struct
|
||||
// with that name exists.
|
||||
func (strs Structs) ByName(name string) (Struct, bool) {
|
||||
for _, str := range strs {
|
||||
if str.Name == name {
|
||||
return str, true
|
||||
}
|
||||
}
|
||||
return Struct{}, false
|
||||
}
|
||||
|
||||
// Struct is a definition of the struct
|
||||
type Struct struct {
|
||||
Name string
|
||||
|
@ -63,20 +49,6 @@ type StructField struct {
|
|||
Tags map[string][]string
|
||||
}
|
||||
|
||||
// Interfaces is a group of Interface model
|
||||
type Interfaces []InterfaceType
|
||||
|
||||
// ByName return interface by name Another return value shows whether there is an interface
|
||||
// with that name exists.
|
||||
func (intfs Interfaces) ByName(name string) (InterfaceType, bool) {
|
||||
for _, intf := range intfs {
|
||||
if intf.Name == name {
|
||||
return intf, true
|
||||
}
|
||||
}
|
||||
return InterfaceType{}, false
|
||||
}
|
||||
|
||||
// InterfaceType is a definition of the interface
|
||||
type InterfaceType struct {
|
||||
Name string
|
||||
|
|
|
@ -7,36 +7,6 @@ import (
|
|||
"github.com/sunboyy/repogen/internal/code"
|
||||
)
|
||||
|
||||
func TestStructsByName(t *testing.T) {
|
||||
userStruct := code.Struct{
|
||||
Name: "UserModel",
|
||||
Fields: code.StructFields{
|
||||
code.StructField{Name: "ID", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}},
|
||||
code.StructField{Name: "Username", Type: code.SimpleType("string")},
|
||||
},
|
||||
}
|
||||
structs := code.Structs{userStruct}
|
||||
|
||||
t.Run("struct found", func(t *testing.T) {
|
||||
structModel, ok := structs.ByName("UserModel")
|
||||
|
||||
if !ok {
|
||||
t.Fail()
|
||||
}
|
||||
if !reflect.DeepEqual(structModel, userStruct) {
|
||||
t.Errorf("Expected = %+v\nReceived = %+v", userStruct, structModel)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("struct not found", func(t *testing.T) {
|
||||
_, ok := structs.ByName("ProductModel")
|
||||
|
||||
if ok {
|
||||
t.Fail()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestStructFieldsByName(t *testing.T) {
|
||||
idField := code.StructField{Name: "ID", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}
|
||||
usernameField := code.StructField{Name: "Username", Type: code.SimpleType("string")}
|
||||
|
@ -62,30 +32,6 @@ func TestStructFieldsByName(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestInterfacesByName(t *testing.T) {
|
||||
userRepoIntf := code.InterfaceType{Name: "UserRepository"}
|
||||
interfaces := code.Interfaces{userRepoIntf}
|
||||
|
||||
t.Run("struct field found", func(t *testing.T) {
|
||||
intf, ok := interfaces.ByName("UserRepository")
|
||||
|
||||
if !ok {
|
||||
t.Fail()
|
||||
}
|
||||
if !reflect.DeepEqual(intf, userRepoIntf) {
|
||||
t.Errorf("Expected = %+v\nReceived = %+v", userRepoIntf, intf)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("struct field not found", func(t *testing.T) {
|
||||
_, ok := interfaces.ByName("Password")
|
||||
|
||||
if ok {
|
||||
t.Fail()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
type TypeCodeTestCase struct {
|
||||
Name string
|
||||
Type code.Type
|
||||
|
|
66
internal/code/package.go
Normal file
66
internal/code/package.go
Normal file
|
@ -0,0 +1,66 @@
|
|||
package code
|
||||
|
||||
import (
|
||||
"go/ast"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ParsePackage extracts package name, struct and interface implementations from
|
||||
// map[string]*ast.Package. Test files will be ignored.
|
||||
func ParsePackage(pkgs map[string]*ast.Package) (Package, error) {
|
||||
pkg := NewPackage()
|
||||
for _, astPkg := range pkgs {
|
||||
for fileName, file := range astPkg.Files {
|
||||
if strings.HasSuffix(fileName, "_test.go") {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := pkg.addFile(ExtractComponents(file)); err != nil {
|
||||
return Package{}, err
|
||||
}
|
||||
}
|
||||
}
|
||||
return pkg, nil
|
||||
}
|
||||
|
||||
// Package stores package name, struct and interface implementations as a result
|
||||
// from ParsePackage
|
||||
type Package struct {
|
||||
Name string
|
||||
Structs map[string]Struct
|
||||
Interfaces map[string]InterfaceType
|
||||
}
|
||||
|
||||
// NewPackage is a constructor function for Package.
|
||||
func NewPackage() Package {
|
||||
return Package{
|
||||
Structs: map[string]Struct{},
|
||||
Interfaces: map[string]InterfaceType{},
|
||||
}
|
||||
}
|
||||
|
||||
// addFile alters the Package by adding struct and interface implementations in
|
||||
// the extracted file. If the package name conflicts, it will return error.
|
||||
func (pkg *Package) addFile(file File) error {
|
||||
if pkg.Name == "" {
|
||||
pkg.Name = file.PackageName
|
||||
} else if pkg.Name != file.PackageName {
|
||||
return ErrAmbiguousPackageName
|
||||
}
|
||||
|
||||
for _, structImpl := range file.Structs {
|
||||
if _, ok := pkg.Structs[structImpl.Name]; ok {
|
||||
return DuplicateStructError(structImpl.Name)
|
||||
}
|
||||
pkg.Structs[structImpl.Name] = structImpl
|
||||
}
|
||||
|
||||
for _, interfaceImpl := range file.Interfaces {
|
||||
if _, ok := pkg.Interfaces[interfaceImpl.Name]; ok {
|
||||
return DuplicateInterfaceError(interfaceImpl.Name)
|
||||
}
|
||||
pkg.Interfaces[interfaceImpl.Name] = interfaceImpl
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
238
internal/code/package_test.go
Normal file
238
internal/code/package_test.go
Normal file
|
@ -0,0 +1,238 @@
|
|||
package code_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"go/ast"
|
||||
"go/parser"
|
||||
"go/token"
|
||||
"testing"
|
||||
|
||||
"github.com/sunboyy/repogen/internal/code"
|
||||
)
|
||||
|
||||
const goImplFile1Data = `
|
||||
package codepkgsuccess
|
||||
|
||||
import (
|
||||
"math"
|
||||
"time"
|
||||
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
type Gender string
|
||||
|
||||
const (
|
||||
GenderMale Gender = "MALE"
|
||||
GenderFemale Gender = "FEMALE"
|
||||
)
|
||||
|
||||
type User struct {
|
||||
ID primitive.ObjectID ` + "`json:\"id\"`" + `
|
||||
Name string ` + "`json:\"name\"`" + `
|
||||
Gender Gender ` + "`json:\"gender\"`" + `
|
||||
Birthday time.Time ` + "`json:\"birthday\"`" + `
|
||||
}
|
||||
|
||||
func (u User) Age() int {
|
||||
return int(math.Floor(time.Since(u.Birthday).Hours() / 24 / 365))
|
||||
}
|
||||
|
||||
type (
|
||||
Product struct {
|
||||
ID primitive.ObjectID ` + "`json:\"id\"`" + `
|
||||
Name string ` + "`json:\"name\"`" + `
|
||||
Price float64 ` + "`json:\"price\"`" + `
|
||||
}
|
||||
|
||||
Order struct {
|
||||
ID primitive.ObjectID ` + "`json:\"id\"`" + `
|
||||
ItemIDs map[primitive.ObjectID]int ` + "`json:\"itemIds\"`" + `
|
||||
TotalPrice float64 ` + "`json:\"totalPrice\"`" + `
|
||||
UserID primitive.ObjectID ` + "`json:\"userId\"`" + `
|
||||
CreatedAt time.Time ` + "`json:\"createdAt\"`" + `
|
||||
}
|
||||
)
|
||||
`
|
||||
|
||||
const goImplFile2Data = `
|
||||
package codepkgsuccess
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
type OrderService interface {
|
||||
CreateOrder(u User, products map[Product]int) Order
|
||||
}
|
||||
|
||||
type OrderServiceImpl struct{}
|
||||
|
||||
func (s *OrderServiceImpl) CreateOrder(u User, products map[Product]int) Order {
|
||||
itemIDs := map[primitive.ObjectID]int{}
|
||||
var totalPrice float64
|
||||
for product, amount := range products {
|
||||
itemIDs[product.ID] = amount
|
||||
totalPrice += product.Price * float64(amount)
|
||||
}
|
||||
|
||||
return Order{
|
||||
ID: primitive.NewObjectID(),
|
||||
ItemIDs: map[primitive.ObjectID]int{},
|
||||
TotalPrice: totalPrice,
|
||||
UserID: u.ID,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
}
|
||||
`
|
||||
|
||||
const goImplFile3Data = `
|
||||
package success
|
||||
`
|
||||
|
||||
const goImplFile4Data = `
|
||||
package codepkgsuccess
|
||||
|
||||
type User struct {
|
||||
Name string
|
||||
}
|
||||
`
|
||||
|
||||
const goImplFile5Data = `
|
||||
package codepkgsuccess
|
||||
|
||||
import "go.mongodb.org/mongo-driver/bson/primitive"
|
||||
|
||||
type OrderService interface {
|
||||
CancelOrder(orderID primitive.ObjectID) error
|
||||
}
|
||||
`
|
||||
|
||||
const goTestFileData = `
|
||||
package codepkgsuccess
|
||||
|
||||
type TestCase struct {
|
||||
Name string
|
||||
Params []interface{}
|
||||
Expected string
|
||||
Actual string
|
||||
}
|
||||
`
|
||||
|
||||
var (
|
||||
goImplFile1 *ast.File
|
||||
goImplFile2 *ast.File
|
||||
goImplFile3 *ast.File
|
||||
goImplFile4 *ast.File
|
||||
goImplFile5 *ast.File
|
||||
goTestFile *ast.File
|
||||
)
|
||||
|
||||
func init() {
|
||||
fset := token.NewFileSet()
|
||||
goImplFile1, _ = parser.ParseFile(fset, "", goImplFile1Data, parser.ParseComments)
|
||||
goImplFile2, _ = parser.ParseFile(fset, "", goImplFile2Data, parser.ParseComments)
|
||||
goImplFile3, _ = parser.ParseFile(fset, "", goImplFile3Data, parser.ParseComments)
|
||||
goImplFile4, _ = parser.ParseFile(fset, "", goImplFile4Data, parser.ParseComments)
|
||||
goImplFile5, _ = parser.ParseFile(fset, "", goImplFile5Data, parser.ParseComments)
|
||||
goTestFile, _ = parser.ParseFile(fset, "", goTestFileData, parser.ParseComments)
|
||||
}
|
||||
|
||||
func TestParsePackage_Success(t *testing.T) {
|
||||
pkg, err := code.ParsePackage(map[string]*ast.Package{
|
||||
"codepkgsuccess": {
|
||||
Files: map[string]*ast.File{
|
||||
"file1.go": goImplFile1,
|
||||
"file2.go": goImplFile2,
|
||||
"file1_test.go": goTestFile,
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if pkg.Name != "codepkgsuccess" {
|
||||
t.Errorf("expected package name 'codepkgsuccess', got '%s'", pkg.Name)
|
||||
}
|
||||
if _, ok := pkg.Structs["User"]; !ok {
|
||||
t.Error("struct 'User' not found")
|
||||
}
|
||||
if _, ok := pkg.Structs["Product"]; !ok {
|
||||
t.Error("struct 'Product' not found")
|
||||
}
|
||||
if _, ok := pkg.Structs["Order"]; !ok {
|
||||
t.Error("struct 'Order' not found")
|
||||
}
|
||||
if _, ok := pkg.Structs["OrderServiceImpl"]; !ok {
|
||||
t.Error("struct 'OrderServiceImpl' not found")
|
||||
}
|
||||
if _, ok := pkg.Interfaces["OrderService"]; !ok {
|
||||
t.Error("interface 'OrderService' not found")
|
||||
}
|
||||
if _, ok := pkg.Structs["TestCase"]; ok {
|
||||
t.Error("unexpected struct 'TestCase' in test file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParsePackage_AmbiguousPackageName(t *testing.T) {
|
||||
_, err := code.ParsePackage(map[string]*ast.Package{
|
||||
"codepkgsuccess": {
|
||||
Files: map[string]*ast.File{
|
||||
"file1.go": goImplFile1,
|
||||
"file2.go": goImplFile2,
|
||||
"file3.go": goImplFile3,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
if !errors.Is(err, code.ErrAmbiguousPackageName) {
|
||||
t.Errorf(
|
||||
"expected error '%s', got '%s'",
|
||||
code.ErrAmbiguousPackageName.Error(),
|
||||
err.Error(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParsePackage_DuplicateStructs(t *testing.T) {
|
||||
_, err := code.ParsePackage(map[string]*ast.Package{
|
||||
"codepkgsuccess": {
|
||||
Files: map[string]*ast.File{
|
||||
"file1.go": goImplFile1,
|
||||
"file2.go": goImplFile2,
|
||||
"file4.go": goImplFile4,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
if !errors.Is(err, code.DuplicateStructError("User")) {
|
||||
t.Errorf(
|
||||
"expected error '%s', got '%s'",
|
||||
code.ErrAmbiguousPackageName.Error(),
|
||||
err.Error(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParsePackage_DuplicateInterfaces(t *testing.T) {
|
||||
_, err := code.ParsePackage(map[string]*ast.Package{
|
||||
"codepkgsuccess": {
|
||||
Files: map[string]*ast.File{
|
||||
"file1.go": goImplFile1,
|
||||
"file2.go": goImplFile2,
|
||||
"file5.go": goImplFile5,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
if !errors.Is(err, code.DuplicateInterfaceError("OrderService")) {
|
||||
t.Errorf(
|
||||
"expected error '%s', got '%s'",
|
||||
code.ErrAmbiguousPackageName.Error(),
|
||||
err.Error(),
|
||||
)
|
||||
}
|
||||
}
|
|
@ -24,7 +24,7 @@ func (r FieldReference) ReferencingCode() string {
|
|||
}
|
||||
|
||||
type fieldResolver struct {
|
||||
Structs code.Structs
|
||||
Structs map[string]code.Struct
|
||||
}
|
||||
|
||||
func (r fieldResolver) ResolveStructField(structModel code.Struct, tokens []string) (FieldReference, bool) {
|
||||
|
@ -46,7 +46,7 @@ func (r fieldResolver) ResolveStructField(structModel code.Struct, tokens []stri
|
|||
continue
|
||||
}
|
||||
|
||||
childStruct, ok := r.Structs.ByName(fieldSimpleType.Code())
|
||||
childStruct, ok := r.Structs[fieldSimpleType.Code()]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
|
|
@ -6,7 +6,7 @@ import (
|
|||
)
|
||||
|
||||
// ParseInterfaceMethod returns repository method spec from declared interface method
|
||||
func ParseInterfaceMethod(structs code.Structs, structModel code.Struct, method code.Method) (MethodSpec, error) {
|
||||
func ParseInterfaceMethod(structs map[string]code.Struct, structModel code.Struct, method code.Method) (MethodSpec, error) {
|
||||
parser := interfaceMethodParser{
|
||||
fieldResolver: fieldResolver{
|
||||
Structs: structs,
|
||||
|
|
|
@ -91,9 +91,9 @@ var (
|
|||
}
|
||||
)
|
||||
|
||||
var structs = code.Structs{
|
||||
nameStruct,
|
||||
structModel,
|
||||
var structs = map[string]code.Struct{
|
||||
nameStruct.Name: nameStruct,
|
||||
structModel.Name: structModel,
|
||||
}
|
||||
|
||||
type ParseInterfaceMethodTestCase struct {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue