diff --git a/internal/sample/controller/handlers.go b/internal/sample/controller/handlers.go index f6aac35..7cbefb9 100644 --- a/internal/sample/controller/handlers.go +++ b/internal/sample/controller/handlers.go @@ -13,6 +13,14 @@ func LoginHandler(ctx gin.Context) { } +type EmbeddedGroup struct { + gin.RouterGroup +} + +type GroupWithNonEmbedded struct { + group gin.RouterGroup +} + type MyRouterGroup struct{} func (m MyRouterGroup) Use(handlerFunc ...gin.HandlerFunc) gin.IRoutes { @@ -96,3 +104,4 @@ func (m MyRouterGroup) Group(s string, handlerFunc ...gin.HandlerFunc) *gin.Rout } var _ gin.IRouter = (*MyRouterGroup)(nil) +var _ gin.IRouter = (*EmbeddedGroup)(nil) diff --git a/type.go b/type.go index e537756..542d817 100644 --- a/type.go +++ b/type.go @@ -22,8 +22,33 @@ func AllTypes() Types { return typs } -func TypesEmbeddedWith(embeds ...string) Types { - panic("to be implemented") +func TypesEmbeddedWith(embeddedType string) Types { + eType, ok := internal.Arch().Type(embeddedType) + if !ok { + log.Fatalf("can not find interface %s", embeddedType) + } + var typMap sync.Map + lo.ForEach(internal.Arch().Packages(), func(pkg *internal.Package, index int) { + if strings.HasPrefix(pkg.ID(), internal.Arch().Module()) && + (pkg.ID() == eType.Package() || lo.Contains(pkg.Imports(), eType.Package())) { + lop.ForEach(pkg.Types(), func(typ internal.Type, index int) { + if str, ok := typ.Raw().Underlying().(*types.Struct); ok { + for i := 0; i < str.NumFields(); i++ { + if v := str.Field(i); v.Embedded() && types.Identical(v.Type(), eType.Raw()) { + typMap.Store(index, typ) + } + } + } + }) + } + }) + var typs Types + typMap.Range(func(_, value any) bool { + typs = append(typs, value.(internal.Type)) + return true + }) + return typs + } // TypesImplement return all the types implement the interface @@ -33,20 +58,19 @@ func TypesImplement(interName string) Types { log.Fatalf("can not find interface %s", interName) } var typMap sync.Map - lop.ForEach(internal.Arch().Packages(), func(pkg *internal.Package, index int) { + lo.ForEach(internal.Arch().Packages(), func(pkg *internal.Package, index int) { if strings.HasPrefix(pkg.ID(), internal.Arch().Module()) && (pkg.ID() == interType.Package() || lo.Contains(pkg.Imports(), interType.Package())) { - implementations := lo.Filter(pkg.Types(), func(typ internal.Type, _ int) bool { - return !strings.HasSuffix(typ.Name(), interName) && types.Implements(typ.Raw(), interType.Raw().Underlying().(*types.Interface)) + lop.ForEach(pkg.Types(), func(typ internal.Type, index int) { + if !strings.HasSuffix(typ.Name(), interName) && types.Implements(typ.Raw(), interType.Raw().Underlying().(*types.Interface)) { + typMap.Store(index, typ) + } }) - if len(implementations) > 0 { - typMap.Store(pkg.ID(), implementations) - } } }) var typs Types typMap.Range(func(_, value any) bool { - typs = append(typs, value.([]internal.Type)...) + typs = append(typs, value.(internal.Type)) return true }) return typs diff --git a/type_test.go b/type_test.go index 58c45d8..a33b8d0 100644 --- a/type_test.go +++ b/type_test.go @@ -45,6 +45,8 @@ func TestAllTypes(t *testing.T) { "github.com/kcmvp/archunit/internal/sample/repository.FF", "github.com/kcmvp/archunit/internal/sample/repository.UserRepository", "github.com/kcmvp/archunit/internal/sample/controller.MyRouterGroup", + "github.com/kcmvp/archunit/internal/sample/controller.EmbeddedGroup", + "github.com/kcmvp/archunit/internal/sample/controller.GroupWithNonEmbedded", } assert.ElementsMatch(t, expected, typs) } @@ -77,3 +79,29 @@ func TestTypeImplement(t *testing.T) { }) } } + +func TestTypesEmbeddedWith(t *testing.T) { + tests := []struct { + interType string + implementation []string + }{ + { + interType: "github.com/gin-gonic/gin.RouterGroup", + implementation: []string{ + "github.com/kcmvp/archunit/internal/sample/controller.EmbeddedGroup", + }, + }, + { + interType: "github.com/gin-gonic/gin.IRouter", + implementation: []string{}, + }, + } + for _, test := range tests { + t.Run(test.interType, func(t *testing.T) { + types := lo.Map(TypesEmbeddedWith(test.interType), func(item internal.Type, _ int) string { + return item.Name() + }) + assert.ElementsMatch(t, test.implementation, types) + }) + } +}