Added support for template deduction guides

* Added DeductionGuide as a language element
This commit is contained in:
Justin Boswell 2023-11-30 23:07:03 -08:00 committed by Dustin Spicuzza
parent 64c5290318
commit 88a7048513
5 changed files with 164 additions and 11 deletions

View File

@ -25,6 +25,7 @@ from .types import (
Concept, Concept,
DecltypeSpecifier, DecltypeSpecifier,
DecoratedType, DecoratedType,
DeductionGuide,
EnumDecl, EnumDecl,
Enumerator, Enumerator,
Field, Field,
@ -1868,10 +1869,9 @@ class CxxParser:
_auto_return_typename = PQName([AutoSpecifier()]) _auto_return_typename = PQName([AutoSpecifier()])
def _parse_trailing_return_type( def _parse_trailing_return_type(
self, fn: typing.Union[Function, FunctionType] self, return_type: typing.Optional[DecoratedType]
) -> None: ) -> DecoratedType:
# entry is "->" # entry is "->"
return_type = fn.return_type
if not ( if not (
isinstance(return_type, Type) isinstance(return_type, Type)
and not return_type.const and not return_type.const
@ -1890,8 +1890,7 @@ class CxxParser:
dtype = self._parse_cv_ptr(parsed_type) dtype = self._parse_cv_ptr(parsed_type)
fn.has_trailing_return = True return dtype
fn.return_type = dtype
def _parse_fn_end(self, fn: Function) -> None: def _parse_fn_end(self, fn: Function) -> None:
""" """
@ -1918,7 +1917,9 @@ class CxxParser:
fn.raw_requires = self._parse_requires(rtok) fn.raw_requires = self._parse_requires(rtok)
if self.lex.token_if("ARROW"): if self.lex.token_if("ARROW"):
self._parse_trailing_return_type(fn) return_type = self._parse_trailing_return_type(fn.return_type)
fn.has_trailing_return = True
fn.return_type = return_type
if self.lex.token_if("{"): if self.lex.token_if("{"):
self._discard_contents("{", "}") self._discard_contents("{", "}")
@ -1966,7 +1967,9 @@ class CxxParser:
elif tok_value in ("&", "&&"): elif tok_value in ("&", "&&"):
method.ref_qualifier = tok_value method.ref_qualifier = tok_value
elif tok_value == "->": elif tok_value == "->":
self._parse_trailing_return_type(method) return_type = self._parse_trailing_return_type(method.return_type)
method.has_trailing_return = True
method.return_type = return_type
if self.lex.token_if("{"): if self.lex.token_if("{"):
self._discard_contents("{", "}") self._discard_contents("{", "}")
method.has_body = True method.has_body = True
@ -2000,6 +2003,7 @@ class CxxParser:
is_friend: bool, is_friend: bool,
is_typedef: bool, is_typedef: bool,
msvc_convention: typing.Optional[LexToken], msvc_convention: typing.Optional[LexToken],
is_guide: bool = False,
) -> bool: ) -> bool:
""" """
Assumes the caller has already consumed the return type and name, this consumes the Assumes the caller has already consumed the return type and name, this consumes the
@ -2076,7 +2080,21 @@ class CxxParser:
self.visitor.on_method_impl(state, method) self.visitor.on_method_impl(state, method)
return method.has_body or method.has_trailing_return return method.has_body or method.has_trailing_return
elif is_guide:
assert isinstance(state, (ExternBlockState, NamespaceBlockState))
if not self.lex.token_if("ARROW"):
raise self._parse_error(None, expected="Trailing return type")
return_type = self._parse_trailing_return_type(
Type(PQName([AutoSpecifier()]))
)
guide = DeductionGuide(
return_type,
name=pqname,
parameters=params,
doxygen=doxygen,
)
self.visitor.on_deduction_guide(state, guide)
return False
else: else:
assert return_type is not None assert return_type is not None
fn = Function( fn = Function(
@ -2210,7 +2228,9 @@ class CxxParser:
assert not isinstance(dtype, FunctionType) assert not isinstance(dtype, FunctionType)
dtype = dtype_fn = FunctionType(dtype, fn_params, vararg) dtype = dtype_fn = FunctionType(dtype, fn_params, vararg)
if self.lex.token_if("ARROW"): if self.lex.token_if("ARROW"):
self._parse_trailing_return_type(dtype_fn) return_type = self._parse_trailing_return_type(dtype_fn.return_type)
dtype_fn.has_trailing_return = True
dtype_fn.return_type = return_type
else: else:
msvc_convention = None msvc_convention = None
@ -2391,6 +2411,7 @@ class CxxParser:
destructor = False destructor = False
op = None op = None
msvc_convention = None msvc_convention = None
is_guide = False
# If we have a leading (, that's either an obnoxious grouping # If we have a leading (, that's either an obnoxious grouping
# paren or it's a constructor # paren or it's a constructor
@ -2441,6 +2462,13 @@ class CxxParser:
# grouping paren like "void (name(int x));" # grouping paren like "void (name(int x));"
toks = self._consume_balanced_tokens(tok) toks = self._consume_balanced_tokens(tok)
# check to see if the next token is an arrow, and thus a trailing return
if self.lex.token_peek_if("ARROW"):
self.lex.return_tokens(toks)
# the leading name of the class/ctor has been parsed as a type before the parens
pqname = parsed_type.typename
is_guide = True
else:
# .. not sure what it's grouping, so put it back? # .. not sure what it's grouping, so put it back?
self.lex.return_tokens(toks[1:-1]) self.lex.return_tokens(toks[1:-1])
@ -2473,6 +2501,7 @@ class CxxParser:
is_friend, is_friend,
is_typedef, is_typedef,
msvc_convention, msvc_convention,
is_guide,
) )
elif msvc_convention: elif msvc_convention:
raise self._parse_error(msvc_convention) raise self._parse_error(msvc_convention)

View File

@ -35,6 +35,7 @@ from dataclasses import dataclass, field
from .types import ( from .types import (
ClassDecl, ClassDecl,
Concept, Concept,
DeductionGuide,
EnumDecl, EnumDecl,
Field, Field,
ForwardDecl, ForwardDecl,
@ -123,6 +124,9 @@ class NamespaceScope:
#: Child namespaces #: Child namespaces
namespaces: typing.Dict[str, "NamespaceScope"] = field(default_factory=dict) namespaces: typing.Dict[str, "NamespaceScope"] = field(default_factory=dict)
#: Deduction guides
deduction_guides: typing.List[DeductionGuide] = field(default_factory=list)
Block = typing.Union[ClassScope, NamespaceScope] Block = typing.Union[ClassScope, NamespaceScope]
@ -317,6 +321,11 @@ class SimpleCxxVisitor:
def on_class_end(self, state: SClassBlockState) -> None: def on_class_end(self, state: SClassBlockState) -> None:
pass pass
def on_deduction_guide(
self, state: SNonClassBlockState, guide: DeductionGuide
) -> None:
state.user_data.deduction_guides.append(guide)
def parse_string( def parse_string(
content: str, content: str,

View File

@ -896,3 +896,21 @@ class UsingAlias:
#: Documentation if present #: Documentation if present
doxygen: typing.Optional[str] = None doxygen: typing.Optional[str] = None
@dataclass
class DeductionGuide:
"""
.. code-block:: c++
template <class T>
MyClass(T) -> MyClass(int);
"""
#: Only constructors and destructors don't have a return type
result_type: typing.Optional[DecoratedType]
name: PQName
parameters: typing.List[Parameter]
doxygen: typing.Optional[str] = None

View File

@ -9,6 +9,7 @@ else:
from .types import ( from .types import (
Concept, Concept,
DeductionGuide,
EnumDecl, EnumDecl,
Field, Field,
ForwardDecl, ForwardDecl,
@ -236,6 +237,13 @@ class CxxVisitor(Protocol):
``on_variable`` for each instance declared. ``on_variable`` for each instance declared.
""" """
def on_deduction_guide(
self, state: NonClassBlockState, guide: DeductionGuide
) -> None:
"""
Called when a deduction guide is encountered
"""
class NullVisitor: class NullVisitor:
""" """
@ -318,5 +326,10 @@ class NullVisitor:
def on_class_end(self, state: ClassBlockState) -> None: def on_class_end(self, state: ClassBlockState) -> None:
return None return None
def on_deduction_guide(
self, state: NonClassBlockState, guide: DeductionGuide
) -> None:
return None
null_visitor = NullVisitor() null_visitor = NullVisitor()

View File

@ -5,6 +5,7 @@ from cxxheaderparser.types import (
BaseClass, BaseClass,
ClassDecl, ClassDecl,
DecltypeSpecifier, DecltypeSpecifier,
DeductionGuide,
Field, Field,
ForwardDecl, ForwardDecl,
Function, Function,
@ -2163,3 +2164,86 @@ def test_member_class_template_specialization() -> None:
] ]
) )
) )
def test_template_deduction_guide() -> None:
content = """
template <class CharT, class Traits = std::char_traits<CharT>>
Error(std::basic_string_view<CharT, Traits>) -> Error<std::string>;
"""
data = parse_string(content, cleandoc=True)
assert data == ParsedData(
namespace=NamespaceScope(
deduction_guides=[
DeductionGuide(
result_type=Type(
typename=PQName(
segments=[
NameSpecifier(
name="Error",
specialization=TemplateSpecialization(
args=[
TemplateArgument(
arg=Type(
typename=PQName(
segments=[
NameSpecifier(name="std"),
NameSpecifier(
name="string"
),
]
)
)
)
]
),
)
]
)
),
name=PQName(segments=[NameSpecifier(name="Error")]),
parameters=[
Parameter(
type=Type(
typename=PQName(
segments=[
NameSpecifier(name="std"),
NameSpecifier(
name="basic_string_view",
specialization=TemplateSpecialization(
args=[
TemplateArgument(
arg=Type(
typename=PQName(
segments=[
NameSpecifier(
name="CharT"
)
]
)
)
),
TemplateArgument(
arg=Type(
typename=PQName(
segments=[
NameSpecifier(
name="Traits"
)
]
)
)
),
]
),
),
]
)
)
)
],
)
]
)
)