diff --git a/cxxheaderparser/lexer.py b/cxxheaderparser/lexer.py index e700331..ccf6f08 100644 --- a/cxxheaderparser/lexer.py +++ b/cxxheaderparser/lexer.py @@ -149,6 +149,7 @@ class Lexer: "DBL_RBRACKET", "DBL_COLON", "DBL_AMP", + "ARROW", "SHIFT_LEFT", ] + list(keywords) @@ -217,6 +218,7 @@ class Lexer: t_DBL_RBRACKET = r"\]\]" t_DBL_COLON = r"::" t_DBL_AMP = r"&&" + t_ARROW = r"->" t_SHIFT_LEFT = r"<<" # SHIFT_RIGHT introduces ambiguity diff --git a/cxxheaderparser/parser.py b/cxxheaderparser/parser.py index 45bf168..735901f 100644 --- a/cxxheaderparser/parser.py +++ b/cxxheaderparser/parser.py @@ -1515,6 +1515,31 @@ class CxxParser: return params, vararg + _auto_return_typename = PQName([AutoSpecifier()]) + + def _parse_trailing_return_type( + self, fn: typing.Union[Function, FunctionType] + ) -> None: + # entry is "->" + return_type = fn.return_type + if not ( + isinstance(return_type, Type) + and not return_type.const + and not return_type.volatile + and return_type.typename == self._auto_return_typename + ): + raise CxxParseError( + f"function with trailing return type must specify return type of 'auto', not {return_type}" + ) + + parsed_type, mods = self._parse_type(None) + mods.validate(var_ok=False, meth_ok=False, msg="parsing trailing return type") + + dtype = self._parse_cv_ptr(parsed_type) + + fn.has_trailing_return = True + fn.return_type = dtype + def _parse_fn_end(self, fn: Function) -> None: """ Consumes the various keywords after the parameters in a function @@ -1535,6 +1560,8 @@ class CxxParser: if self.lex.token_if("{"): self._discard_contents("{", "}") fn.has_body = True + elif self.lex.token_if("ARROW"): + self._parse_trailing_return_type(fn) def _parse_method_end(self, method: Method) -> None: """ @@ -1577,6 +1604,9 @@ class CxxParser: setattr(method, tok_value, True) elif tok_value in ("&", "&&"): method.ref_qualifier = tok_value + elif tok_value == "ARROW": + self._parse_trailing_return_type(method) + break elif tok_value == "throw": tok = self._next_token_must_be("(") method.throw = self._create_value(self._consume_balanced_tokens(tok)) @@ -1667,7 +1697,7 @@ class CxxParser: self.visitor.on_class_method(state, method) - return method.has_body + return method.has_body or method.has_trailing_return else: fn = Function( @@ -1682,7 +1712,7 @@ class CxxParser: self._parse_fn_end(fn) self.visitor.on_function(state, fn) - return fn.has_body + return fn.has_body or fn.has_trailing_return # # Decorated type parsing @@ -1736,6 +1766,8 @@ class CxxParser: fn_params, vararg = self._parse_parameters() dtype = FunctionType(dtype, fn_params, vararg) + if self.lex.token_if("ARROW"): + self._parse_trailing_return_type(dtype) else: # Check to see if this is a grouping paren or something else diff --git a/cxxheaderparser/types.py b/cxxheaderparser/types.py index 21bfa5e..61cfca5 100644 --- a/cxxheaderparser/types.py +++ b/cxxheaderparser/types.py @@ -230,6 +230,11 @@ class FunctionType: #: Set to True if ends with ``...`` vararg: bool = False + #: True if function has a trailing return type (``auto foo() -> int``). + #: In this case, the 'auto' return type is removed and replaced with + #: whatever the trailing return type was + has_trailing_return: bool = False + @dataclass class Type: @@ -469,6 +474,11 @@ class Function: #: If true, the body of the function is present has_body: bool = False + #: True if function has a trailing return type (``auto foo() -> int``). + #: In this case, the 'auto' return type is removed and replaced with + #: whatever the trailing return type was + has_trailing_return: bool = False + template: typing.Optional[TemplateDecl] = None throw: typing.Optional[Value] = None diff --git a/tests/test_fn.py b/tests/test_fn.py index 9c3ad56..e5fee55 100644 --- a/tests/test_fn.py +++ b/tests/test_fn.py @@ -585,3 +585,131 @@ def test_fn_return_std_function(): assert data1 == expected assert data2 == expected + + +def test_fn_return_std_function_trailing(): + content = """ + std::functionint> fn(); + """ + data = parse_string(content, cleandoc=True) + + assert data == ParsedData( + namespace=NamespaceScope( + functions=[ + Function( + return_type=Type( + typename=PQName( + segments=[ + NameSpecifier(name="std"), + NameSpecifier( + name="function", + specialization=TemplateSpecialization( + args=[ + TemplateArgument( + arg=FunctionType( + return_type=Type( + typename=PQName( + segments=[ + FundamentalSpecifier( + name="int" + ) + ] + ) + ), + parameters=[ + Parameter( + type=Type( + typename=PQName( + segments=[ + FundamentalSpecifier( + name="int" + ) + ] + ) + ) + ) + ], + has_trailing_return=True, + ) + ) + ] + ), + ), + ] + ) + ), + name=PQName(segments=[NameSpecifier(name="fn")]), + parameters=[], + ) + ] + ) + ) + + +def test_fn_trailing_return_simple(): + content = """ + auto fn() -> int; + """ + data = parse_string(content, cleandoc=True) + + assert data == ParsedData( + namespace=NamespaceScope( + functions=[ + Function( + return_type=Type( + typename=PQName(segments=[FundamentalSpecifier(name="int")]) + ), + name=PQName(segments=[NameSpecifier(name="fn")]), + parameters=[], + has_trailing_return=True, + ) + ] + ) + ) + + +def test_fn_trailing_return_std_function(): + content = """ + auto fn() -> std::function; + """ + data = parse_string(content, cleandoc=True) + + assert data == ParsedData( + namespace=NamespaceScope( + functions=[ + Function( + return_type=Type( + typename=PQName( + segments=[ + NameSpecifier(name="std"), + NameSpecifier( + name="function", + specialization=TemplateSpecialization( + args=[ + TemplateArgument( + arg=FunctionType( + return_type=Type( + typename=PQName( + segments=[ + FundamentalSpecifier( + name="int" + ) + ] + ) + ), + parameters=[], + ) + ) + ] + ), + ), + ] + ) + ), + name=PQName(segments=[NameSpecifier(name="fn")]), + parameters=[], + has_trailing_return=True, + ) + ] + ) + )