From 37cd3abee92317501489e9c1ed83759327dbcd6c Mon Sep 17 00:00:00 2001 From: Dustin Spicuzza Date: Thu, 12 Oct 2023 01:08:30 -0400 Subject: [PATCH] Parse C++20 requirement constraints for functions/classes Co-authored-by: David Vo --- cxxheaderparser/lexer.py | 2 + cxxheaderparser/parser.py | 104 ++++++++++ cxxheaderparser/types.py | 9 + tests/test_concepts.py | 410 +++++++++++++++++++++++++++++++++++++- 4 files changed, 524 insertions(+), 1 deletion(-) diff --git a/cxxheaderparser/lexer.py b/cxxheaderparser/lexer.py index 38c0b59..15dec34 100644 --- a/cxxheaderparser/lexer.py +++ b/cxxheaderparser/lexer.py @@ -188,6 +188,7 @@ class PlyLexer: "DBL_RBRACKET", "DBL_COLON", "DBL_AMP", + "DBL_PIPE", "ARROW", "SHIFT_LEFT", ] + list(keywords) @@ -473,6 +474,7 @@ class PlyLexer: t_DBL_RBRACKET = r"\]\]" t_DBL_COLON = r"::" t_DBL_AMP = r"&&" + t_DBL_PIPE = r"\|\|" t_ARROW = r"->" t_SHIFT_LEFT = r"<<" # SHIFT_RIGHT introduces ambiguity diff --git a/cxxheaderparser/parser.py b/cxxheaderparser/parser.py index 24b940f..a4217a2 100644 --- a/cxxheaderparser/parser.py +++ b/cxxheaderparser/parser.py @@ -647,6 +647,9 @@ class CxxParser: self._parse_friend_decl(tok, doxygen, template) elif tok.type == "concept": self._parse_concept(tok, doxygen, template) + elif tok.type == "requires": + template.raw_requires_pre = self._parse_requires(tok) + self._parse_declarations(self.lex.token(), doxygen, template) else: self._parse_declarations(tok, doxygen, template) @@ -783,6 +786,91 @@ class CxxParser: ), ) + # fmt: off + _expr_operators = { + "<", ">", "|", "%", "^", "!", "*", "-", "+", "&", "=", + "&&", "||", "<<" + } + # fmt: on + + def _parse_requires( + self, + tok: LexToken, + ) -> Value: + tok = self.lex.token() + + rawtoks: typing.List[LexToken] = [] + + # The easier case -- requires requires + if tok.type == "requires": + rawtoks.append(tok) + for tt in ("(", "{"): + tok = self._next_token_must_be(tt) + rawtoks.extend(self._consume_balanced_tokens(tok)) + # .. and that's it? + + # this is either a parenthesized expression or a primary clause + elif tok.type == "(": + rawtoks.extend(self._consume_balanced_tokens(tok)) + else: + while True: + if tok.type == "(": + rawtoks.extend(self._consume_balanced_tokens(tok)) + else: + tok = self._parse_requires_segment(tok, rawtoks) + + # If this is not an operator of some kind, we don't know how + # to proceed so let the next parser figure it out + if tok.value not in self._expr_operators: + break + + rawtoks.append(tok) + + # check once more for compound operator? + tok = self.lex.token() + if tok.value in self._expr_operators: + rawtoks.append(tok) + tok = self.lex.token() + + self.lex.return_token(tok) + + return self._create_value(rawtoks) + + def _parse_requires_segment( + self, tok: LexToken, rawtoks: typing.List[LexToken] + ) -> LexToken: + # first token could be a name or :: + if tok.type == "DBL_COLON": + rawtoks.append(tok) + tok = self.lex.token() + + while True: + # This token has to be a name or some other valid name-like thing + if tok.value == "decltype": + rawtoks.append(tok) + tok = self._next_token_must_be("(") + rawtoks.extend(self._consume_balanced_tokens(tok)) + elif tok.type == "NAME": + rawtoks.append(tok) + else: + # not sure what I expected, but I didn't find it + raise self._parse_error(tok) + + tok = self.lex.token() + + # Maybe there's a specialization + if tok.value == "<": + rawtoks.extend(self._consume_balanced_tokens(tok)) + tok = self.lex.token() + + # Maybe we keep trying to parse this name + if tok.type == "DBL_COLON": + tok = self.lex.token() + continue + + # Let the caller decide + return tok + # # Attributes # @@ -1816,6 +1904,15 @@ class CxxParser: if otok: toks = self._consume_balanced_tokens(otok)[1:-1] fn.noexcept = self._create_value(toks) + else: + rtok = self.lex.token_if("requires") + if rtok: + fn_template = fn.template + if fn_template is None: + raise self._parse_error(rtok) + elif isinstance(fn_template, list): + fn_template = fn_template[0] + fn_template.raw_requires_post = self._parse_requires(rtok) if self.lex.token_if("{"): self._discard_contents("{", "}") @@ -1876,6 +1973,13 @@ class CxxParser: if otok: toks = self._consume_balanced_tokens(otok)[1:-1] method.noexcept = self._create_value(toks) + elif tok_value == "requires": + method_template = method.template + if method_template is None: + raise self._parse_error(tok) + elif isinstance(method_template, list): + method_template = method_template[0] + method_template.raw_requires_post = self._parse_requires(tok) else: self.lex.return_token(tok) break diff --git a/cxxheaderparser/types.py b/cxxheaderparser/types.py index 674398e..ebc15f9 100644 --- a/cxxheaderparser/types.py +++ b/cxxheaderparser/types.py @@ -520,6 +520,15 @@ class TemplateDecl: params: typing.List[TemplateParam] = field(default_factory=list) + # Currently don't interpret requires, if that changes in the future + # then this API will change. + + #: template requires ... + raw_requires_pre: typing.Optional[Value] = None + + #: template int main() requires ... + raw_requires_post: typing.Optional[Value] = None + #: If no template, this is None. This is a TemplateDecl if this there is a single #: declaration: diff --git a/tests/test_concepts.py b/tests/test_concepts.py index ae29aec..2be5a8c 100644 --- a/tests/test_concepts.py +++ b/tests/test_concepts.py @@ -1,9 +1,12 @@ -from cxxheaderparser.simple import NamespaceScope, ParsedData, parse_string +from cxxheaderparser.simple import ClassScope, NamespaceScope, ParsedData, parse_string from cxxheaderparser.tokfmt import Token from cxxheaderparser.types import ( + AutoSpecifier, + ClassDecl, Concept, Function, FundamentalSpecifier, + MoveReference, NameSpecifier, PQName, Parameter, @@ -401,3 +404,408 @@ def test_concept_nested_requirements() -> None: ] ) ) + + +def test_concept_requires_class() -> None: + content = """ + // clang-format off + template + concept Number = std::integral || std::floating_point; + + template + requires Number + struct WrappedNumber {}; + """ + data = parse_string(content, cleandoc=True) + + assert data == ParsedData( + namespace=NamespaceScope( + classes=[ + ClassScope( + class_decl=ClassDecl( + typename=PQName( + segments=[NameSpecifier(name="WrappedNumber")], + classkey="struct", + ), + template=TemplateDecl( + params=[TemplateTypeParam(typekey="typename", name="T")], + raw_requires_pre=Value( + tokens=[ + Token(value="Number"), + Token(value="<"), + Token(value="T"), + Token(value=">"), + ] + ), + ), + ) + ) + ], + concepts=[ + Concept( + template=TemplateDecl( + params=[TemplateTypeParam(typekey="typename", name="T")] + ), + name="Number", + raw_constraint=Value( + tokens=[ + Token(value="std"), + Token(value="::"), + Token(value="integral"), + Token(value="<"), + Token(value="T"), + Token(value=">"), + Token(value="||"), + Token(value="std"), + Token(value="::"), + Token(value="floating_point"), + Token(value="<"), + Token(value="T"), + Token(value=">"), + ] + ), + ) + ], + ) + ) + + +def test_requires_last_elem() -> None: + content = """ + template + void f(T&&) requires Eq; // can appear as the last element of a function declarator + """ + data = parse_string(content, cleandoc=True) + + assert data == ParsedData( + namespace=NamespaceScope( + functions=[ + Function( + return_type=Type( + typename=PQName(segments=[FundamentalSpecifier(name="void")]) + ), + name=PQName(segments=[NameSpecifier(name="f")]), + parameters=[ + Parameter( + type=MoveReference( + moveref_to=Type( + typename=PQName(segments=[NameSpecifier(name="T")]) + ) + ) + ) + ], + template=TemplateDecl( + params=[TemplateTypeParam(typekey="typename", name="T")], + raw_requires_post=Value( + tokens=[ + Token(value="Eq"), + Token(value="<"), + Token(value="T"), + Token(value=">"), + ] + ), + ), + ) + ] + ) + ) + + +def test_requires_first_elem1() -> None: + content = """ + template requires Addable // or right after a template parameter list + T add(T a, T b) { return a + b; } + """ + data = parse_string(content, cleandoc=True) + + assert data == ParsedData( + namespace=NamespaceScope( + functions=[ + Function( + return_type=Type( + typename=PQName(segments=[NameSpecifier(name="T")]) + ), + name=PQName(segments=[NameSpecifier(name="add")]), + parameters=[ + Parameter( + type=Type( + typename=PQName(segments=[NameSpecifier(name="T")]) + ), + name="a", + ), + Parameter( + type=Type( + typename=PQName(segments=[NameSpecifier(name="T")]) + ), + name="b", + ), + ], + has_body=True, + template=TemplateDecl( + params=[TemplateTypeParam(typekey="typename", name="T")], + raw_requires_pre=Value( + tokens=[ + Token(value="Addable"), + Token(value="<"), + Token(value="T"), + Token(value=">"), + ] + ), + ), + ) + ] + ) + ) + + +def test_requires_first_elem2() -> None: + content = """ + template requires std::is_arithmetic_v + T add(T a, T b) { return a + b; } + """ + data = parse_string(content, cleandoc=True) + + assert data == ParsedData( + namespace=NamespaceScope( + functions=[ + Function( + return_type=Type( + typename=PQName(segments=[NameSpecifier(name="T")]) + ), + name=PQName(segments=[NameSpecifier(name="add")]), + parameters=[ + Parameter( + type=Type( + typename=PQName(segments=[NameSpecifier(name="T")]) + ), + name="a", + ), + Parameter( + type=Type( + typename=PQName(segments=[NameSpecifier(name="T")]) + ), + name="b", + ), + ], + has_body=True, + template=TemplateDecl( + params=[TemplateTypeParam(typekey="typename", name="T")], + raw_requires_pre=Value( + tokens=[ + Token(value="std"), + Token(value="is_arithmetic_v"), + Token(value="<"), + Token(value="T"), + Token(value=">"), + ] + ), + ), + ) + ] + ) + ) + + +def test_requires_compound() -> None: + content = """ + template requires Addable || Subtractable + T add(T a, T b) { return a + b; } + """ + data = parse_string(content, cleandoc=True) + + assert data == ParsedData( + namespace=NamespaceScope( + functions=[ + Function( + return_type=Type( + typename=PQName(segments=[NameSpecifier(name="T")]) + ), + name=PQName(segments=[NameSpecifier(name="add")]), + parameters=[ + Parameter( + type=Type( + typename=PQName(segments=[NameSpecifier(name="T")]) + ), + name="a", + ), + Parameter( + type=Type( + typename=PQName(segments=[NameSpecifier(name="T")]) + ), + name="b", + ), + ], + has_body=True, + template=TemplateDecl( + params=[TemplateTypeParam(typekey="typename", name="T")], + raw_requires_pre=Value( + tokens=[ + Token(value="Addable"), + Token(value="<"), + Token(value="T"), + Token(value=">"), + Token(value="||"), + Token(value="Subtractable"), + Token(value="<"), + Token(value="T"), + Token(value=">"), + ] + ), + ), + ) + ] + ) + ) + + +def test_requires_ad_hoc() -> None: + content = """ + template + requires requires (T x) { x + x; } // ad-hoc constraint, note keyword used twice + T add(T a, T b) { return a + b; } + """ + data = parse_string(content, cleandoc=True) + + assert data == ParsedData( + namespace=NamespaceScope( + functions=[ + Function( + return_type=Type( + typename=PQName(segments=[NameSpecifier(name="T")]) + ), + name=PQName(segments=[NameSpecifier(name="add")]), + parameters=[ + Parameter( + type=Type( + typename=PQName(segments=[NameSpecifier(name="T")]) + ), + name="a", + ), + Parameter( + type=Type( + typename=PQName(segments=[NameSpecifier(name="T")]) + ), + name="b", + ), + ], + has_body=True, + template=TemplateDecl( + params=[TemplateTypeParam(typekey="typename", name="T")], + raw_requires_pre=Value( + tokens=[ + Token(value="requires"), + Token(value="("), + Token(value="T"), + Token(value="x"), + Token(value=")"), + Token(value="{"), + Token(value="x"), + Token(value="+"), + Token(value="x"), + Token(value=";"), + Token(value="}"), + ] + ), + ), + ) + ] + ) + ) + + +def test_requires_both() -> None: + content = """ + // clang-format off + template + requires Addable + auto f1(T a, T b) requires Subtractable; + """ + data = parse_string(content, cleandoc=True) + + assert data == ParsedData( + namespace=NamespaceScope( + functions=[ + Function( + return_type=Type(typename=PQName(segments=[AutoSpecifier()])), + name=PQName(segments=[NameSpecifier(name="f1")]), + parameters=[ + Parameter( + type=Type( + typename=PQName(segments=[NameSpecifier(name="T")]) + ), + name="a", + ), + Parameter( + type=Type( + typename=PQName(segments=[NameSpecifier(name="T")]) + ), + name="b", + ), + ], + template=TemplateDecl( + params=[TemplateTypeParam(typekey="typename", name="T")], + raw_requires_pre=Value( + tokens=[ + Token(value="Addable"), + Token(value="<"), + Token(value="T"), + Token(value=">"), + ] + ), + raw_requires_post=Value( + tokens=[ + Token(value="Subtractable"), + Token(value="<"), + Token(value="T"), + Token(value=">"), + ] + ), + ), + ) + ] + ) + ) + + +def test_requires_paren() -> None: + content = """ + // clang-format off + template + void h(T) requires (is_purrable()); + """ + data = parse_string(content, cleandoc=True) + + assert data == ParsedData( + namespace=NamespaceScope( + functions=[ + Function( + return_type=Type( + typename=PQName(segments=[FundamentalSpecifier(name="void")]) + ), + name=PQName(segments=[NameSpecifier(name="h")]), + parameters=[ + Parameter( + type=Type( + typename=PQName(segments=[NameSpecifier(name="T")]) + ) + ) + ], + template=TemplateDecl( + params=[TemplateTypeParam(typekey="class", name="T")], + raw_requires_post=Value( + tokens=[ + Token(value="("), + Token(value="is_purrable"), + Token(value="<"), + Token(value="T"), + Token(value=">"), + Token(value="("), + Token(value=")"), + Token(value=")"), + ] + ), + ), + ) + ] + ) + )