Skip to content
This repository has been archived by the owner on Mar 26, 2020. It is now read-only.

Improve error messages and generate operator<< for records #417

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions example/generated-src/cpp/sort_order.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#pragma once

#include <functional>
#include <ostream>

namespace textsort {

Expand All @@ -13,6 +14,16 @@ enum class sort_order : int {
RANDOM,
};

static inline ::std::ostream& operator<<(::std::ostream& os, sort_order v) {
os << "sort_order::";
switch (v) {
case sort_order::ASCENDING: return os << "ASCENDING";
case sort_order::DESCENDING: return os << "DESCENDING";
case sort_order::RANDOM: return os << "RANDOM";
default: return os << "<Unsupported Value " << static_cast<int>(v) << ">";
}
}

} // namespace textsort

namespace std {
Expand Down
69 changes: 68 additions & 1 deletion src/source/CppGenerator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ class CppGenerator(spec: Spec) extends Generator(spec) {
if (spec.cppEnumHashWorkaround) {
refs.hpp.add("#include <functional>") // needed for std::hash
}
if (!e.flags) {
refs.hpp.add("#include <ostream>") // needed for printing to stream
}

val flagsType = "unsigned"
val enumType = "int"
Expand Down Expand Up @@ -91,6 +94,18 @@ class CppGenerator(spec: Spec) extends Generator(spec) {
w.w(s"constexpr $self operator~($self x) noexcept").braced {
w.wl(s"return static_cast<$self>(~static_cast<$flagsType>(x));")
}
} else {
w.wl
w.w(s"static inline ::std::ostream& operator<<(::std::ostream& os, $self v)").braced {
w.wl("os << \"" ++ self ++ "::\";")
w.w("switch (v)").braced {
for (o <- normalEnumOptions(e)) {
val name = idCpp.enum(o.ident.name)
w.wl("case " ++ self ++ "::" ++ name ++ ": return os << \"" ++ name ++ "\";")
}
w.wl("default: return os << \"<Unsupported Value \" << static_cast<" ++ underlyingType ++ ">(v) << \">\";")
}
}
}
},
w => {
Expand Down Expand Up @@ -189,6 +204,9 @@ class CppGenerator(spec: Spec) extends Generator(spec) {
r.fields.foreach(f => refs.find(f.ty, false))
r.consts.foreach(c => refs.find(c.ty, false))
refs.hpp.add("#include <utility>") // Add for std::move
if (r.derivingTypes.contains(DerivingType.Show)) {
refs.hpp.add("#include <ostream>") // Add for overloading operator<<
}

val self = marshal.typename(ident, r)
val (cppName, cppFinal) = if (r.ext.cpp) (ident.name + "_base", "") else (ident.name, " final")
Expand Down Expand Up @@ -231,6 +249,10 @@ class CppGenerator(spec: Spec) extends Generator(spec) {
w.wl(s"friend bool operator<=(const $actualSelf& lhs, const $actualSelf& rhs);")
w.wl(s"friend bool operator>=(const $actualSelf& lhs, const $actualSelf& rhs);")
}
if (r.derivingTypes.contains(DerivingType.Show)) {
w.wl
w.wl(s"friend ::std::ostream& operator<<(::std::ostream& os, const $actualSelf& self);")
}

// Constructor.
if(r.fields.nonEmpty) {
Expand All @@ -256,11 +278,16 @@ class CppGenerator(spec: Spec) extends Generator(spec) {
w.wl(s"$actualSelf& operator=($actualSelf&&) = default;")
}
}

if (r.derivingTypes.contains(DerivingType.Show)) {
w.wl
w.wl(s"::std::ostream& operator<<(::std::ostream& os, const $actualSelf& self);")
}
}

writeHppFile(cppName, origin, refs.hpp, refs.hppFwds, writeCppPrototype)

if (r.consts.nonEmpty || r.derivingTypes.contains(DerivingType.Eq) || r.derivingTypes.contains(DerivingType.Ord)) {
if (r.consts.nonEmpty || r.derivingTypes.contains(DerivingType.Eq) || r.derivingTypes.contains(DerivingType.Ord) || r.derivingTypes.contains(DerivingType.Show)) {
writeCppFile(cppName, origin, refs.cpp, w => {
generateCppConstants(w, r.consts, actualSelf)

Expand Down Expand Up @@ -307,6 +334,46 @@ class CppGenerator(spec: Spec) extends Generator(spec) {
w.wl("return !(lhs < rhs);")
}
}
if (r.derivingTypes.contains(DerivingType.Show)) {
def generateCollectionFormatterIfNeeded(parameters: Array[String], collection: String, open: String, close: String) {
// overload operator<< for collections that are actually used
if (refs.hpp.contains(s"#include <$collection>")) {
def chain(a: String, b: String): String = {
a ++ ", " ++ b
}

val parameterDecl = parameters.map(p => "typename " ++ p).reduceLeft(chain)
w.wl(s"template<$parameterDecl>")

val parameterDef = parameters.reduceLeft(chain)
val collectionDecl = s"::std::$collection<$parameterDef>"
w.w(s"::std::ostream& operator<<(::std::ostream& os, const $collectionDecl& c)").braced {
w.wl("auto first = true;")
w.wl("os << \"" ++ open ++ "\";")
w.w("for (auto&& element : c)").braced {
w.wl("if (first) first = false; else os << \",\" << ::std::endl;")
parameters.length match {
case 1 => w.wl("os << element;")
case 2 => w.wl("os << element.first << \": \" << element.second;")
}
}
w.wl("return os << \"" ++ close ++ "\";")
}
}
}
generateCollectionFormatterIfNeeded(Array("T"), "vector", "[", "]")
generateCollectionFormatterIfNeeded(Array("T"), "unordered_set", "{", "}")
generateCollectionFormatterIfNeeded(Array("K","V"), "unordered_map", "[", "]")

w.wl
w.w(s"::std::ostream& operator<<(::std::ostream& os, const $actualSelf& self)").braced {
w.wl("os << \"" ++ self ++ "{\";")
for(f <- r.fields) {
w.wl("os << \" " ++ idCpp.field(f.ident) ++ ":\" << self." ++ idCpp.field(f.ident) ++ ";")
}
w.wl("return os << \"}\";")
}
}
})
}

Expand Down
2 changes: 1 addition & 1 deletion src/source/ast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ case class Record(ext: Ext, fields: Seq[Field], consts: Seq[Const], derivingType
object Record {
object DerivingType extends Enumeration {
type DerivingType = Value
val Eq, Ord, AndroidParcelable = Value
val Eq, Ord, Show, AndroidParcelable = Value
}
}

Expand Down
1 change: 1 addition & 0 deletions src/source/parser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ private object IdlParser extends RegexParsers {
_.map(ident => ident.name match {
case "eq" => Record.DerivingType.Eq
case "ord" => Record.DerivingType.Ord
case "show" => Record.DerivingType.Show
case "parcelable" => Record.DerivingType.AndroidParcelable
case _ => return err( s"""Unrecognized deriving type "${ident.name}"""")
}).toSet
Expand Down
16 changes: 12 additions & 4 deletions src/source/resolver.scala
Original file line number Diff line number Diff line change
Expand Up @@ -240,17 +240,15 @@ private def resolveRecord(scope: Scope, r: Record) {
throw new Error(f.ident.loc, "Interface reference cannot live in a record").toException
case DRecord =>
val record = df.body.asInstanceOf[Record]
if (!r.derivingTypes.subsetOf(record.derivingTypes))
throw new Error(f.ident.loc, s"Some deriving required is not implemented in record ${f.ident.name}").toException
checkRecordRef(r, f, record)
case DEnum =>
}
case e: MExtern => e.defType match {
case DInterface =>
throw new Error(f.ident.loc, "Interface reference cannot live in a record").toException
case DRecord =>
val record = e.body.asInstanceOf[Record]
if (!r.derivingTypes.subsetOf(record.derivingTypes))
throw new Error(f.ident.loc, s"Some deriving required is not implemented in record ${f.ident.name}").toException
checkRecordRef(r, f, record)
case DEnum =>
}
case _ => throw new AssertionError("Type cannot be resolved")
Expand All @@ -263,6 +261,16 @@ private def resolveRecord(scope: Scope, r: Record) {
}
}

private def checkRecordRef(container: Record, field: Field, reference: Record) {
// Find missing directives so we can provide a useful error message
val missingTypes = container.derivingTypes -- reference.derivingTypes
if (!missingTypes.isEmpty) {
def describeType(t: DerivingType) = { "'" ++ t.toString.map(_.toLower) ++ "'" }
val names = missingTypes.tail.foldLeft(describeType(missingTypes.head))((s, t) => s ++ ", " ++ describeType(t))
throw new Error(field.ident.loc, s"Record '${field.ty.expr.ident.name}' not deriving required operation(s): ${names}").toException
}
}

private def resolveInterface(scope: Scope, i: Interface) {
// Const and static methods are only allowed on +c (only) interfaces
if (i.ext.java || i.ext.objc) {
Expand Down