Scan directory to find source code in multiple files in the same package ()

This commit is contained in:
sunboyy 2022-10-17 17:57:49 +07:00 committed by GitHub
parent 0a1d5c8545
commit 737c1a4044
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 444 additions and 177 deletions

28
internal/code/errors.go Normal file
View file

@ -0,0 +1,28 @@
package code
import (
"errors"
"fmt"
)
var (
ErrAmbiguousPackageName = errors.New("code: ambiguous package name")
)
type DuplicateStructError string
func (err DuplicateStructError) Error() string {
return fmt.Sprintf(
"code: duplicate implementation of struct '%s'",
string(err),
)
}
type DuplicateInterfaceError string
func (err DuplicateInterfaceError) Error() string {
return fmt.Sprintf(
"code: duplicate implementation of interface '%s'",
string(err),
)
}

View file

@ -0,0 +1,36 @@
package code_test
import (
"testing"
"github.com/sunboyy/repogen/internal/code"
)
type ErrorTestCase struct {
Name string
Error error
ExpectedString string
}
func TestError(t *testing.T) {
testTable := []ErrorTestCase{
{
Name: "DuplicateStructError",
Error: code.DuplicateStructError("User"),
ExpectedString: "code: duplicate implementation of struct 'User'",
},
{
Name: "DuplicateInterfaceError",
Error: code.DuplicateInterfaceError("UserRepository"),
ExpectedString: "code: duplicate implementation of interface 'UserRepository'",
},
}
for _, testCase := range testTable {
t.Run(testCase.Name, func(t *testing.T) {
if testCase.Error.Error() != testCase.ExpectedString {
t.Errorf("Expected = %+v\nReceived = %+v", testCase.ExpectedString, testCase.Error.Error())
}
})
}
}

View file

@ -64,8 +64,8 @@ type UserModel struct {
}`,
ExpectedOutput: code.File{
PackageName: "user",
Structs: code.Structs{
code.Struct{
Structs: []code.Struct{
{
Name: "UserModel",
Fields: code.StructFields{
code.StructField{
@ -108,8 +108,8 @@ type UserRepository interface {
}`,
ExpectedOutput: code.File{
PackageName: "user",
Interfaces: code.Interfaces{
code.InterfaceType{
Interfaces: []code.InterfaceType{
{
Name: "UserRepository",
Methods: []code.Method{
{
@ -229,8 +229,8 @@ type UserRepository interface {
{Path: "context"},
{Path: "go.mongodb.org/mongo-driver/bson/primitive"},
},
Structs: code.Structs{
code.Struct{
Structs: []code.Struct{
{
Name: "UserModel",
Fields: code.StructFields{
code.StructField{
@ -252,8 +252,8 @@ type UserRepository interface {
},
},
},
Interfaces: code.Interfaces{
code.InterfaceType{
Interfaces: []code.InterfaceType{
{
Name: "UserRepository",
Methods: []code.Method{
{

View file

@ -8,8 +8,8 @@ import (
type File struct {
PackageName string
Imports []Import
Structs Structs
Interfaces Interfaces
Structs []Struct
Interfaces []InterfaceType
}
// Import is a model for package imports
@ -18,20 +18,6 @@ type Import struct {
Path string
}
// Structs is a group of Struct model
type Structs []Struct
// ByName return struct with matching name. Another return value shows whether there is a struct
// with that name exists.
func (strs Structs) ByName(name string) (Struct, bool) {
for _, str := range strs {
if str.Name == name {
return str, true
}
}
return Struct{}, false
}
// Struct is a definition of the struct
type Struct struct {
Name string
@ -63,20 +49,6 @@ type StructField struct {
Tags map[string][]string
}
// Interfaces is a group of Interface model
type Interfaces []InterfaceType
// ByName return interface by name Another return value shows whether there is an interface
// with that name exists.
func (intfs Interfaces) ByName(name string) (InterfaceType, bool) {
for _, intf := range intfs {
if intf.Name == name {
return intf, true
}
}
return InterfaceType{}, false
}
// InterfaceType is a definition of the interface
type InterfaceType struct {
Name string

View file

@ -7,36 +7,6 @@ import (
"github.com/sunboyy/repogen/internal/code"
)
func TestStructsByName(t *testing.T) {
userStruct := code.Struct{
Name: "UserModel",
Fields: code.StructFields{
code.StructField{Name: "ID", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}},
code.StructField{Name: "Username", Type: code.SimpleType("string")},
},
}
structs := code.Structs{userStruct}
t.Run("struct found", func(t *testing.T) {
structModel, ok := structs.ByName("UserModel")
if !ok {
t.Fail()
}
if !reflect.DeepEqual(structModel, userStruct) {
t.Errorf("Expected = %+v\nReceived = %+v", userStruct, structModel)
}
})
t.Run("struct not found", func(t *testing.T) {
_, ok := structs.ByName("ProductModel")
if ok {
t.Fail()
}
})
}
func TestStructFieldsByName(t *testing.T) {
idField := code.StructField{Name: "ID", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}
usernameField := code.StructField{Name: "Username", Type: code.SimpleType("string")}
@ -62,30 +32,6 @@ func TestStructFieldsByName(t *testing.T) {
})
}
func TestInterfacesByName(t *testing.T) {
userRepoIntf := code.InterfaceType{Name: "UserRepository"}
interfaces := code.Interfaces{userRepoIntf}
t.Run("struct field found", func(t *testing.T) {
intf, ok := interfaces.ByName("UserRepository")
if !ok {
t.Fail()
}
if !reflect.DeepEqual(intf, userRepoIntf) {
t.Errorf("Expected = %+v\nReceived = %+v", userRepoIntf, intf)
}
})
t.Run("struct field not found", func(t *testing.T) {
_, ok := interfaces.ByName("Password")
if ok {
t.Fail()
}
})
}
type TypeCodeTestCase struct {
Name string
Type code.Type

66
internal/code/package.go Normal file
View file

@ -0,0 +1,66 @@
package code
import (
"go/ast"
"strings"
)
// ParsePackage extracts package name, struct and interface implementations from
// map[string]*ast.Package. Test files will be ignored.
func ParsePackage(pkgs map[string]*ast.Package) (Package, error) {
pkg := NewPackage()
for _, astPkg := range pkgs {
for fileName, file := range astPkg.Files {
if strings.HasSuffix(fileName, "_test.go") {
continue
}
if err := pkg.addFile(ExtractComponents(file)); err != nil {
return Package{}, err
}
}
}
return pkg, nil
}
// Package stores package name, struct and interface implementations as a result
// from ParsePackage
type Package struct {
Name string
Structs map[string]Struct
Interfaces map[string]InterfaceType
}
// NewPackage is a constructor function for Package.
func NewPackage() Package {
return Package{
Structs: map[string]Struct{},
Interfaces: map[string]InterfaceType{},
}
}
// addFile alters the Package by adding struct and interface implementations in
// the extracted file. If the package name conflicts, it will return error.
func (pkg *Package) addFile(file File) error {
if pkg.Name == "" {
pkg.Name = file.PackageName
} else if pkg.Name != file.PackageName {
return ErrAmbiguousPackageName
}
for _, structImpl := range file.Structs {
if _, ok := pkg.Structs[structImpl.Name]; ok {
return DuplicateStructError(structImpl.Name)
}
pkg.Structs[structImpl.Name] = structImpl
}
for _, interfaceImpl := range file.Interfaces {
if _, ok := pkg.Interfaces[interfaceImpl.Name]; ok {
return DuplicateInterfaceError(interfaceImpl.Name)
}
pkg.Interfaces[interfaceImpl.Name] = interfaceImpl
}
return nil
}

View file

@ -0,0 +1,238 @@
package code_test
import (
"errors"
"go/ast"
"go/parser"
"go/token"
"testing"
"github.com/sunboyy/repogen/internal/code"
)
const goImplFile1Data = `
package codepkgsuccess
import (
"math"
"time"
"go.mongodb.org/mongo-driver/bson/primitive"
)
type Gender string
const (
GenderMale Gender = "MALE"
GenderFemale Gender = "FEMALE"
)
type User struct {
ID primitive.ObjectID ` + "`json:\"id\"`" + `
Name string ` + "`json:\"name\"`" + `
Gender Gender ` + "`json:\"gender\"`" + `
Birthday time.Time ` + "`json:\"birthday\"`" + `
}
func (u User) Age() int {
return int(math.Floor(time.Since(u.Birthday).Hours() / 24 / 365))
}
type (
Product struct {
ID primitive.ObjectID ` + "`json:\"id\"`" + `
Name string ` + "`json:\"name\"`" + `
Price float64 ` + "`json:\"price\"`" + `
}
Order struct {
ID primitive.ObjectID ` + "`json:\"id\"`" + `
ItemIDs map[primitive.ObjectID]int ` + "`json:\"itemIds\"`" + `
TotalPrice float64 ` + "`json:\"totalPrice\"`" + `
UserID primitive.ObjectID ` + "`json:\"userId\"`" + `
CreatedAt time.Time ` + "`json:\"createdAt\"`" + `
}
)
`
const goImplFile2Data = `
package codepkgsuccess
import (
"time"
"go.mongodb.org/mongo-driver/bson/primitive"
)
type OrderService interface {
CreateOrder(u User, products map[Product]int) Order
}
type OrderServiceImpl struct{}
func (s *OrderServiceImpl) CreateOrder(u User, products map[Product]int) Order {
itemIDs := map[primitive.ObjectID]int{}
var totalPrice float64
for product, amount := range products {
itemIDs[product.ID] = amount
totalPrice += product.Price * float64(amount)
}
return Order{
ID: primitive.NewObjectID(),
ItemIDs: map[primitive.ObjectID]int{},
TotalPrice: totalPrice,
UserID: u.ID,
CreatedAt: time.Now(),
}
}
`
const goImplFile3Data = `
package success
`
const goImplFile4Data = `
package codepkgsuccess
type User struct {
Name string
}
`
const goImplFile5Data = `
package codepkgsuccess
import "go.mongodb.org/mongo-driver/bson/primitive"
type OrderService interface {
CancelOrder(orderID primitive.ObjectID) error
}
`
const goTestFileData = `
package codepkgsuccess
type TestCase struct {
Name string
Params []interface{}
Expected string
Actual string
}
`
var (
goImplFile1 *ast.File
goImplFile2 *ast.File
goImplFile3 *ast.File
goImplFile4 *ast.File
goImplFile5 *ast.File
goTestFile *ast.File
)
func init() {
fset := token.NewFileSet()
goImplFile1, _ = parser.ParseFile(fset, "", goImplFile1Data, parser.ParseComments)
goImplFile2, _ = parser.ParseFile(fset, "", goImplFile2Data, parser.ParseComments)
goImplFile3, _ = parser.ParseFile(fset, "", goImplFile3Data, parser.ParseComments)
goImplFile4, _ = parser.ParseFile(fset, "", goImplFile4Data, parser.ParseComments)
goImplFile5, _ = parser.ParseFile(fset, "", goImplFile5Data, parser.ParseComments)
goTestFile, _ = parser.ParseFile(fset, "", goTestFileData, parser.ParseComments)
}
func TestParsePackage_Success(t *testing.T) {
pkg, err := code.ParsePackage(map[string]*ast.Package{
"codepkgsuccess": {
Files: map[string]*ast.File{
"file1.go": goImplFile1,
"file2.go": goImplFile2,
"file1_test.go": goTestFile,
},
},
})
if err != nil {
t.Fatal(err)
}
if pkg.Name != "codepkgsuccess" {
t.Errorf("expected package name 'codepkgsuccess', got '%s'", pkg.Name)
}
if _, ok := pkg.Structs["User"]; !ok {
t.Error("struct 'User' not found")
}
if _, ok := pkg.Structs["Product"]; !ok {
t.Error("struct 'Product' not found")
}
if _, ok := pkg.Structs["Order"]; !ok {
t.Error("struct 'Order' not found")
}
if _, ok := pkg.Structs["OrderServiceImpl"]; !ok {
t.Error("struct 'OrderServiceImpl' not found")
}
if _, ok := pkg.Interfaces["OrderService"]; !ok {
t.Error("interface 'OrderService' not found")
}
if _, ok := pkg.Structs["TestCase"]; ok {
t.Error("unexpected struct 'TestCase' in test file")
}
}
func TestParsePackage_AmbiguousPackageName(t *testing.T) {
_, err := code.ParsePackage(map[string]*ast.Package{
"codepkgsuccess": {
Files: map[string]*ast.File{
"file1.go": goImplFile1,
"file2.go": goImplFile2,
"file3.go": goImplFile3,
},
},
})
if !errors.Is(err, code.ErrAmbiguousPackageName) {
t.Errorf(
"expected error '%s', got '%s'",
code.ErrAmbiguousPackageName.Error(),
err.Error(),
)
}
}
func TestParsePackage_DuplicateStructs(t *testing.T) {
_, err := code.ParsePackage(map[string]*ast.Package{
"codepkgsuccess": {
Files: map[string]*ast.File{
"file1.go": goImplFile1,
"file2.go": goImplFile2,
"file4.go": goImplFile4,
},
},
})
if !errors.Is(err, code.DuplicateStructError("User")) {
t.Errorf(
"expected error '%s', got '%s'",
code.ErrAmbiguousPackageName.Error(),
err.Error(),
)
}
}
func TestParsePackage_DuplicateInterfaces(t *testing.T) {
_, err := code.ParsePackage(map[string]*ast.Package{
"codepkgsuccess": {
Files: map[string]*ast.File{
"file1.go": goImplFile1,
"file2.go": goImplFile2,
"file5.go": goImplFile5,
},
},
})
if !errors.Is(err, code.DuplicateInterfaceError("OrderService")) {
t.Errorf(
"expected error '%s', got '%s'",
code.ErrAmbiguousPackageName.Error(),
err.Error(),
)
}
}

View file

@ -24,7 +24,7 @@ func (r FieldReference) ReferencingCode() string {
}
type fieldResolver struct {
Structs code.Structs
Structs map[string]code.Struct
}
func (r fieldResolver) ResolveStructField(structModel code.Struct, tokens []string) (FieldReference, bool) {
@ -46,7 +46,7 @@ func (r fieldResolver) ResolveStructField(structModel code.Struct, tokens []stri
continue
}
childStruct, ok := r.Structs.ByName(fieldSimpleType.Code())
childStruct, ok := r.Structs[fieldSimpleType.Code()]
if !ok {
continue
}

View file

@ -6,7 +6,7 @@ import (
)
// ParseInterfaceMethod returns repository method spec from declared interface method
func ParseInterfaceMethod(structs code.Structs, structModel code.Struct, method code.Method) (MethodSpec, error) {
func ParseInterfaceMethod(structs map[string]code.Struct, structModel code.Struct, method code.Method) (MethodSpec, error) {
parser := interfaceMethodParser{
fieldResolver: fieldResolver{
Structs: structs,

View file

@ -91,9 +91,9 @@ var (
}
)
var structs = code.Structs{
nameStruct,
structModel,
var structs = map[string]code.Struct{
nameStruct.Name: nameStruct,
structModel.Name: structModel,
}
type ParseInterfaceMethodTestCase struct {