diff --git a/cxxheaderparser/lexer.py b/cxxheaderparser/lexer.py index 3c58593..e700331 100644 --- a/cxxheaderparser/lexer.py +++ b/cxxheaderparser/lexer.py @@ -8,7 +8,7 @@ from ._ply import lex if sys.version_info >= (3, 8): - Protocol = typing.Protocol + from typing import Protocol else: Protocol = object @@ -43,7 +43,7 @@ class LexToken(Protocol): location: Location -PhonyEnding = lex.LexToken() +PhonyEnding: LexToken = lex.LexToken() # type: ignore PhonyEnding.type = "PLACEHOLDER" PhonyEnding.value = "" PhonyEnding.lineno = 0 @@ -278,7 +278,7 @@ class Lexer: self.lookahead = typing.Deque[LexToken]() # For 'set_group_of_tokens' support - self._get_token = self.lex.token + self._get_token: typing.Callable[[], LexToken] = self.lex.token self.lookahead_stack = typing.Deque[typing.Deque[LexToken]]() def current_location(self) -> Location: @@ -462,7 +462,7 @@ class Lexer: def return_token(self, tok: LexToken) -> None: self.lookahead.appendleft(tok) - def return_tokens(self, toks: typing.Iterable[LexToken]) -> None: + def return_tokens(self, toks: typing.Sequence[LexToken]) -> None: self.lookahead.extendleft(reversed(toks)) diff --git a/cxxheaderparser/parser.py b/cxxheaderparser/parser.py index 11f86ab..45bf168 100644 --- a/cxxheaderparser/parser.py +++ b/cxxheaderparser/parser.py @@ -8,7 +8,6 @@ from .errors import CxxParseError from .lexer import Lexer, LexToken, Location, PhonyEnding from .options import ParserOptions from .parserstate import ( - BlockState, ClassBlockState, EmptyBlockState, ExternBlockState, @@ -45,6 +44,7 @@ from .types import ( TemplateArgument, TemplateDecl, TemplateNonTypeParam, + TemplateParam, TemplateSpecialization, TemplateTypeParam, Token, @@ -60,6 +60,9 @@ from .visitor import CxxVisitor LexTokenList = typing.List[LexToken] T = typing.TypeVar("T") +ST = typing.TypeVar("ST", bound=State) +PT = typing.TypeVar("PT", Parameter, TemplateNonTypeParam) + class CxxParser: """ @@ -82,7 +85,7 @@ class CxxParser: global_ns = NamespaceDecl([], False) self.current_namespace = global_ns - self.state: BlockState = NamespaceBlockState(None, global_ns) + self.state: State = NamespaceBlockState(None, global_ns) self.anon_id = 0 self.options = options if options else ParserOptions() @@ -103,7 +106,7 @@ class CxxParser: # State management # - def _push_state(self, cls: typing.Type[T], *args) -> T: + def _push_state(self, cls: typing.Type[ST], *args) -> ST: state = cls(self.state, *args) if isinstance(state, NamespaceBlockState): self.current_namespace = state.namespace @@ -133,7 +136,7 @@ class CxxParser: def _parse_error( self, tok: typing.Optional[LexToken], expected="" - ) -> typing.NoReturn: + ) -> CxxParseError: if not tok: # common case after a failed token_if tok = self.lex.token() @@ -270,7 +273,9 @@ class CxxParser: """ # non-ambiguous parsing functions for each token type - _translation_unit_tokens = { + _translation_unit_tokens: typing.Dict[ + str, typing.Callable[[LexToken, typing.Optional[str]], typing.Any] + ] = { "__attribute__": self._consume_gcc_attribute, "__declspec": self._consume_declspec, "alignas": self._consume_attribute_specifier_seq, @@ -339,16 +344,16 @@ class CxxParser: self, tok: LexToken, doxygen: typing.Optional[str] ) -> None: value = self._preprocessor_compress_re.sub("#", tok.value) - value = self._preprocessor_split_re.split(value, 1) - if len(value) == 2: + svalue = self._preprocessor_split_re.split(value, 1) + if len(svalue) == 2: self.state.location = tok.location - macro = value[0].lower().replace(" ", "") + macro = svalue[0].lower().replace(" ", "") if macro.startswith("#include"): - self.visitor.on_include(self.state, value[1]) + self.visitor.on_include(self.state, svalue[1]) elif macro.startswith("#define"): - self.visitor.on_define(self.state, value[1]) + self.visitor.on_define(self.state, svalue[1]) elif macro.startswith("#pragma"): - self.visitor.on_pragma(self.state, value[1]) + self.visitor.on_pragma(self.state, svalue[1]) # # Various @@ -453,7 +458,7 @@ class CxxParser: # def _parse_template_type_parameter( - self, tok: LexToken, template: TemplateDecl + self, tok: LexToken, template: typing.Optional[TemplateDecl] ) -> TemplateTypeParam: """ type_parameter: "class" ["..."] [IDENTIFIER] @@ -469,12 +474,12 @@ class CxxParser: name = None default = None - tok = self.lex.token_if("NAME") - if tok: - name = tok.value + otok = self.lex.token_if("NAME") + if otok: + name = otok.value - tok = self.lex.token_if("=") - if tok: + otok = self.lex.token_if("=") + if otok: default = self._create_value(self._consume_value_until([], ",", ">")) return TemplateTypeParam(typekey, name, param_pack, default, template) @@ -492,7 +497,7 @@ class CxxParser: | parameter_declaration """ tok = self._next_token_must_be("<") - params = [] + params: typing.List[TemplateParam] = [] lex = self.lex @@ -502,6 +507,8 @@ class CxxParser: tok = lex.token() tok_type = tok.type + param: TemplateParam + if tok_type == "template": template = self._parse_template_decl() tok = self._next_token_must_be("class", "typename") @@ -528,9 +535,7 @@ class CxxParser: return TemplateDecl(params) - def _parse_template( - self, tok: LexToken, doxygen: typing.Optional[str] - ) -> TemplateDecl: + def _parse_template(self, tok: LexToken, doxygen: typing.Optional[str]): template = self._parse_template_decl() @@ -908,7 +913,7 @@ class CxxParser: # might start with attributes if tok.type in self._attribute_specifier_seq_start_types: - self._parse_attribute_specifier_seq(tok) + self._consume_attribute_specifier_seq(tok) tok = self.lex.token() tok_type = tok.type @@ -1155,6 +1160,8 @@ class CxxParser: doxygen = self.lex.get_doxygen() if is_typedef: + if not name: + raise self._parse_error(None) typedef = Typedef(dtype, name, self._current_access) self.visitor.on_typedef(state, typedef) else: @@ -1423,8 +1430,8 @@ class CxxParser: # def _parse_parameter( - self, tok: typing.Optional[LexToken], cls: typing.Type[T], end: str = ")" - ) -> T: + self, tok: typing.Optional[LexToken], cls: typing.Type[PT], end: str = ")" + ) -> PT: """ Parses a single parameter (excluding vararg parameters). Also used to parse template non-type parameters @@ -1516,13 +1523,13 @@ class CxxParser: if self.lex.token_if("throw"): tok = self._next_token_must_be("(") - fn.throw = self._create_value(self._consume_balanced_items(tok)) + fn.throw = self._create_value(self._consume_balanced_tokens(tok)) elif self.lex.token_if("noexcept"): toks = [] - tok = self.lex.token_if("(") - if tok: - toks = self._consume_balanced_tokens(tok)[1:-1] + otok = self.lex.token_if("(") + if otok: + toks = self._consume_balanced_tokens(otok)[1:-1] fn.noexcept = self._create_value(toks) if self.lex.token_if("{"): @@ -1575,9 +1582,9 @@ class CxxParser: method.throw = self._create_value(self._consume_balanced_tokens(tok)) elif tok_value == "noexcept": toks = [] - tok = self.lex.token_if("(") - if tok: - toks = self._consume_balanced_tokens(tok)[1:-1] + otok = self.lex.token_if("(") + if otok: + toks = self._consume_balanced_tokens(otok)[1:-1] method.noexcept = self._create_value(toks) else: self.lex.return_token(tok) @@ -1620,6 +1627,8 @@ class CxxParser: if is_class_block: props.update(dict.fromkeys(mods.meths.keys(), True)) + method: Method + if op: method = Operator( return_type, @@ -1684,18 +1693,18 @@ class CxxParser: assert tok.type == "[" toks = self._consume_balanced_tokens(tok) - tok = self.lex.token_if("[") - if tok: + otok = self.lex.token_if("[") + if otok: # recurses because array types are right to left - dtype = self._parse_array_type(tok, dtype) + dtype = self._parse_array_type(otok, dtype) toks = toks[1:-1] - if toks: - value = self._create_value(toks) - else: - value = None + size = None - return Array(dtype, value) + if toks: + size = self._create_value(toks) + + return Array(dtype, size) def _parse_cv_ptr( self, @@ -1805,7 +1814,7 @@ class CxxParser: if not tok: tok = get_token() - pqname: PQName = None + pqname: typing.Optional[PQName] = None _pqname_start_tokens = self._pqname_start_tokens @@ -1837,12 +1846,13 @@ class CxxParser: elif tok_type == "volatile": volatile = True else: - if pqname is None: - raise self._parse_error(tok) break tok = get_token() + if pqname is None: + raise self._parse_error(tok) + self.lex.return_token(tok) # Construct a type from the parsed name @@ -1931,6 +1941,9 @@ class CxxParser: if is_typedef: raise self._parse_error(None) + if not pqname: + raise self._parse_error(None) + return self._parse_function( mods, dtype, @@ -2051,7 +2064,9 @@ class CxxParser: # enum cannot be forward declared, but "enum class" can # -> but `friend enum X` is fine - if classkey == "enum" and not is_friend: + if not classkey: + raise self._parse_error(None) + elif classkey == "enum" and not is_friend: raise self._parse_error(None) elif template and classkey.startswith("enum"): # enum class cannot have a template diff --git a/cxxheaderparser/parserstate.py b/cxxheaderparser/parserstate.py index 46abe33..0f0722b 100644 --- a/cxxheaderparser/parserstate.py +++ b/cxxheaderparser/parserstate.py @@ -33,6 +33,9 @@ class State: #: parent state parent: typing.Optional["State"] + #: Approximate location that the parsed element was found at + location: Location + def __init__(self, parent: typing.Optional["State"]) -> None: self.parent = parent @@ -40,18 +43,12 @@ class State: pass -class BlockState(State): - - #: Approximate location that the parsed element was found at - location: Location - - -class EmptyBlockState(BlockState): +class EmptyBlockState(State): def _finish(self, visitor: "CxxVisitor") -> None: visitor.on_empty_block_end(self) -class ExternBlockState(BlockState): +class ExternBlockState(State): #: The linkage for this extern block linkage: str @@ -64,7 +61,7 @@ class ExternBlockState(BlockState): visitor.on_extern_block_end(self) -class NamespaceBlockState(BlockState): +class NamespaceBlockState(State): #: The incremental namespace for this block namespace: NamespaceDecl @@ -79,7 +76,7 @@ class NamespaceBlockState(BlockState): visitor.on_namespace_end(self) -class ClassBlockState(BlockState): +class ClassBlockState(State): #: class decl block being processed class_decl: ClassDecl diff --git a/cxxheaderparser/simple.py b/cxxheaderparser/simple.py index 67d35e0..5e84e8b 100644 --- a/cxxheaderparser/simple.py +++ b/cxxheaderparser/simple.py @@ -98,7 +98,7 @@ class NamespaceScope: forward_decls: typing.List[ForwardDecl] = field(default_factory=list) using: typing.List[UsingDecl] = field(default_factory=list) - using_ns: typing.List[UsingDecl] = field(default_factory=list) + using_ns: typing.List["UsingNamespace"] = field(default_factory=list) using_alias: typing.List[UsingAlias] = field(default_factory=list) #: Child namespaces diff --git a/cxxheaderparser/tokfmt.py b/cxxheaderparser/tokfmt.py index 2072f48..ada6790 100644 --- a/cxxheaderparser/tokfmt.py +++ b/cxxheaderparser/tokfmt.py @@ -1,6 +1,6 @@ import typing -from .lexer import Lexer +from .lexer import LexToken, Lexer from .types import Token # key: token type, value: (left spacing, right spacing) @@ -58,7 +58,7 @@ if __name__ == "__main__": with open(lexer.filename) as fp: lexer.input(fp.read()) - toks = [] + toks: typing.List[Token] = [] while True: tok = lexer.token_eof_ok() if not tok: @@ -68,7 +68,7 @@ if __name__ == "__main__": print(tokfmt(toks)) toks = [] else: - toks.append(tok) + toks.append(Token(tok.value, tok.type)) print(toks) print(tokfmt(toks)) diff --git a/cxxheaderparser/types.py b/cxxheaderparser/types.py index d834c19..21bfa5e 100644 --- a/cxxheaderparser/types.py +++ b/cxxheaderparser/types.py @@ -95,7 +95,7 @@ class NameSpecifier: name: str - specialization: typing.Optional[typing.List["TemplateSpecialization"]] = None + specialization: typing.Optional["TemplateSpecialization"] = None @dataclass @@ -233,16 +233,15 @@ class FunctionType: @dataclass class Type: - """""" + """ + A type with a name associated with it + """ typename: PQName const: bool = False volatile: bool = False - def get_type(self) -> "Type": - return self - @dataclass class Array: @@ -262,9 +261,6 @@ class Array: #: ~~ size: typing.Optional[Value] - def get_type(self) -> Type: - return self.array_of.get_type() - @dataclass class Pointer: @@ -278,9 +274,6 @@ class Pointer: const: bool = False volatile: bool = False - def get_type(self) -> Type: - return self.ptr_to.get_type() - @dataclass class Reference: @@ -290,9 +283,6 @@ class Reference: ref_to: typing.Union[Array, Pointer, Type] - def get_type(self) -> Type: - return self.ref_to.get_type() - @dataclass class MoveReference: @@ -302,9 +292,6 @@ class MoveReference: moveref_to: typing.Union[Array, Pointer, Type] - def get_type(self) -> Type: - return self.moveref_to.get_type() - #: A type or function type that is decorated with various things #: @@ -443,7 +430,7 @@ class ClassDecl: access: typing.Optional[str] = None @property - def classkey(self) -> str: + def classkey(self) -> typing.Optional[str]: return self.typename.classkey diff --git a/cxxheaderparser/visitor.py b/cxxheaderparser/visitor.py index f9243ab..8afeb45 100644 --- a/cxxheaderparser/visitor.py +++ b/cxxheaderparser/visitor.py @@ -2,10 +2,11 @@ import sys import typing if sys.version_info >= (3, 8): - Protocol = typing.Protocol + from typing import Protocol else: Protocol = object + from .types import ( EnumDecl, Field,