diff --git a/cxxheaderparser/gentest.py b/cxxheaderparser/gentest.py index ace24d7..a738298 100644 --- a/cxxheaderparser/gentest.py +++ b/cxxheaderparser/gentest.py @@ -63,7 +63,7 @@ def gentest( popt = "" options = ParserOptions(verbose=verbose) - if options: + if pcpp: options.preprocessor = make_pcpp_preprocessor() maybe_options = "options = ParserOptions(preprocessor=make_pcpp_preprocessor())" popt = ", options=options" diff --git a/cxxheaderparser/parser.py b/cxxheaderparser/parser.py index 0fed1c2..1fd0fa8 100644 --- a/cxxheaderparser/parser.py +++ b/cxxheaderparser/parser.py @@ -58,7 +58,7 @@ from .types import ( Value, Variable, ) -from .visitor import CxxVisitor +from .visitor import CxxVisitor, null_visitor LexTokenList = typing.List[LexToken] T = typing.TypeVar("T") @@ -114,6 +114,7 @@ class CxxParser: def _push_state(self, cls: typing.Type[ST], *args) -> ST: state = cls(self.state, *args) + state._prior_visitor = self.visitor if isinstance(state, NamespaceBlockState): self.current_namespace = state.namespace self.state = state @@ -122,6 +123,7 @@ class CxxParser: def _pop_state(self) -> State: prev_state = self.state prev_state._finish(self.visitor) + self.visitor = prev_state._prior_visitor state = prev_state.parent if state is None: raise CxxParseError("INTERNAL ERROR: unbalanced state") @@ -454,7 +456,8 @@ class CxxParser: ns = NamespaceDecl(names, inline, doxygen) state = self._push_state(NamespaceBlockState, ns) state.location = location - self.visitor.on_namespace_start(state) + if self.visitor.on_namespace_start(state) is False: + self.visitor = null_visitor def _parse_extern(self, tok: LexToken, doxygen: typing.Optional[str]) -> None: etok = self.lex.token_if("STRING_LITERAL", "template") @@ -463,7 +466,8 @@ class CxxParser: if self.lex.token_if("{"): state = self._push_state(ExternBlockState, etok.value) state.location = tok.location - self.visitor.on_extern_block_start(state) + if self.visitor.on_extern_block_start(state) is False: + self.visitor = null_visitor return # an extern variable/function with specific linkage @@ -508,7 +512,8 @@ class CxxParser: self, tok: LexToken, doxygen: typing.Optional[str] ) -> None: state = self._push_state(EmptyBlockState) - self.visitor.on_empty_block_start(state) + if self.visitor.on_empty_block_start(state) is False: + self.visitor = null_visitor def _on_block_end(self, tok: LexToken, doxygen: typing.Optional[str]) -> None: old_state = self._pop_state() @@ -1143,7 +1148,8 @@ class CxxParser: ClassBlockState, clsdecl, default_access, typedef, mods ) state.location = location - self.visitor.on_class_start(state) + if self.visitor.on_class_start(state) is False: + self.visitor = null_visitor def _finish_class_decl(self, state: ClassBlockState) -> None: self._finish_class_or_enum( diff --git a/cxxheaderparser/parserstate.py b/cxxheaderparser/parserstate.py index d2f1a95..86ed2f6 100644 --- a/cxxheaderparser/parserstate.py +++ b/cxxheaderparser/parserstate.py @@ -46,6 +46,9 @@ class State(typing.Generic[T, PT]): #: Approximate location that the parsed element was found at location: Location + #: internal detail used by parser + _prior_visitor: "CxxVisitor" + def __init__(self, parent: typing.Optional["State[PT, typing.Any]"]) -> None: self.parent = parent diff --git a/cxxheaderparser/simple.py b/cxxheaderparser/simple.py index f0ec240..a58191e 100644 --- a/cxxheaderparser/simple.py +++ b/cxxheaderparser/simple.py @@ -209,22 +209,24 @@ class SimpleCxxVisitor: def on_include(self, state: SState, filename: str) -> None: self.data.includes.append(Include(filename)) - def on_empty_block_start(self, state: SEmptyBlockState) -> None: + def on_empty_block_start(self, state: SEmptyBlockState) -> typing.Optional[bool]: # this matters for some scope/resolving purposes, but you're # probably going to want to use clang if you care about that # level of detail state.user_data = state.parent.user_data + return None def on_empty_block_end(self, state: SEmptyBlockState) -> None: pass - def on_extern_block_start(self, state: SExternBlockState) -> None: + def on_extern_block_start(self, state: SExternBlockState) -> typing.Optional[bool]: state.user_data = state.parent.user_data + return None def on_extern_block_end(self, state: SExternBlockState) -> None: pass - def on_namespace_start(self, state: SNamespaceBlockState) -> None: + def on_namespace_start(self, state: SNamespaceBlockState) -> typing.Optional[bool]: parent_ns = state.parent.user_data ns = None @@ -247,6 +249,7 @@ class SimpleCxxVisitor: ns.doxygen = state.namespace.doxygen state.user_data = ns + return None def on_namespace_end(self, state: SNamespaceBlockState) -> None: pass @@ -299,11 +302,12 @@ class SimpleCxxVisitor: # Class/union/struct # - def on_class_start(self, state: SClassBlockState) -> None: + def on_class_start(self, state: SClassBlockState) -> typing.Optional[bool]: parent = state.parent.user_data block = ClassScope(state.class_decl) parent.classes.append(block) state.user_data = block + return None def on_class_field(self, state: SClassBlockState, f: Field) -> None: state.user_data.fields.append(f) diff --git a/cxxheaderparser/visitor.py b/cxxheaderparser/visitor.py index 6088139..5c17a46 100644 --- a/cxxheaderparser/visitor.py +++ b/cxxheaderparser/visitor.py @@ -52,7 +52,7 @@ class CxxVisitor(Protocol): Called once for each ``#include`` directive encountered """ - def on_empty_block_start(self, state: EmptyBlockState) -> None: + def on_empty_block_start(self, state: EmptyBlockState) -> typing.Optional[bool]: """ Called when a ``{`` is encountered that isn't associated with or consumed by other declarations. @@ -62,6 +62,9 @@ class CxxVisitor(Protocol): { // stuff } + + If this function returns False, the visitor will not be called for any + items inside this block (including on_empty_block_end) """ def on_empty_block_end(self, state: EmptyBlockState) -> None: @@ -69,7 +72,7 @@ class CxxVisitor(Protocol): Called when an empty block ends """ - def on_extern_block_start(self, state: ExternBlockState) -> None: + def on_extern_block_start(self, state: ExternBlockState) -> typing.Optional[bool]: """ .. code-block:: c++ @@ -77,6 +80,8 @@ class CxxVisitor(Protocol): } + If this function returns False, the visitor will not be called for any + items inside this block (including on_extern_block_end) """ def on_extern_block_end(self, state: ExternBlockState) -> None: @@ -84,9 +89,12 @@ class CxxVisitor(Protocol): Called when an extern block ends """ - def on_namespace_start(self, state: NamespaceBlockState) -> None: + def on_namespace_start(self, state: NamespaceBlockState) -> typing.Optional[bool]: """ Called when a ``namespace`` directive is encountered + + If this function returns False, the visitor will not be called for any + items inside this namespace (including on_namespace_end) """ def on_namespace_end(self, state: NamespaceBlockState) -> None: @@ -186,7 +194,7 @@ class CxxVisitor(Protocol): # Class/union/struct # - def on_class_start(self, state: ClassBlockState) -> None: + def on_class_start(self, state: ClassBlockState) -> typing.Optional[bool]: """ Called when a class/struct/union is encountered @@ -199,6 +207,9 @@ class CxxVisitor(Protocol): This is called first, followed by on_typedef for each typedef instance encountered. The compound type object is passed as the type to the typedef. + + If this function returns False, the visitor will not be called for any + items inside this class (including on_class_end) """ def on_class_field(self, state: ClassBlockState, f: Field) -> None: @@ -231,3 +242,87 @@ class CxxVisitor(Protocol): Then ``on_class_start``, .. ``on_class_end`` are emitted, along with ``on_variable`` for each instance declared. """ + + +class NullVisitor: + """ + This visitor does nothing + """ + + def on_parse_start(self, state: NamespaceBlockState) -> None: + return None + + def on_pragma(self, state: State, content: Value) -> None: + return None + + def on_include(self, state: State, filename: str) -> None: + return None + + def on_empty_block_start(self, state: EmptyBlockState) -> typing.Optional[bool]: + return None + + def on_empty_block_end(self, state: EmptyBlockState) -> None: + return None + + def on_extern_block_start(self, state: ExternBlockState) -> typing.Optional[bool]: + return None + + def on_extern_block_end(self, state: ExternBlockState) -> None: + return None + + def on_namespace_start(self, state: NamespaceBlockState) -> typing.Optional[bool]: + return None + + def on_namespace_end(self, state: NamespaceBlockState) -> None: + return None + + def on_namespace_alias(self, state: State, alias: NamespaceAlias) -> None: + return None + + def on_forward_decl(self, state: State, fdecl: ForwardDecl) -> None: + return None + + def on_template_inst(self, state: State, inst: TemplateInst) -> None: + return None + + def on_variable(self, state: State, v: Variable) -> None: + return None + + def on_function(self, state: State, fn: Function) -> None: + return None + + def on_method_impl(self, state: State, method: Method) -> None: + return None + + def on_typedef(self, state: State, typedef: Typedef) -> None: + return None + + def on_using_namespace(self, state: State, namespace: typing.List[str]) -> None: + return None + + def on_using_alias(self, state: State, using: UsingAlias) -> None: + return None + + def on_using_declaration(self, state: State, using: UsingDecl) -> None: + return None + + def on_enum(self, state: State, enum: EnumDecl) -> None: + return None + + def on_class_start(self, state: ClassBlockState) -> typing.Optional[bool]: + return None + + def on_class_field(self, state: ClassBlockState, f: Field) -> None: + return None + + def on_class_friend(self, state: ClassBlockState, friend: FriendDecl) -> None: + return None + + def on_class_method(self, state: ClassBlockState, method: Method) -> None: + return None + + def on_class_end(self, state: ClassBlockState) -> None: + return None + + +null_visitor = NullVisitor() diff --git a/tests/test_skip.py b/tests/test_skip.py new file mode 100644 index 0000000..2415e03 --- /dev/null +++ b/tests/test_skip.py @@ -0,0 +1,269 @@ +# Note: testcases generated via `python -m cxxheaderparser.gentest` +# .. and modified + +import inspect +import typing + +from cxxheaderparser.parser import CxxParser +from cxxheaderparser.simple import ( + ClassScope, + NamespaceScope, + ParsedData, + SClassBlockState, + SEmptyBlockState, + SExternBlockState, + SNamespaceBlockState, + SimpleCxxVisitor, +) + +from cxxheaderparser.types import ( + ClassDecl, + Function, + FundamentalSpecifier, + Method, + NameSpecifier, + PQName, + Type, +) + +# +# ensure extern block is skipped +# + + +class SkipExtern(SimpleCxxVisitor): + def on_extern_block_start(self, state: SExternBlockState) -> typing.Optional[bool]: + return False + + +def test_skip_extern(): + content = """ + void fn1(); + + extern "C" { + void fn2(); + } + + void fn3(); + """ + + v = SkipExtern() + parser = CxxParser("", inspect.cleandoc(content), v) + parser.parse() + data = v.data + + assert data == ParsedData( + namespace=NamespaceScope( + functions=[ + Function( + return_type=Type( + typename=PQName(segments=[FundamentalSpecifier(name="void")]) + ), + name=PQName(segments=[NameSpecifier(name="fn1")]), + parameters=[], + ), + Function( + return_type=Type( + typename=PQName(segments=[FundamentalSpecifier(name="void")]) + ), + name=PQName(segments=[NameSpecifier(name="fn3")]), + parameters=[], + ), + ] + ) + ) + + +# +# ensure class block is skipped +# + + +class SkipClass(SimpleCxxVisitor): + def on_class_start(self, state: SClassBlockState) -> typing.Optional[bool]: + if getattr(state.class_decl.typename.segments[0], "name", None) == "Skip": + return False + return super().on_class_start(state) + + +def test_skip_class() -> None: + content = """ + void fn1(); + + class Skip { + void fn2(); + }; + + class Yup { + void fn3(); + }; + + void fn5(); + """ + v = SkipClass() + parser = CxxParser("", inspect.cleandoc(content), v) + parser.parse() + data = v.data + + assert data == ParsedData( + namespace=NamespaceScope( + classes=[ + ClassScope( + class_decl=ClassDecl( + typename=PQName( + segments=[NameSpecifier(name="Yup")], classkey="class" + ) + ), + methods=[ + Method( + return_type=Type( + typename=PQName( + segments=[FundamentalSpecifier(name="void")] + ) + ), + name=PQName(segments=[NameSpecifier(name="fn3")]), + parameters=[], + access="private", + ) + ], + ), + ], + functions=[ + Function( + return_type=Type( + typename=PQName(segments=[FundamentalSpecifier(name="void")]) + ), + name=PQName(segments=[NameSpecifier(name="fn1")]), + parameters=[], + ), + Function( + return_type=Type( + typename=PQName(segments=[FundamentalSpecifier(name="void")]) + ), + name=PQName(segments=[NameSpecifier(name="fn5")]), + parameters=[], + ), + ], + ) + ) + + +# +# ensure empty block is skipped +# + + +class SkipEmptyBlock(SimpleCxxVisitor): + def on_empty_block_start(self, state: SEmptyBlockState) -> typing.Optional[bool]: + return False + + +def test_skip_empty_block() -> None: + content = """ + void fn1(); + + { + void fn2(); + } + + void fn3(); + """ + v = SkipEmptyBlock() + parser = CxxParser("", inspect.cleandoc(content), v) + parser.parse() + data = v.data + + assert data == ParsedData( + namespace=NamespaceScope( + functions=[ + Function( + return_type=Type( + typename=PQName(segments=[FundamentalSpecifier(name="void")]) + ), + name=PQName(segments=[NameSpecifier(name="fn1")]), + parameters=[], + ), + Function( + return_type=Type( + typename=PQName(segments=[FundamentalSpecifier(name="void")]) + ), + name=PQName(segments=[NameSpecifier(name="fn3")]), + parameters=[], + ), + ] + ) + ) + + +# +# ensure namespace 'skip' is skipped +# + + +class SkipNamespace(SimpleCxxVisitor): + def on_namespace_start(self, state: SNamespaceBlockState) -> typing.Optional[bool]: + if "skip" in state.namespace.names[0]: + return False + + return super().on_namespace_start(state) + + +def test_skip_namespace(): + content = """ + void fn1(); + + namespace skip { + void fn2(); + + namespace thistoo { + void fn3(); + } + } + + namespace ok { + void fn4(); + } + + void fn5(); + """ + v = SkipNamespace() + parser = CxxParser("", inspect.cleandoc(content), v) + parser.parse() + data = v.data + + assert data == ParsedData( + namespace=NamespaceScope( + functions=[ + Function( + return_type=Type( + typename=PQName(segments=[FundamentalSpecifier(name="void")]) + ), + name=PQName(segments=[NameSpecifier(name="fn1")]), + parameters=[], + ), + Function( + return_type=Type( + typename=PQName(segments=[FundamentalSpecifier(name="void")]) + ), + name=PQName(segments=[NameSpecifier(name="fn5")]), + parameters=[], + ), + ], + namespaces={ + "ok": NamespaceScope( + name="ok", + functions=[ + Function( + return_type=Type( + typename=PQName( + segments=[FundamentalSpecifier(name="void")] + ) + ), + name=PQName(segments=[NameSpecifier(name="fn4")]), + parameters=[], + ) + ], + ), + }, + ) + )