Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix cyclic definition for CollectionOf #2517

Merged
merged 3 commits into from
Mar 25, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 165 additions & 0 deletions codegen/service/testdata/views_code.go
Original file line number Diff line number Diff line change
Expand Up @@ -547,3 +547,168 @@ func ValidateRTViewTiny(result *RTView) (err error) {
return
}
`

const ResultWithRecursiveCollectionOfResultTypeCode = `// SomeRT is the viewed result type that is projected based on a view.
type SomeRT struct {
// Type to project
Projected *SomeRTView
// View to render
View string
}

// AnotherResult is the viewed result type that is projected based on a view.
type AnotherResult struct {
// Type to project
Projected *AnotherResultView
// View to render
View string
}

// SomeRTView is a type that runs validations on a projected type.
type SomeRTView struct {
A SomeRTCollectionView
}

// SomeRTCollectionView is a type that runs validations on a projected type.
type SomeRTCollectionView []*SomeRTView

// AnotherResultView is a type that runs validations on a projected type.
type AnotherResultView struct {
A AnotherResultCollectionView
}

// AnotherResultCollectionView is a type that runs validations on a projected
// type.
type AnotherResultCollectionView []*AnotherResultView

var (
// SomeRTMap is a map of attribute names in result type SomeRT indexed by view
// name.
SomeRTMap = map[string][]string{
"default": []string{
"a",
},
"tiny": []string{
"a",
},
}
// AnotherResultMap is a map of attribute names in result type AnotherResult
// indexed by view name.
AnotherResultMap = map[string][]string{
"default": []string{
"a",
},
}
// SomeRTCollectionMap is a map of attribute names in result type
// SomeRTCollection indexed by view name.
SomeRTCollectionMap = map[string][]string{
"default": []string{
"a",
},
"tiny": []string{
"a",
},
}
// AnotherResultCollectionMap is a map of attribute names in result type
// AnotherResultCollection indexed by view name.
AnotherResultCollectionMap = map[string][]string{
"default": []string{
"a",
},
}
)

// ValidateSomeRT runs the validations defined on the viewed result type SomeRT.
func ValidateSomeRT(result *SomeRT) (err error) {
switch result.View {
case "default", "":
err = ValidateSomeRTView(result.Projected)
case "tiny":
err = ValidateSomeRTViewTiny(result.Projected)
default:
err = goa.InvalidEnumValueError("view", result.View, []interface{}{"default", "tiny"})
}
return
}

// ValidateAnotherResult runs the validations defined on the viewed result type
// AnotherResult.
func ValidateAnotherResult(result *AnotherResult) (err error) {
switch result.View {
case "default", "":
err = ValidateAnotherResultView(result.Projected)
default:
err = goa.InvalidEnumValueError("view", result.View, []interface{}{"default"})
}
return
}

// ValidateSomeRTView runs the validations defined on SomeRTView using the
// "default" view.
func ValidateSomeRTView(result *SomeRTView) (err error) {

if result.A != nil {
if err2 := ValidateSomeRTCollectionViewTiny(result.A); err2 != nil {
err = goa.MergeErrors(err, err2)
}
}
return
}

// ValidateSomeRTViewTiny runs the validations defined on SomeRTView using the
// "tiny" view.
func ValidateSomeRTViewTiny(result *SomeRTView) (err error) {

if result.A != nil {
if err2 := ValidateSomeRTCollectionView(result.A); err2 != nil {
err = goa.MergeErrors(err, err2)
}
}
return
}

// ValidateSomeRTCollectionView runs the validations defined on
// SomeRTCollectionView using the "default" view.
func ValidateSomeRTCollectionView(result SomeRTCollectionView) (err error) {
for _, item := range result {
if err2 := ValidateSomeRTView(item); err2 != nil {
err = goa.MergeErrors(err, err2)
}
}
return
}

// ValidateSomeRTCollectionViewTiny runs the validations defined on
// SomeRTCollectionView using the "tiny" view.
func ValidateSomeRTCollectionViewTiny(result SomeRTCollectionView) (err error) {
for _, item := range result {
if err2 := ValidateSomeRTViewTiny(item); err2 != nil {
err = goa.MergeErrors(err, err2)
}
}
return
}

// ValidateAnotherResultView runs the validations defined on AnotherResultView
// using the "default" view.
func ValidateAnotherResultView(result *AnotherResultView) (err error) {

if result.A != nil {
if err2 := ValidateAnotherResultCollectionView(result.A); err2 != nil {
err = goa.MergeErrors(err, err2)
}
}
return
}

// ValidateAnotherResultCollectionView runs the validations defined on
// AnotherResultCollectionView using the "default" view.
func ValidateAnotherResultCollectionView(result AnotherResultCollectionView) (err error) {
for _, item := range result {
if err2 := ValidateAnotherResultView(item); err2 != nil {
err = goa.MergeErrors(err, err2)
}
}
return
}
`
32 changes: 32 additions & 0 deletions codegen/service/testdata/views_dsls.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,38 @@ var ResultWithRecursiveResultTypeDSL = func() {
})
}

var ResultWithRecursiveCollectionOfResultTypeDSL = func() {
var SomeRT = ResultType("application/vnd.some_result", func() {
TypeName("SomeRT")
Attributes(func() {
Attribute("a", CollectionOf("SomeRT"))
Required("a")
})
View("default", func() {
Attribute("a", func() {
View("tiny")
})
})
View("tiny", func() {
Attribute("a")
})
})
var AnotherRT = ResultType("application/vnd.another_result", func() {
Attributes(func() {
Attribute("a", CollectionOf("application/vnd.another_result"))
Required("a")
})
})
Service("ResultWithRecursiveCollectionOfResultType", func() {
Method("A", func() {
Result(SomeRT)
})
Method("B", func() {
Result(AnotherRT)
})
})
}

var ResultWithCustomFieldsDSL = func() {
var RT = ResultType("application/vnd.result", func() {
TypeName("RT")
Expand Down
1 change: 1 addition & 0 deletions codegen/service/views_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ func TestViews(t *testing.T) {
{"result-with-result-type", testdata.ResultWithResultTypeDSL, testdata.ResultWithResultTypeCode},
{"result-with-recursive-result-type", testdata.ResultWithRecursiveResultTypeDSL, testdata.ResultWithRecursiveResultTypeCode},
{"result-type-with-custom-fields", testdata.ResultWithCustomFieldsDSL, testdata.ResultWithCustomFieldsCode},
{"result-with-recursive-collection-of-result-type", testdata.ResultWithRecursiveCollectionOfResultTypeDSL, testdata.ResultWithRecursiveCollectionOfResultTypeCode},
}
for _, c := range cases {
t.Run(c.Name, func(t *testing.T) {
Expand Down
68 changes: 45 additions & 23 deletions dsl/result_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,13 @@ func ResultType(identifier string, fn func()) *expr.ResultTypeExpr {
return nil
}

identifier, typeName, err := mediaTypeToResultType(identifier)
// Validate Result Type
identifier, params, err := mime.ParseMediaType(identifier)
if err != nil {
eval.ReportError("invalid result type identifier %#v: %s",
identifier, err)
// We don't return so that other errors may be captured in this
// one run.
identifier = "text/plain"
}
canonicalID := expr.CanonicalIdentifier(identifier)
// Validate that result type identifier doesn't clash
Expand All @@ -86,26 +85,6 @@ func ResultType(identifier string, fn func()) *expr.ResultTypeExpr {
return nil
}
}
identifier = mime.FormatMediaType(identifier, params)
lastPart := identifier
lastPartIndex := strings.LastIndex(identifier, "/")
if lastPartIndex > -1 {
lastPart = identifier[lastPartIndex+1:]
}
plusIndex := strings.Index(lastPart, "+")
if plusIndex > 0 {
lastPart = lastPart[:plusIndex]
}
lastPart = strings.TrimPrefix(lastPart, "vnd.")
elems := strings.Split(lastPart, ".")
for i, e := range elems {
elems[i] = strings.Title(e)
}
typeName := strings.Join(elems, "")
if typeName == "" {
resultTypeCount++
typeName = fmt.Sprintf("ResultType%d", resultTypeCount)
}
// Now save the type in the API result types map
mt := expr.NewResultTypeExpr(typeName, identifier, fn)
expr.Root.ResultTypes = append(expr.Root.ResultTypes, mt)
Expand Down Expand Up @@ -282,10 +261,22 @@ func CollectionOf(v interface{}, adsl ...func()) *expr.ResultTypeExpr {
m, ok = v.(*expr.ResultTypeExpr)
if !ok {
if id, ok := v.(string); ok {
if dt := expr.Root.UserType(expr.CanonicalIdentifier(id)); dt != nil {
// Check if a result type exists with the given type name
if dt := expr.Root.UserType(id); dt != nil {
if mt, ok := dt.(*expr.ResultTypeExpr); ok {
m = mt
}
} else {
// Check if a result type exists with the given identifier
id, typeName, err := mediaTypeToResultType(id)
if err != nil {
eval.ReportError("invalid result type identifier %#v in CollectionOf: %s", id, err)
}
if dt := expr.Root.UserType(typeName); dt != nil {
nitinmohan87 marked this conversation as resolved.
Show resolved Hide resolved
if mt, ok := dt.(*expr.ResultTypeExpr); ok {
m = mt
}
}
}
}
}
Expand Down Expand Up @@ -447,6 +438,37 @@ func Attributes(fn func()) {
eval.Execute(fn, mt)
}

// mediaTypeToResultType returns the formatted identifier and the result type
// name from the given identifier string. If the given identifier is invalid it
// returns text/plain as the identifier and an error.
func mediaTypeToResultType(identifier string) (string, string, error) {
identifier, params, err := mime.ParseMediaType(identifier)
if err != nil {
identifier = "text/plain"
}
identifier = mime.FormatMediaType(identifier, params)
lastPart := identifier
lastPartIndex := strings.LastIndex(identifier, "/")
if lastPartIndex > -1 {
lastPart = identifier[lastPartIndex+1:]
}
plusIndex := strings.Index(lastPart, "+")
if plusIndex > 0 {
lastPart = lastPart[:plusIndex]
}
lastPart = strings.TrimPrefix(lastPart, "vnd.")
elems := strings.Split(lastPart, ".")
for i, e := range elems {
elems[i] = strings.Title(e)
}
typeName := strings.Join(elems, "")
if typeName == "" {
resultTypeCount++
typeName = fmt.Sprintf("ResultType%d", resultTypeCount)
}
return identifier, typeName, err
}

// buildView builds a view expression given an attribute and a corresponding
// result type. The attribute must be an object listing the child attributes
// that make up the view.
Expand Down