Parse C++20 requirement constraints for functions/classes

Co-authored-by: David Vo <auscompgeek@users.noreply.github.com>
This commit is contained in:
Dustin Spicuzza 2023-10-12 01:08:30 -04:00
parent 2957e70823
commit 37cd3abee9
4 changed files with 524 additions and 1 deletions

View File

@ -188,6 +188,7 @@ class PlyLexer:
"DBL_RBRACKET",
"DBL_COLON",
"DBL_AMP",
"DBL_PIPE",
"ARROW",
"SHIFT_LEFT",
] + list(keywords)
@ -473,6 +474,7 @@ class PlyLexer:
t_DBL_RBRACKET = r"\]\]"
t_DBL_COLON = r"::"
t_DBL_AMP = r"&&"
t_DBL_PIPE = r"\|\|"
t_ARROW = r"->"
t_SHIFT_LEFT = r"<<"
# SHIFT_RIGHT introduces ambiguity

View File

@ -647,6 +647,9 @@ class CxxParser:
self._parse_friend_decl(tok, doxygen, template)
elif tok.type == "concept":
self._parse_concept(tok, doxygen, template)
elif tok.type == "requires":
template.raw_requires_pre = self._parse_requires(tok)
self._parse_declarations(self.lex.token(), doxygen, template)
else:
self._parse_declarations(tok, doxygen, template)
@ -783,6 +786,91 @@ class CxxParser:
),
)
# fmt: off
_expr_operators = {
"<", ">", "|", "%", "^", "!", "*", "-", "+", "&", "=",
"&&", "||", "<<"
}
# fmt: on
def _parse_requires(
self,
tok: LexToken,
) -> Value:
tok = self.lex.token()
rawtoks: typing.List[LexToken] = []
# The easier case -- requires requires
if tok.type == "requires":
rawtoks.append(tok)
for tt in ("(", "{"):
tok = self._next_token_must_be(tt)
rawtoks.extend(self._consume_balanced_tokens(tok))
# .. and that's it?
# this is either a parenthesized expression or a primary clause
elif tok.type == "(":
rawtoks.extend(self._consume_balanced_tokens(tok))
else:
while True:
if tok.type == "(":
rawtoks.extend(self._consume_balanced_tokens(tok))
else:
tok = self._parse_requires_segment(tok, rawtoks)
# If this is not an operator of some kind, we don't know how
# to proceed so let the next parser figure it out
if tok.value not in self._expr_operators:
break
rawtoks.append(tok)
# check once more for compound operator?
tok = self.lex.token()
if tok.value in self._expr_operators:
rawtoks.append(tok)
tok = self.lex.token()
self.lex.return_token(tok)
return self._create_value(rawtoks)
def _parse_requires_segment(
self, tok: LexToken, rawtoks: typing.List[LexToken]
) -> LexToken:
# first token could be a name or ::
if tok.type == "DBL_COLON":
rawtoks.append(tok)
tok = self.lex.token()
while True:
# This token has to be a name or some other valid name-like thing
if tok.value == "decltype":
rawtoks.append(tok)
tok = self._next_token_must_be("(")
rawtoks.extend(self._consume_balanced_tokens(tok))
elif tok.type == "NAME":
rawtoks.append(tok)
else:
# not sure what I expected, but I didn't find it
raise self._parse_error(tok)
tok = self.lex.token()
# Maybe there's a specialization
if tok.value == "<":
rawtoks.extend(self._consume_balanced_tokens(tok))
tok = self.lex.token()
# Maybe we keep trying to parse this name
if tok.type == "DBL_COLON":
tok = self.lex.token()
continue
# Let the caller decide
return tok
#
# Attributes
#
@ -1816,6 +1904,15 @@ class CxxParser:
if otok:
toks = self._consume_balanced_tokens(otok)[1:-1]
fn.noexcept = self._create_value(toks)
else:
rtok = self.lex.token_if("requires")
if rtok:
fn_template = fn.template
if fn_template is None:
raise self._parse_error(rtok)
elif isinstance(fn_template, list):
fn_template = fn_template[0]
fn_template.raw_requires_post = self._parse_requires(rtok)
if self.lex.token_if("{"):
self._discard_contents("{", "}")
@ -1876,6 +1973,13 @@ class CxxParser:
if otok:
toks = self._consume_balanced_tokens(otok)[1:-1]
method.noexcept = self._create_value(toks)
elif tok_value == "requires":
method_template = method.template
if method_template is None:
raise self._parse_error(tok)
elif isinstance(method_template, list):
method_template = method_template[0]
method_template.raw_requires_post = self._parse_requires(tok)
else:
self.lex.return_token(tok)
break

View File

@ -520,6 +520,15 @@ class TemplateDecl:
params: typing.List[TemplateParam] = field(default_factory=list)
# Currently don't interpret requires, if that changes in the future
# then this API will change.
#: template <typename T> requires ...
raw_requires_pre: typing.Optional[Value] = None
#: template <typename T> int main() requires ...
raw_requires_post: typing.Optional[Value] = None
#: If no template, this is None. This is a TemplateDecl if this there is a single
#: declaration:

View File

@ -1,9 +1,12 @@
from cxxheaderparser.simple import NamespaceScope, ParsedData, parse_string
from cxxheaderparser.simple import ClassScope, NamespaceScope, ParsedData, parse_string
from cxxheaderparser.tokfmt import Token
from cxxheaderparser.types import (
AutoSpecifier,
ClassDecl,
Concept,
Function,
FundamentalSpecifier,
MoveReference,
NameSpecifier,
PQName,
Parameter,
@ -401,3 +404,408 @@ def test_concept_nested_requirements() -> None:
]
)
)
def test_concept_requires_class() -> None:
content = """
// clang-format off
template <typename T>
concept Number = std::integral<T> || std::floating_point<T>;
template <typename T>
requires Number<T>
struct WrappedNumber {};
"""
data = parse_string(content, cleandoc=True)
assert data == ParsedData(
namespace=NamespaceScope(
classes=[
ClassScope(
class_decl=ClassDecl(
typename=PQName(
segments=[NameSpecifier(name="WrappedNumber")],
classkey="struct",
),
template=TemplateDecl(
params=[TemplateTypeParam(typekey="typename", name="T")],
raw_requires_pre=Value(
tokens=[
Token(value="Number"),
Token(value="<"),
Token(value="T"),
Token(value=">"),
]
),
),
)
)
],
concepts=[
Concept(
template=TemplateDecl(
params=[TemplateTypeParam(typekey="typename", name="T")]
),
name="Number",
raw_constraint=Value(
tokens=[
Token(value="std"),
Token(value="::"),
Token(value="integral"),
Token(value="<"),
Token(value="T"),
Token(value=">"),
Token(value="||"),
Token(value="std"),
Token(value="::"),
Token(value="floating_point"),
Token(value="<"),
Token(value="T"),
Token(value=">"),
]
),
)
],
)
)
def test_requires_last_elem() -> None:
content = """
template<typename T>
void f(T&&) requires Eq<T>; // can appear as the last element of a function declarator
"""
data = parse_string(content, cleandoc=True)
assert data == ParsedData(
namespace=NamespaceScope(
functions=[
Function(
return_type=Type(
typename=PQName(segments=[FundamentalSpecifier(name="void")])
),
name=PQName(segments=[NameSpecifier(name="f")]),
parameters=[
Parameter(
type=MoveReference(
moveref_to=Type(
typename=PQName(segments=[NameSpecifier(name="T")])
)
)
)
],
template=TemplateDecl(
params=[TemplateTypeParam(typekey="typename", name="T")],
raw_requires_post=Value(
tokens=[
Token(value="Eq"),
Token(value="<"),
Token(value="T"),
Token(value=">"),
]
),
),
)
]
)
)
def test_requires_first_elem1() -> None:
content = """
template<typename T> requires Addable<T> // or right after a template parameter list
T add(T a, T b) { return a + b; }
"""
data = parse_string(content, cleandoc=True)
assert data == ParsedData(
namespace=NamespaceScope(
functions=[
Function(
return_type=Type(
typename=PQName(segments=[NameSpecifier(name="T")])
),
name=PQName(segments=[NameSpecifier(name="add")]),
parameters=[
Parameter(
type=Type(
typename=PQName(segments=[NameSpecifier(name="T")])
),
name="a",
),
Parameter(
type=Type(
typename=PQName(segments=[NameSpecifier(name="T")])
),
name="b",
),
],
has_body=True,
template=TemplateDecl(
params=[TemplateTypeParam(typekey="typename", name="T")],
raw_requires_pre=Value(
tokens=[
Token(value="Addable"),
Token(value="<"),
Token(value="T"),
Token(value=">"),
]
),
),
)
]
)
)
def test_requires_first_elem2() -> None:
content = """
template<typename T> requires std::is_arithmetic_v<T>
T add(T a, T b) { return a + b; }
"""
data = parse_string(content, cleandoc=True)
assert data == ParsedData(
namespace=NamespaceScope(
functions=[
Function(
return_type=Type(
typename=PQName(segments=[NameSpecifier(name="T")])
),
name=PQName(segments=[NameSpecifier(name="add")]),
parameters=[
Parameter(
type=Type(
typename=PQName(segments=[NameSpecifier(name="T")])
),
name="a",
),
Parameter(
type=Type(
typename=PQName(segments=[NameSpecifier(name="T")])
),
name="b",
),
],
has_body=True,
template=TemplateDecl(
params=[TemplateTypeParam(typekey="typename", name="T")],
raw_requires_pre=Value(
tokens=[
Token(value="std"),
Token(value="is_arithmetic_v"),
Token(value="<"),
Token(value="T"),
Token(value=">"),
]
),
),
)
]
)
)
def test_requires_compound() -> None:
content = """
template<typename T> requires Addable<T> || Subtractable<T>
T add(T a, T b) { return a + b; }
"""
data = parse_string(content, cleandoc=True)
assert data == ParsedData(
namespace=NamespaceScope(
functions=[
Function(
return_type=Type(
typename=PQName(segments=[NameSpecifier(name="T")])
),
name=PQName(segments=[NameSpecifier(name="add")]),
parameters=[
Parameter(
type=Type(
typename=PQName(segments=[NameSpecifier(name="T")])
),
name="a",
),
Parameter(
type=Type(
typename=PQName(segments=[NameSpecifier(name="T")])
),
name="b",
),
],
has_body=True,
template=TemplateDecl(
params=[TemplateTypeParam(typekey="typename", name="T")],
raw_requires_pre=Value(
tokens=[
Token(value="Addable"),
Token(value="<"),
Token(value="T"),
Token(value=">"),
Token(value="||"),
Token(value="Subtractable"),
Token(value="<"),
Token(value="T"),
Token(value=">"),
]
),
),
)
]
)
)
def test_requires_ad_hoc() -> None:
content = """
template<typename T>
requires requires (T x) { x + x; } // ad-hoc constraint, note keyword used twice
T add(T a, T b) { return a + b; }
"""
data = parse_string(content, cleandoc=True)
assert data == ParsedData(
namespace=NamespaceScope(
functions=[
Function(
return_type=Type(
typename=PQName(segments=[NameSpecifier(name="T")])
),
name=PQName(segments=[NameSpecifier(name="add")]),
parameters=[
Parameter(
type=Type(
typename=PQName(segments=[NameSpecifier(name="T")])
),
name="a",
),
Parameter(
type=Type(
typename=PQName(segments=[NameSpecifier(name="T")])
),
name="b",
),
],
has_body=True,
template=TemplateDecl(
params=[TemplateTypeParam(typekey="typename", name="T")],
raw_requires_pre=Value(
tokens=[
Token(value="requires"),
Token(value="("),
Token(value="T"),
Token(value="x"),
Token(value=")"),
Token(value="{"),
Token(value="x"),
Token(value="+"),
Token(value="x"),
Token(value=";"),
Token(value="}"),
]
),
),
)
]
)
)
def test_requires_both() -> None:
content = """
// clang-format off
template<typename T>
requires Addable<T>
auto f1(T a, T b) requires Subtractable<T>;
"""
data = parse_string(content, cleandoc=True)
assert data == ParsedData(
namespace=NamespaceScope(
functions=[
Function(
return_type=Type(typename=PQName(segments=[AutoSpecifier()])),
name=PQName(segments=[NameSpecifier(name="f1")]),
parameters=[
Parameter(
type=Type(
typename=PQName(segments=[NameSpecifier(name="T")])
),
name="a",
),
Parameter(
type=Type(
typename=PQName(segments=[NameSpecifier(name="T")])
),
name="b",
),
],
template=TemplateDecl(
params=[TemplateTypeParam(typekey="typename", name="T")],
raw_requires_pre=Value(
tokens=[
Token(value="Addable"),
Token(value="<"),
Token(value="T"),
Token(value=">"),
]
),
raw_requires_post=Value(
tokens=[
Token(value="Subtractable"),
Token(value="<"),
Token(value="T"),
Token(value=">"),
]
),
),
)
]
)
)
def test_requires_paren() -> None:
content = """
// clang-format off
template<class T>
void h(T) requires (is_purrable<T>());
"""
data = parse_string(content, cleandoc=True)
assert data == ParsedData(
namespace=NamespaceScope(
functions=[
Function(
return_type=Type(
typename=PQName(segments=[FundamentalSpecifier(name="void")])
),
name=PQName(segments=[NameSpecifier(name="h")]),
parameters=[
Parameter(
type=Type(
typename=PQName(segments=[NameSpecifier(name="T")])
)
)
],
template=TemplateDecl(
params=[TemplateTypeParam(typekey="class", name="T")],
raw_requires_post=Value(
tokens=[
Token(value="("),
Token(value="is_purrable"),
Token(value="<"),
Token(value="T"),
Token(value=">"),
Token(value="("),
Token(value=")"),
Token(value=")"),
]
),
),
)
]
)
)