diff --git a/src/docstub-stubs/_analysis.pyi b/src/docstub-stubs/_analysis.pyi index 4e439d9..34a3cec 100644 --- a/src/docstub-stubs/_analysis.pyi +++ b/src/docstub-stubs/_analysis.pyi @@ -14,6 +14,7 @@ from typing import Any, ClassVar import libcst as cst import libcst.matchers as cstm +from ._report import Stats from ._utils import accumulate_qualname, module_name_from_path, pyfile_checksum logger: logging.Logger @@ -83,6 +84,7 @@ class TypeMatcher: types: dict[str, PyImport] | None = ..., type_prefixes: dict[str, PyImport] | None = ..., type_nicknames: dict[str, str] | None = ..., + stats: Stats | None = ..., ) -> None: ... def _resolve_nickname(self, name: str) -> str: ... def match(self, search: str) -> tuple[str | None, PyImport | None]: ... diff --git a/src/docstub-stubs/_app_generate_stubs.pyi b/src/docstub-stubs/_app_generate_stubs.pyi index 3839bac..ae4c95b 100644 --- a/src/docstub-stubs/_app_generate_stubs.pyi +++ b/src/docstub-stubs/_app_generate_stubs.pyi @@ -18,9 +18,8 @@ from ._path_utils import ( walk_source_and_targets, walk_source_package, ) -from ._report import setup_logging +from ._report import Stats, setup_logging from ._stubs import Py2StubTransformer, try_format_stub -from ._utils import update_with_add_values from ._version import __version__ logger: logging.Logger diff --git a/src/docstub-stubs/_docstrings.pyi b/src/docstub-stubs/_docstrings.pyi index cfcedf5..f116ada 100644 --- a/src/docstub-stubs/_docstrings.pyi +++ b/src/docstub-stubs/_docstrings.pyi @@ -2,11 +2,10 @@ import logging import traceback +import warnings from collections.abc import Generator, Iterable from dataclasses import dataclass, field from functools import cached_property -from pathlib import Path -from typing import Any, ClassVar import click import lark @@ -14,20 +13,15 @@ import lark.visitors import numpydoc.docscrape as npds from ._analysis import PyImport, TypeMatcher -from ._report import ContextReporter -from ._utils import DocstubError, escape_qualname +from ._doctype import BlacklistedQualname, Expr, Term, TermKind, parse_doctype +from ._report import ContextReporter, Stats +from ._utils import escape_qualname logger: logging.Logger -here: Path -grammar_path: Path - -with grammar_path.open() as file: - _grammar: str - -_lark: lark.Lark - -def _find_one_token(tree: lark.Tree, *, name: str) -> lark.Token: ... +def _update_qualnames( + expr: Expr, *, _parents: tuple[Expr, ...] = ... +) -> Generator[tuple[tuple[Expr, ...], Term], str]: ... @dataclass(frozen=True, slots=True, kw_only=True) class Annotation: @@ -54,57 +48,31 @@ class Annotation: FallbackAnnotation: Annotation -class QualnameIsKeyword(DocstubError): - pass - -class DoctypeTransformer(lark.visitors.Transformer): - matcher: TypeMatcher - stats: dict[str, Any] - - blacklisted_qualnames: ClassVar[frozenset[str]] - - def __init__( - self, *, matcher: TypeMatcher | None = ..., **kwargs: dict[Any, Any] - ) -> None: ... - def doctype_to_annotation( - self, doctype: str, *, reporter: ContextReporter | None = ... - ) -> tuple[Annotation, list[tuple[str, int, int]]]: ... - def qualname(self, tree: lark.Tree) -> lark.Token: ... - def rst_role(self, tree: lark.Tree) -> lark.Token: ... - def union(self, tree: lark.Tree) -> str: ... - def subscription(self, tree: lark.Tree) -> str: ... - def param_spec(self, tree: lark.Tree) -> str: ... - def callable(self, tree: lark.Tree) -> str: ... - def natlang_literal(self, tree: lark.Tree) -> str: ... - def natlang_container(self, tree: lark.Tree) -> str: ... - def natlang_array(self, tree: lark.Tree) -> str: ... - def array_name(self, tree: lark.Tree) -> lark.Token: ... - def shape(self, tree: lark.Tree) -> lark.visitors._DiscardType: ... - def optional_info(self, tree: lark.Tree) -> lark.visitors._DiscardType: ... - def __default__( - self, data: lark.Token, children: list[lark.Token], meta: lark.tree.Meta - ) -> lark.Token | list[lark.Token]: ... - def _match_import(self, qualname: str, *, meta: lark.tree.Meta) -> str: ... - def _uncombine_numpydoc_params( params: list[npds.Parameter], ) -> Generator[npds.Parameter]: ... +def _red_partial_underline(doctype: str, *, start: int, stop: int) -> str: ... +def doctype_to_annotation( + doctype: str, + *, + matcher: TypeMatcher | None = ..., + reporter: ContextReporter | None = ..., + stats: Stats | None = ..., +) -> Annotation: ... class DocstringAnnotations: docstring: str - transformer: DoctypeTransformer + matcher: TypeMatcher reporter: ContextReporter def __init__( self, docstring: str, *, - transformer: DoctypeTransformer, + matcher: TypeMatcher | None = ..., reporter: ContextReporter | None = ..., + stats: Stats | None = ..., ) -> None: ... - def _doctype_to_annotation( - self, doctype: str, ds_line: int = ... - ) -> Annotation: ... @cached_property def attributes(self) -> dict[str, Annotation]: ... @cached_property diff --git a/src/docstub-stubs/_doctype.pyi b/src/docstub-stubs/_doctype.pyi new file mode 100644 index 0000000..97e462a --- /dev/null +++ b/src/docstub-stubs/_doctype.pyi @@ -0,0 +1,95 @@ +# File generated with docstub + +import enum +import keyword +import logging +from collections.abc import Generator, Iterable, Sequence +from dataclasses import dataclass +from pathlib import Path +from textwrap import indent +from typing import Any, Final, Self + +import lark +import lark.visitors +from _typeshed import Incomplete + +from ._report import ContextReporter +from ._utils import DocstubError + +logger: Final[logging.Logger] + +grammar_path: Final[Path] + +with grammar_path.open() as file: + _grammar: Final[str] + +_lark: Final[lark.Lark] + +def flatten_recursive(iterable: Iterable[Iterable | str]) -> Generator[str]: ... +def insert_between(iterable: Iterable, *, sep: Any) -> list[Any]: ... + +class TermKind(enum.StrEnum): + + NAME = enum.auto() + LITERAL = enum.auto() + SYNTAX = enum.auto() + +class Term(str): + kind: TermKind + pos: tuple[int, int] | None + + __slots__: Final[tuple[str, ...]] + + def __new__( + cls, value: str, *, kind: TermKind | str, pos: tuple[int, int] | None = ... + ) -> Self: ... + def __repr__(self) -> str: ... + def __getnewargs_ex__(self) -> tuple[tuple[Any, ...], dict[str, Any]]: ... + +@dataclass(slots=True) +class Expr: + + rule: str + children: list[Self | Term] + + @property + def terms(self) -> list[Term]: ... + @property + def names(self) -> list[Term]: ... + @property + def sub_expressions(self) -> list[Self]: ... + def __iter__(self) -> Generator[Expr | Term]: ... + def format_tree(self) -> str: ... + def print_tree(self) -> None: ... + def __repr__(self) -> str: ... + def __str__(self) -> str: ... + def as_code(self) -> str: ... + +BLACKLISTED_QUALNAMES: Final[set[str]] + +class BlacklistedQualname(DocstubError): + pass + +class DoctypeTransformer(lark.visitors.Transformer): + def __init__(self, *, reporter: ContextReporter | None = ...) -> None: ... + def start(self, tree: lark.Tree) -> Expr: ... + def qualname(self, tree: lark.Tree) -> Term: ... + def rst_role(self, tree: lark.Tree) -> Expr: ... + def ELLIPSES(self, token: lark.Token) -> Term: ... + def union(self, tree: lark.Tree) -> Expr: ... + def subscription(self, tree: lark.Tree) -> Expr: ... + def param_spec(self, tree: lark.Tree) -> Expr: ... + def callable(self, tree: lark.Tree) -> Expr: ... + def literal(self, tree: lark.Tree) -> Expr: ... + def natlang_literal(self, tree: lark.Tree) -> Expr: ... + def literal_item(self, tree: lark.Tree) -> Term: ... + def natlang_container(self, tree: lark.Tree) -> Expr: ... + def natlang_array(self, tree: lark.Tree) -> Expr: ... + def array_name(self, tree: lark.Tree) -> Term: ... + def dtype(self, tree: lark.Tree) -> Expr: ... + def shape(self, tree: lark.Tree) -> lark.visitors._DiscardType: ... + def optional_info(self, tree: lark.Tree) -> lark.visitors._DiscardType: ... + def extra_info(self, tree: lark.Tree) -> lark.visitors._DiscardType: ... + def _format_subscription(self, sequence: Sequence[str], *, rule: str) -> Expr: ... + +def parse_doctype(doctype: str, *, reporter: ContextReporter | None = ...) -> Expr: ... diff --git a/src/docstub-stubs/_report.pyi b/src/docstub-stubs/_report.pyi index 0ba8bf6..6430e76 100644 --- a/src/docstub-stubs/_report.pyi +++ b/src/docstub-stubs/_report.pyi @@ -2,6 +2,7 @@ import dataclasses import logging +from collections.abc import Hashable, Iterator, Mapping, Sequence from pathlib import Path from textwrap import indent from typing import Any, ClassVar, Literal, Self, TextIO @@ -33,23 +34,43 @@ class ContextReporter: short: str, *args: Any, log_level: int, - details: str | None = ..., + details: str | tuple[Any, ...] | None = ..., **log_kw: Any ) -> None: ... def debug( - self, short: str, *args: Any, details: str | None = ..., **log_kw: Any + self, + short: str, + *args: Any, + details: str | tuple[Any, ...] | None = ..., + **log_kw: Any ) -> None: ... def info( - self, short: str, *args: Any, details: str | None = ..., **log_kw: Any + self, + short: str, + *args: Any, + details: str | tuple[Any, ...] | None = ..., + **log_kw: Any ) -> None: ... def warn( - self, short: str, *args: Any, details: str | None = ..., **log_kw: Any + self, + short: str, + *args: Any, + details: str | tuple[Any, ...] | None = ..., + **log_kw: Any ) -> None: ... def error( - self, short: str, *args: Any, details: str | None = ..., **log_kw: Any + self, + short: str, + *args: Any, + details: str | tuple[Any, ...] | None = ..., + **log_kw: Any ) -> None: ... def critical( - self, short: str, *args: Any, details: str | None = ..., **log_kw: Any + self, + short: str, + *args: Any, + details: str | tuple[Any, ...] | None = ..., + **log_kw: Any ) -> None: ... def __post_init__(self) -> None: ... @staticmethod @@ -80,3 +101,22 @@ class LogCounter(logging.NullHandler): def setup_logging( *, verbosity: Literal[-2, -1, 0, 1, 2, 3], group_errors: bool ) -> tuple[ReportHandler, LogCounter]: ... +def update_with_add_values( + *mappings: Mapping[Hashable, int | Sequence], out: dict | None = ... +) -> dict: ... + +class Stats(Mapping): + class _UNSET: + pass + + def __init__(self, stats: dict[str, list[Any] | str] | None = ...) -> None: ... + def __getitem__(self, key: str) -> list[Any] | int: ... + def __iter__(self) -> Iterator: ... + def __len__(self) -> int: ... + def inc_counter(self, key: str, *, inc: int = ...) -> None: ... + def append_to_list(self, key: str, value: Any) -> None: ... + @classmethod + def merge(cls, *stats: Self) -> Self: ... + def __repr__(self) -> str: ... + def pop(self, key: str, *, default: Any = ...) -> list[Any] | int: ... + def pop_all(self) -> dict[str, list[Any] | int]: ... diff --git a/src/docstub-stubs/_stubs.pyi b/src/docstub-stubs/_stubs.pyi index 00d8104..590d421 100644 --- a/src/docstub-stubs/_stubs.pyi +++ b/src/docstub-stubs/_stubs.pyi @@ -13,14 +13,10 @@ import libcst.matchers as cstm from _typeshed import Incomplete from ._analysis import PyImport, TypeMatcher -from ._docstrings import ( - Annotation, - DocstringAnnotations, - DoctypeTransformer, - FallbackAnnotation, -) -from ._report import ContextReporter -from ._utils import module_name_from_path, update_with_add_values +from ._docstrings import Annotation, DocstringAnnotations, FallbackAnnotation +from ._doctype import DoctypeTransformer +from ._report import ContextReporter, Stats +from ._utils import module_name_from_path logger: logging.Logger @@ -77,9 +73,6 @@ class Py2StubTransformer(cst.CSTTransformer): @property def is_inside_function_def(self) -> bool: ... def python_to_stub(self, source: str, *, module_path: Path | None = ...) -> str: ... - def collect_stats( - self, *, reset_after: bool = ... - ) -> dict[str, int | list[str]]: ... def visit_ClassDef(self, node: cst.ClassDef) -> Literal[True]: ... def leave_ClassDef( self, original_node: cst.ClassDef, updated_node: cst.ClassDef diff --git a/src/docstub-stubs/_utils.pyi b/src/docstub-stubs/_utils.pyi index aef1204..3072d06 100644 --- a/src/docstub-stubs/_utils.pyi +++ b/src/docstub-stubs/_utils.pyi @@ -2,7 +2,7 @@ import itertools import re -from collections.abc import Callable, Hashable, Mapping, Sequence +from collections.abc import Callable from functools import lru_cache, wraps from pathlib import Path from typing import Any @@ -15,9 +15,6 @@ def _resolve_path_before_caching( ) -> Callable[[Path], str]: ... def module_name_from_path(path: Path) -> str: ... def pyfile_checksum(path: Path) -> str: ... -def update_with_add_values( - *mappings: Mapping[Hashable, int | Sequence], out: dict | None = ... -) -> dict: ... class DocstubError(Exception): pass diff --git a/src/docstub/_analysis.py b/src/docstub/_analysis.py index 32a6daf..a94e6d8 100644 --- a/src/docstub/_analysis.py +++ b/src/docstub/_analysis.py @@ -12,6 +12,7 @@ import libcst as cst import libcst.matchers as cstm +from ._report import Stats from ._utils import accumulate_qualname, module_name_from_path, pyfile_checksum logger: logging.Logger = logging.getLogger(__name__) @@ -492,6 +493,7 @@ def __init__( types=None, type_prefixes=None, type_nicknames=None, + stats=None, ): """ Parameters @@ -499,15 +501,13 @@ def __init__( types : dict[str, PyImport] type_prefixes : dict[str, PyImport] type_nicknames : dict[str, str] + stats : ~.Stats, optional """ self.types = common_known_types() | (types or {}) self.type_prefixes = type_prefixes or {} self.type_nicknames = type_nicknames or {} - self.stats = { - "matched_type_names": 0, - "unknown_type_names": [], - } + self.stats = stats or Stats() self.current_file = None @@ -623,8 +623,8 @@ def match(self, search): type_name = type_name[type_name.find(py_import.target) :] if type_name is not None: - self.stats["matched_type_names"] += 1 + self.stats.inc_counter("matched_type_names") else: - self.stats["unknown_type_names"].append(search) + self.stats.append_to_list("unknown_type_names", search) return type_name, py_import diff --git a/src/docstub/_app_generate_stubs.py b/src/docstub/_app_generate_stubs.py index 1800b41..616f519 100644 --- a/src/docstub/_app_generate_stubs.py +++ b/src/docstub/_app_generate_stubs.py @@ -24,9 +24,8 @@ walk_source_and_targets, walk_source_package, ) -from ._report import setup_logging +from ._report import Stats, setup_logging from ._stubs import Py2StubTransformer, try_format_stub -from ._utils import update_with_add_values from ._version import __version__ logger: logging.Logger = logging.getLogger(__name__) @@ -234,7 +233,9 @@ def _generate_single_stub(task): logger.info("Wrote %s", stub_path) fo.write(stub_content) - stats = stub_transformer.collect_stats() + stats = Stats.merge( + stub_transformer.stats.pop_all(), stub_transformer.matcher.stats.pop_all() + ) return stats @@ -350,7 +351,7 @@ def generate_stubs( stats_per_task = executor.map( _generate_single_stub, task_args, chunksize=chunk_size ) - stats = update_with_add_values(*stats_per_task) + stats = Stats.merge(*stats_per_task) py_typed_out = out_dir / "py.typed" if not py_typed_out.exists(): @@ -368,24 +369,28 @@ def generate_stubs( total_warnings = error_counter.warning_count total_errors = error_counter.error_count - logger.info("Recognized type names: %i", stats["matched_type_names"]) - logger.info("Transformed doctypes: %i", stats["transformed_doctypes"]) + logger.info("Recognized type names: %i", stats.pop("matched_type_names", default=0)) + logger.info( + "Transformed doctypes: %i", stats.pop("transformed_doctypes", default=0) + ) if total_warnings: logger.warning("Warnings: %i", total_warnings) - if stats["doctype_syntax_errors"]: + if "doctype_syntax_errors" in stats: assert total_errors - logger.warning("Syntax errors: %i", stats["doctype_syntax_errors"]) - if stats["unknown_type_names"]: + logger.warning("Syntax errors: %i", stats.pop("doctype_syntax_errors")) + if "unknown_type_names" in stats: assert total_errors logger.warning( "Unknown type names: %i (locations: %i)", len(set(stats["unknown_type_names"])), len(stats["unknown_type_names"]), - extra={"details": _format_unknown_names(stats["unknown_type_names"])}, + extra={"details": _format_unknown_names(stats.pop("unknown_type_names"))}, ) if total_errors: logger.error("Total errors: %i", total_errors) + assert len(stats) == 0 + total_fails = total_errors if fail_on_warning: total_fails += total_warnings diff --git a/src/docstub/_docstrings.py b/src/docstub/_docstrings.py index e8a5ad7..8a6ceba 100644 --- a/src/docstub/_docstrings.py +++ b/src/docstub/_docstrings.py @@ -2,9 +2,9 @@ import logging import traceback +import warnings from dataclasses import dataclass, field from functools import cached_property -from pathlib import Path import click import lark @@ -16,44 +16,60 @@ # types and imports. I think that could very well be done at a higher level, # e.g. in the stubs module. from ._analysis import PyImport, TypeMatcher -from ._report import ContextReporter -from ._utils import DocstubError, escape_qualname +from ._doctype import BlacklistedQualname, Term, TermKind, parse_doctype +from ._report import ContextReporter, Stats +from ._utils import escape_qualname logger: logging.Logger = logging.getLogger(__name__) -here: Path = Path(__file__).parent -grammar_path: Path = here / "doctype.lark" +def _update_qualnames(expr, *, _parents=()): + """Yield and receive names in `expr`. + This generator works as a coroutine. -with grammar_path.open() as file: - _grammar: str = file.read() + Parameters + ---------- + expr : ~.Expr + _parents : tuple of (~._doctype.Expr, ...) -_lark: lark.Lark = lark.Lark(_grammar, propagate_positions=True, strict=True) + Yields + ------ + parents : tuple of (~._doctype.Expr, ...) + name : ~._doctype.Term + Receives + -------- + new_name : str -def _find_one_token(tree, *, name): - """Find token with a specific type name in tree. + Examples + -------- + >>> from docstub._doctype import parse_doctype + >>> expr = parse_doctype("tuple of (tuple or str, ...)") + >>> updater = _update_qualnames(expr) + >>> for parents, name in updater: + ... if name == "tuple" and parents[-1].rule == "union": + ... updater.send("list") + ... if name == "str": + ... updater.send("bytes") + >>> expr.as_code() + 'tuple[list | bytes, ...]' + """ + _parents += (expr,) + children = expr.children.copy() - Parameters - ---------- - tree : lark.Tree - name : str - Name of the token to find in the children of `tree`. + for i, child in enumerate(children): + if hasattr(child, "children"): + yield from _update_qualnames(child, _parents=_parents) - Returns - ------- - token : lark.Token - """ - tokens = [ - child - for child in tree.children - if hasattr(child, "type") and child.type == name - ] - if len(tokens) != 1: - msg = f"expected exactly one Token of type {name}, found {len(tokens)}" - raise ValueError(msg) - return tokens[0] + elif child.kind == TermKind.NAME: + new_name = yield _parents, child + if new_name is not None: + new_term = Term(new_name, kind=child.kind) + expr.children[i] = new_term + # `send` was called, yield `None` to return from `send`, + # otherwise send would return the next child + yield @dataclass(frozen=True, slots=True, kw_only=True) @@ -184,396 +200,6 @@ def _aggregate_annotations(*types): ) -class QualnameIsKeyword(DocstubError): - """Raised when a qualname is a blacklisted Python keyword.""" - - -@lark.visitors.v_args(tree=True) -class DoctypeTransformer(lark.visitors.Transformer): - """Transformer for docstring type descriptions (doctypes). - - Attributes - ---------- - matcher : ~.TypeMatcher - stats : dict[str, Any] - blacklisted_qualnames : ClassVar[frozenset[str]] - All Python keywords [1]_ are blacklisted from use in qualnames except for ``True`` - ``False`` and ``None``. - - References - ---------- - .. [1] https://docs.python.org/3/reference/lexical_analysis.html#keywords - - Examples - -------- - >>> transformer = DoctypeTransformer() - >>> annotation, unknown_names = transformer.doctype_to_annotation( - ... "tuple of (int or ndarray)" - ... ) - >>> annotation.value - 'tuple[int | ndarray]' - >>> unknown_names - [('ndarray', 17, 24)] - """ - - blacklisted_qualnames = frozenset( - { - "await", - "else", - "import", - "pass", - "break", - "except", - "in", - "raise", - "class", - "finally", - "is", - "return", - "and", - "continue", - "for", - "lambda", - "try", - "as", - "def", - "from", - "nonlocal", - "while", - "assert", - "del", - "global", - "not", - "with", - "async", - "elif", - "if", - "or", - "yield", - } - ) - - def __init__(self, *, matcher=None, **kwargs): - """ - Parameters - ---------- - matcher : ~.TypeMatcher, optional - kwargs : dict[Any, Any], optional - Keyword arguments passed to the init of the parent class. - """ - if matcher is None: - matcher = TypeMatcher() - - self.matcher = matcher - - self._reporter = None - self._collected_imports = None - self._unknown_qualnames = None - - super().__init__(**kwargs) - - self.stats = { - "doctype_syntax_errors": 0, - "transformed_doctypes": 0, - } - - def doctype_to_annotation(self, doctype, *, reporter=None): - """Turn a type description in a docstring into a type annotation. - - Parameters - ---------- - doctype : str - The doctype to parse. - reporter : ~.ContextReporter - - Returns - ------- - annotation : Annotation - The parsed annotation. - unknown_qualnames : list[tuple[str, int, int]] - A set containing tuples. Each tuple contains a qualname, its start and its - end index relative to the given `doctype`. - """ - try: - self._reporter = reporter or ContextReporter(logger=logger) - self._collected_imports = set() - self._unknown_qualnames = [] - tree = _lark.parse(doctype) - value = super().transform(tree=tree) - annotation = Annotation( - value=value, imports=frozenset(self._collected_imports) - ) - self.stats["transformed_doctypes"] += 1 - return annotation, self._unknown_qualnames - except ( - lark.exceptions.LexError, - lark.exceptions.ParseError, - QualnameIsKeyword, - ): - self.stats["doctype_syntax_errors"] += 1 - raise - finally: - self._reporter = None - self._collected_imports = None - self._unknown_qualnames = None - - def qualname(self, tree): - """ - Parameters - ---------- - tree : lark.Tree - - Returns - ------- - out : lark.Token - """ - children = tree.children - _qualname = ".".join(children) - - _qualname = self._match_import(_qualname, meta=tree.meta) - - if _qualname in self.blacklisted_qualnames: - msg = ( - f"qualname {_qualname!r} in docstring type description " - "is a reserved Python keyword and not allowed" - ) - raise QualnameIsKeyword(msg) - - _qualname = lark.Token(type="QUALNAME", value=_qualname) - return _qualname - - def rst_role(self, tree): - """ - Parameters - ---------- - tree : lark.Tree - - Returns - ------- - out : lark.Token - """ - qualname = _find_one_token(tree, name="QUALNAME") - return qualname - - def union(self, tree): - """ - Parameters - ---------- - tree : lark.Tree - - Returns - ------- - out : str - """ - out = " | ".join(tree.children) - return out - - def subscription(self, tree): - """ - Parameters - ---------- - tree : lark.Tree - - Returns - ------- - out : str - """ - _container, *_content = tree.children - _content = ", ".join(_content) - assert _content - out = f"{_container}[{_content}]" - return out - - def param_spec(self, tree): - """ - Parameters - ---------- - tree : lark.Tree - - Returns - ------- - out : str - """ - _content = ", ".join(tree.children) - out = f"[{_content}]" - return out - - def callable(self, tree): - """ - Parameters - ---------- - tree : lark.Tree - - Returns - ------- - out : str - """ - _callable, *_content = tree.children - _content = ", ".join(_content) - out = f"{_callable}[{_content}]" - return out - - def natlang_literal(self, tree): - """ - Parameters - ---------- - tree : lark.Tree - - Returns - ------- - out : str - """ - out = ", ".join(tree.children) - out = f"Literal[{out}]" - - if len(tree.children) == 1: - self._reporter.warn( - "Natural language literal with one item: `{%s}`", - tree.children[0], - details=f"Consider using `{out}` to improve readability", - ) - - if self.matcher is not None: - _, py_import = self.matcher.match("Literal") - if py_import.has_import: - self._collected_imports.add(py_import) - return out - - def natlang_container(self, tree): - """ - Parameters - ---------- - tree : lark.Tree - - Returns - ------- - out : str - """ - return self.subscription(tree) - - def natlang_array(self, tree): - """ - Parameters - ---------- - tree : lark.Tree - - Returns - ------- - out : str - """ - name = _find_one_token(tree, name="ARRAY_NAME") - children = [child for child in tree.children if child != name] - if children: - name = f"{name}[{', '.join(children)}]" - return str(name) - - def array_name(self, tree): - """ - Parameters - ---------- - tree : lark.Tree - - Returns - ------- - out : lark.Token - """ - # Treat `array_name` as `qualname`, but mark it as an array name, - # so we know which one to treat as the container in `array_expression` - # This currently relies on a hack that only allows specific names - # in `array_expression` (see `ARRAY_NAME` terminal in gramar) - qualname = self.qualname(tree) - qualname = lark.Token("ARRAY_NAME", str(qualname)) - return qualname - - def shape(self, tree): - """ - Parameters - ---------- - tree : lark.Tree - - Returns - ------- - out : lark.visitors._DiscardType - """ - # self._reporter.debug("Dropping shape information %r", tree) - return lark.Discard - - def optional_info(self, tree): - """ - Parameters - ---------- - tree : lark.Tree - - Returns - ------- - out : lark.visitors._DiscardType - """ - # self._reporter.debug("Dropping optional info %r", tree) - return lark.Discard - - def __default__(self, data, children, meta): - """Unpack children of rule nodes by default. - - Parameters - ---------- - data : lark.Token - The rule-token of the current node. - children : list[lark.Token] - The children of the current node. - meta : lark.tree.Meta - Meta information for the current node. - - Returns - ------- - out : lark.Token or list[lark.Token] - Either a token or list of tokens. - """ - if isinstance(children, list) and len(children) == 1: - out = children[0] - if hasattr(out, "type"): - out.type = data.upper() # Turn rule into "token" - else: - out = children - return out - - def _match_import(self, qualname, *, meta): - """Match `qualname` to known imports or alias to "Incomplete". - - Parameters - ---------- - qualname : str - meta : lark.tree.Meta - Location metadata for the `qualname`, used to report possible errors. - - Returns - ------- - matched_qualname : str - Possibly modified or normalized qualname. - """ - if self.matcher is not None: - annotation_name, py_import = self.matcher.match(qualname) - else: - annotation_name = None - py_import = None - - if py_import and py_import.has_import: - self._collected_imports.add(py_import) - - if annotation_name: - matched_qualname = annotation_name - else: - # Unknown qualname, alias to `Incomplete` - self._unknown_qualnames.append((qualname, meta.start_pos, meta.end_pos)) - matched_qualname = escape_qualname(qualname) - any_alias = PyImport( - from_="_typeshed", - import_="Incomplete", - as_=matched_qualname, - ) - self._collected_imports.add(any_alias) - return matched_qualname - - def _uncombine_numpydoc_params(params): """Split combined NumPyDoc parameters. @@ -600,13 +226,121 @@ def _uncombine_numpydoc_params(params): yield param +def _red_partial_underline(doctype, *, start, stop): + """Underline a part of a string with red '^'. + + Parameters + ---------- + doctype : str + start : int + stop : int + + Returns + ------- + underlined : str + """ + width = stop - start + assert width > 0 + underline = click.style("^" * width, fg="red", bold=True) + underlined = f"{doctype}\n{' ' * start}{underline}\n" + return underlined + + +def doctype_to_annotation(doctype, *, matcher=None, reporter=None, stats=None): + """Convert a type description to a Python-ready type. + + Parameters + ---------- + doctype : str + matcher : ~.TypeMatcher, optional + reporter : ~.ContextReporter, optional + stats : ~.Stats, optional + + Returns + ------- + annotation : Annotation + The transformed type, ready to be inserted into a stub file, with + necessary imports attached. + """ + matcher = matcher or TypeMatcher() + reporter = reporter or ContextReporter(logger=logger) + stats = Stats() if stats is None else stats + + try: + expression = parse_doctype(doctype, reporter=reporter) + stats.inc_counter("transformed_doctypes") + reporter.debug( + "Transformed doctype", details=(" %s\n-> %s", doctype, expression) + ) + + imports = set() + unknown_qualnames = set() + updater = _update_qualnames(expression) + for _, name in updater: + search_name = str(name) + matched_name, py_import = matcher.match(search_name) + if matched_name is None: + assert py_import is None + unknown_qualnames.add((search_name, *name.pos)) + matched_name = escape_qualname(search_name) + _ = updater.send(matched_name) + assert _ is None + + if py_import is None: + incomplete_alias = PyImport( + from_="_typeshed", + import_="Incomplete", + as_=matched_name, + ) + imports.add(incomplete_alias) + elif py_import.has_import: + imports.add(py_import) + + annotation = Annotation(value=str(expression), imports=frozenset(imports)) + + except ( + lark.exceptions.LexError, + lark.exceptions.ParseError, + ) as error: + details = None + if hasattr(error, "get_context"): + details = error.get_context(doctype) + details = details.replace("^", click.style("^", fg="red", bold=True)) + stats.inc_counter("doctype_syntax_errors") + reporter.error("Invalid syntax in docstring type annotation", details=details) + return FallbackAnnotation + + except lark.visitors.VisitError as error: + original_error = error.orig_exc + if isinstance(original_error, BlacklistedQualname): + msg = "Blacklisted keyword argument in doctype" + details = _red_partial_underline( + doctype, + start=error.obj.meta.start_pos, + stop=error.obj.meta.end_pos, + ) + else: + msg = "Unexpected error while parsing doctype" + tb = traceback.format_exception(original_error) + tb = "\n".join(tb) + details = f"doctype: {doctype!r}\n\n{tb}" + reporter.error(msg, details=details) + return FallbackAnnotation + + else: + for name, start_col, stop_col in unknown_qualnames: + details = _red_partial_underline(doctype, start=start_col, stop=stop_col) + reporter.error(f"Unknown name in doctype: {name!r}", details=details) + return annotation + + class DocstringAnnotations: """Collect annotations in a given docstring. Attributes ---------- docstring : str - transformer : DoctypeTransformer + matcher : ~.TypeMatcher reporter : ~.ContextReporter Examples @@ -618,78 +352,34 @@ class DocstringAnnotations: ... b : some invalid syntax ... c : unknown.symbol ... ''' - >>> transformer = DoctypeTransformer() - >>> annotations = DocstringAnnotations(docstring, transformer=transformer) + >>> annotations = DocstringAnnotations(docstring) >>> annotations.parameters.keys() dict_keys(['a', 'b', 'c']) """ - def __init__(self, docstring, *, transformer, reporter=None): + def __init__(self, docstring, *, matcher=None, reporter=None, stats=None): """ Parameters ---------- docstring : str - transformer : DoctypeTransformer + matcher : ~.TypeMatcher, optional reporter : ~.ContextReporter, optional + stats : ~.Stats, optional """ self.docstring = docstring - self.np_docstring = npds.NumpyDocString(docstring) - self.transformer = transformer + self.matcher = matcher or TypeMatcher() + self.stats = Stats() if stats is None else stats if reporter is None: reporter = ContextReporter(logger=logger, line=0) self.reporter = reporter.copy_with(logger=logger) - def _doctype_to_annotation(self, doctype, ds_line=0): - """Convert a type description to a Python-ready type. - - Parameters - ---------- - doctype : str - The type description of a parameter or return value, as extracted from - a docstring. - ds_line : int, optional - The line number relative to the docstring. - - Returns - ------- - annotation : Annotation - The transformed type, ready to be inserted into a stub file, with - necessary imports attached. - """ - reporter = self.reporter.copy_with(line_offset=ds_line) - - try: - annotation, unknown_qualnames = self.transformer.doctype_to_annotation( - doctype, reporter=reporter - ) - reporter.debug( - "Transformed doctype", details=(" %s\n-> %s", doctype, annotation) - ) - - except (lark.exceptions.LexError, lark.exceptions.ParseError) as error: - details = None - if hasattr(error, "get_context"): - details = error.get_context(doctype) - details = details.replace("^", click.style("^", fg="red", bold=True)) - reporter.error( - "Invalid syntax in docstring type annotation", details=details - ) - return FallbackAnnotation - - except lark.visitors.VisitError as e: - tb = "\n".join(traceback.format_exception(e.orig_exc)) - details = f"doctype: {doctype!r}\n\n{tb}" - reporter.error("Unexpected error while parsing doctype", details=details) - return FallbackAnnotation - - else: - for name, start_col, stop_col in unknown_qualnames: - width = stop_col - start_col - error_underline = click.style("^" * width, fg="red", bold=True) - details = f"{doctype}\n{' ' * start_col}{error_underline}\n" - reporter.error(f"Unknown name in doctype: {name!r}", details=details) - return annotation + with warnings.catch_warnings(record=True) as records: + self.np_docstring = npds.NumpyDocString(docstring) + for message in records: + short = "Warning in NumPyDoc while parsing docstring" + details = message.message.args[0] + self.reporter.warn(short, details=details) @cached_property def attributes(self): @@ -858,7 +548,13 @@ def _section_annotations(self, name): continue ds_line = self._find_docstring_line(param.name, param.type) - annotation = self._doctype_to_annotation(param.type, ds_line=ds_line) + + annotation = doctype_to_annotation( + doctype=param.type, + matcher=self.matcher, + reporter=self.reporter.copy_with(line_offset=ds_line), + stats=self.stats, + ) annotated_params[param.name.strip()] = annotation return annotated_params diff --git a/src/docstub/_doctype.py b/src/docstub/_doctype.py new file mode 100644 index 0000000..e0841b4 --- /dev/null +++ b/src/docstub/_doctype.py @@ -0,0 +1,597 @@ +"""Parsing & transformation of doctypes into Python-compatible syntax.""" + +import enum +import keyword +import logging +from collections.abc import Iterable +from dataclasses import dataclass +from pathlib import Path +from textwrap import indent +from typing import Final, Self + +import lark +import lark.visitors + +from ._report import ContextReporter +from ._utils import DocstubError + +logger: Final[logging.Logger] = logging.getLogger(__name__) + + +grammar_path: Final[Path] = Path(__file__).parent / "doctype.lark" + +with grammar_path.open() as file: + _grammar: Final[str] = file.read() + +# TODO try passing `transformer=DoctypeTransformer()`, may be faster [1] +# [1] https://lark-parser.readthedocs.io/en/latest/classes.html#:~:text=after%20the%20parse%2C-,but%20faster,-) +_lark: Final[lark.Lark] = lark.Lark(_grammar, propagate_positions=True) + + +def flatten_recursive(iterable): + """Flatten nested iterables yield the contained strings. + + Parameters + ---------- + iterable : Iterable[Iterable or str] + An iterable containing nested iterables or strings. Only strings are + supported as "leafs" for now. + + Yields + ------ + item : str + + Examples + -------- + >>> nested = ["only", ["strings", ("and", "iterables"), "are", ["allowed"]]] + >>> list(flatten_recursive(nested)) + ['only', 'strings', 'and', 'iterables', 'are', 'allowed'] + """ + for item in iterable: + if isinstance(item, str): + yield item + elif isinstance(item, Iterable): + yield from flatten_recursive(item) + else: + raise ValueError(f"unexpected type: {item!r}") + + +def insert_between(iterable, *, sep): + """Insert `sep` inbetween elements of `iterable`. + + Parameters + ---------- + iterable : Iterable + sep : Any + + Returns + ------- + out : list[Any] + + Examples + -------- + >>> code = ["a", "b", "c", ] + >>> list(insert_between(code, sep=" | ")) + ['a', ' | ', 'b', ' | ', 'c'] + """ + out = [] + for item in iterable: + out.append(item) + out.append(sep) + return out[:-1] + + +class TermKind(enum.StrEnum): + """Encodes the different kinds of :class:`Term`.""" + + # docstub: off + NAME = enum.auto() + LITERAL = enum.auto() + SYNTAX = enum.auto() + # docstub: on + + +class Term(str): + """A terminal / symbol representing an atomic part of a doctype. + + Attributes + ---------- + kind : TermKind + pos : tuple of (int, int) or None + __slots__ : Final[tuple[str, ...]] + + Examples + -------- + >>> ''.join( + ... [ + ... Term("int", kind="name"), + ... Term(" | ", kind="syntax"), + ... Term("float", kind="name") + ... ] + ... ) + 'int | float' + """ + + __slots__ = ("kind", "pos") + + def __new__(cls, value, *, kind, pos=None): + """ + Parameters + ---------- + value : str + kind : TermKind or str + pos : tuple of (int, int), optional + + Returns + ------- + cls : Self + """ + self = super().__new__(cls, value) + self.kind = TermKind(kind) + self.pos = pos + return self + + def __repr__(self) -> str: + return f"{type(self).__name__}('{self}', kind='{self.kind}')" + + def __getnewargs_ex__(self): + """ + Returns + ------- + args : tuple of (Any, ...) + kwargs : dict of {str: Any} + """ + kwargs = {"value": str(self), "kind": self.kind, "pos": self.pos} + return (), kwargs + + +@dataclass(slots=True) +class Expr: + """An expression that forms or is part of a doctype. + + Parameters + ---------- + rule : + The name of the (grammar) rule corresponding to this expression. + children : list of (Expr or Term) + Sub-expressions or terms that make up this expression. + """ + + rule: str + children: list[Self | Term] + + @property + def terms(self): + """All terms in the expression. + + Returns + ------- + terms : list of Term + """ + return list(flatten_recursive(self)) + + @property + def names(self): + """Name terms in the expression. + + Returns + ------- + names : list of Term + """ + return [term for term in self.terms if term.kind == TermKind.NAME] + + @property + def sub_expressions(self): + """Iterate expressions inside the current one. + + Returns + ------- + names : list of Self + """ + cls = type(self) + for child in self.children: + if isinstance(child, cls): + yield child + yield from child.sub_expressions + + def __iter__(self): + """Iterate over children of this expression. + + Yields + ------ + child : Expr or Term + """ + yield from self.children + + def format_tree(self): + """Format full hierarchy as a tree. + + Returns + ------- + formatted : str + """ + formatted_children = ( + c.format_tree() if hasattr(c, "format_tree") else repr(c) + for c in self.children + ) + formatted_children = ",\n".join(formatted_children) + formatted_children = indent(formatted_children, prefix=" ") + return ( + f"{type(self).__name__}({self.rule!r}, children=[\n{formatted_children}])" + ) + + def print_tree(self): + """Print full hierarchy as a tree.""" + print(self.format_tree()) # noqa: T201 + + def __repr__(self) -> str: + return f"<{type(self).__name__}: '{self.as_code()}' rule='{self.rule}'>" + + def __str__(self) -> str: + return "".join(self.terms) + + def as_code(self) -> str: + return str(self) + + +BLACKLISTED_QUALNAMES: Final[set[str]] = set(keyword.kwlist) - {"None", "True", "False"} + + +class BlacklistedQualname(DocstubError): + """Raised when a qualname is a forbidden keyword.""" + + +@lark.visitors.v_args(tree=True) +class DoctypeTransformer(lark.visitors.Transformer): + """Transform parsed doctypes into Python-compatible syntax. + + Examples + -------- + >>> tree = _lark.parse("int or tuple of (int, ...)") + >>> transformer = DoctypeTransformer() + >>> str(transformer.transform(tree=tree)) + 'int | tuple[int, ...]' + """ + + def __init__(self, *, reporter=None): + """ + Parameters + ---------- + reporter : ~.ContextReporter + """ + reporter = reporter or ContextReporter(logger=logger) + self.reporter = reporter.copy_with(logger=logger) + + def start(self, tree): + """ + Parameters + ---------- + tree : lark.Tree + + Returns + ------- + out : Expr + """ + return Expr(rule="start", children=tree.children) + + def qualname(self, tree): + """ + Parameters + ---------- + tree : lark.Tree + + Returns + ------- + out : Term + """ + children = tree.children + _qualname = ".".join(children) + + if _qualname in BLACKLISTED_QUALNAMES: + raise BlacklistedQualname(_qualname) + + _qualname = Term( + _qualname, + kind=TermKind.NAME, + pos=(tree.meta.start_pos, tree.meta.end_pos), + ) + return _qualname + + def rst_role(self, tree): + """ + Parameters + ---------- + tree : lark.Tree + + Returns + ------- + out : Expr + """ + # Drop rst_prefix + children = [c for c in tree.children if isinstance(c, Term)] + expr = Expr(rule="rst_role", children=children) + return expr + + def ELLIPSES(self, token): + """ + Parameters + ---------- + token : lark.Token + + Returns + ------- + out : Term + """ + return Term(token, kind=TermKind.LITERAL) + + def union(self, tree): + """ + Parameters + ---------- + tree : lark.Tree + + Returns + ------- + out : Expr + """ + sep = Term(" | ", kind=TermKind.SYNTAX) + expr = Expr(rule="union", children=insert_between(tree.children, sep=sep)) + return expr + + def subscription(self, tree): + """ + Parameters + ---------- + tree : lark.Tree + + Returns + ------- + out : Expr + """ + assert len(tree.children) > 1 + return self._format_subscription(tree.children, rule="subscription") + + def param_spec(self, tree): + """ + Parameters + ---------- + tree : lark.Tree + + Returns + ------- + out : Expr + """ + sep = Term(", ", kind=TermKind.SYNTAX) + children = [ + Term("[", kind=TermKind.SYNTAX), + *insert_between(tree.children, sep=sep), + Term("]", kind=TermKind.SYNTAX), + ] + expr = Expr(rule="param_spec", children=children) + return expr + + def callable(self, tree): + """ + Parameters + ---------- + tree : lark.Tree + + Returns + ------- + out : Expr + """ + assert len(tree.children) > 1 + return self._format_subscription(tree.children, rule="callable") + + def literal(self, tree): + """ + Parameters + ---------- + tree : lark.Tree + + Returns + ------- + out : Expr + """ + assert len(tree.children) > 1 + out = self._format_subscription(tree.children, rule="literal") + return out + + def natlang_literal(self, tree): + """ + Parameters + ---------- + tree : lark.Tree + + Returns + ------- + out : Expr + """ + items = [ + Term("Literal", kind=TermKind.NAME), + *tree.children, + ] + out = self._format_subscription(items, rule="natlang_literal") + + assert len(tree.children) >= 1 + if len(tree.children) == 1: + details = ("Consider using `%s` to improve readability", "".join(out)) + self.reporter.warn( + "Natural language literal with one item: `{%s}`", + tree.children[0], + details=details, + ) + return out + + def literal_item(self, tree): + """ + Parameters + ---------- + tree : lark.Tree + + Returns + ------- + out : Term + """ + item, *other = tree.children + assert not other + kind = TermKind.LITERAL + if isinstance(item, Term): + kind = item.kind + out = Term(item, kind=kind, pos=(tree.meta.start_pos, tree.meta.end_pos)) + return out + + def natlang_container(self, tree): + """ + Parameters + ---------- + tree : lark.Tree + + Returns + ------- + out : Expr + """ + assert len(tree.children) >= 1 + return self._format_subscription(tree.children, rule="natlang_container") + + def natlang_array(self, tree): + """ + Parameters + ---------- + tree : lark.Tree + + Returns + ------- + out : Expr + """ + return self._format_subscription(tree.children, rule="natlang_array") + + def array_name(self, tree): + """ + Parameters + ---------- + tree : lark.Tree + + Returns + ------- + out : Term + """ + # This currently relies on a hack that only allows specific names + # in `array_expression` (see `ARRAY_NAME` terminal in gramar) + qualname = self.qualname(tree) + return qualname + + def dtype(self, tree): + """ + Parameters + ---------- + tree : lark.Tree + + Returns + ------- + out : Expr + """ + return Expr(rule="dtype", children=tree.children) + + def shape(self, tree): + """ + Parameters + ---------- + tree : lark.Tree + + Returns + ------- + out : lark.visitors._DiscardType + """ + logger.debug("dropping shape information") + return lark.Discard + + def optional_info(self, tree): + """ + Parameters + ---------- + tree : lark.Tree + + Returns + ------- + out : lark.visitors._DiscardType + """ + logger.debug("dropping optional / default info") + return lark.Discard + + def extra_info(self, tree): + """ + Parameters + ---------- + tree : lark.Tree + + Returns + ------- + out : lark.visitors._DiscardType + """ + logger.debug("dropping extra info") + return lark.Discard + + def _format_subscription(self, sequence, *, rule): + """Format a `name[...]` style expression. + + Parameters + ---------- + sequence : Sequence[str] + rule : str, optional + + Returns + ------- + out : Expr + """ + sep = Term(", ", kind=TermKind.SYNTAX) + container, *content = sequence + assert container + + if content: + content = insert_between(content, sep=sep) + children = [ + container, + Term("[", kind=TermKind.SYNTAX), + *content, + Term("]", kind=TermKind.SYNTAX), + ] + else: + children = [container] + + expr = Expr(rule=rule, children=children) + return expr + + +def parse_doctype(doctype, *, reporter=None): + """Turn a type description in a docstring into a type annotation. + + Parameters + ---------- + doctype : str + The doctype to parse. + reporter : ~.ContextReporter, optional + + Returns + ------- + parsed : Expr + + Raises + ------ + lark.exceptions.VisitError + Raised when the transformation is interrupted by an exception. + See :cls:`lark.exceptions.VisitError`. + BlacklistedQualname + Raised when a qualname is a forbidden keyword. + + Examples + -------- + >>> parse_doctype("tuple of (int, ...)") + + >>> parse_doctype("ndarray of dtype (float or int)") + + """ + tree = _lark.parse(doctype) + transformer = DoctypeTransformer(reporter=reporter) + expression = transformer.transform(tree=tree) + return expression diff --git a/src/docstub/_report.py b/src/docstub/_report.py index 51f2514..831b1f3 100644 --- a/src/docstub/_report.py +++ b/src/docstub/_report.py @@ -2,6 +2,7 @@ import dataclasses import logging +from collections.abc import Mapping from pathlib import Path from textwrap import indent @@ -93,7 +94,7 @@ def report(self, short, *args, log_level, details=None, **log_kw): Optional formatting arguments for `short`. log_level : int The logging level. - details : str, optional + details : str or tuple of (Any, ...), optional An optional multiline report with more details. **log_kw : Any """ @@ -116,7 +117,7 @@ def debug(self, short, *args, details=None, **log_kw): A short summarizing report that shouldn't wrap over multiple lines. *args : Any Optional formatting arguments for `short`. - details : str, optional + details : str or tuple of (Any, ...), optional An optional multiline report with more details. **log_kw : Any """ @@ -133,7 +134,7 @@ def info(self, short, *args, details=None, **log_kw): A short summarizing report that shouldn't wrap over multiple lines. *args : Any Optional formatting arguments for `short`. - details : str, optional + details : str or tuple of (Any, ...), optional An optional multiline report with more details. **log_kw : Any """ @@ -150,7 +151,7 @@ def warn(self, short, *args, details=None, **log_kw): A short summarizing report that shouldn't wrap over multiple lines. *args : Any Optional formatting arguments for `short`. - details : str, optional + details : str or tuple of (Any, ...), optional An optional multiline report with more details. **log_kw : Any """ @@ -167,7 +168,7 @@ def error(self, short, *args, details=None, **log_kw): A short summarizing report that shouldn't wrap over multiple lines. *args : Any Optional formatting arguments for `short`. - details : str, optional + details : str or tuple of (Any, ...), optional An optional multiline report with more details. **log_kw : Any """ @@ -184,7 +185,7 @@ def critical(self, short, *args, details=None, **log_kw): A short summarizing report that shouldn't wrap over multiple lines. *args : Any Optional formatting arguments for `short`. - details : str, optional + details : str or tuple of (Any, ...), optional An optional multiline report with more details. **log_kw : Any """ @@ -446,3 +447,171 @@ def setup_logging(*, verbosity, group_errors): logging.captureWarnings(True) return reporter, log_counter + + +def update_with_add_values(*mappings, out=None): + """Merge mappings while adding together their values. + + Parameters + ---------- + mappings : Mapping[Hashable, int or Sequence] + out : dict, optional + + Returns + ------- + out : dict, optional + + Examples + -------- + >>> stats_1 = {"errors": 2, "warnings": 0, "unknown": ["string", "integer"]} + >>> stats_2 = {"unknown": ["func"], "errors": 1} + >>> update_with_add_values(stats_1, stats_2) + {'errors': 3, 'warnings': 0, 'unknown': ['string', 'integer', 'func']} + + >>> _ = update_with_add_values(stats_1, out=stats_2) + >>> stats_2 + {'unknown': ['func', 'string', 'integer'], 'errors': 3, 'warnings': 0} + + >>> update_with_add_values({"lines": (1, 33)}, {"lines": (42,)}) + {'lines': (1, 33, 42)} + """ + if out is None: + out = {} + for m in mappings: + for key, value in m.items(): + if hasattr(value, "__add__"): + out[key] = out.setdefault(key, type(value)()) + value + else: + raise TypeError(f"Don't know how to 'add' {value!r}") + return out + + +class Stats(Mapping): + """Collect statistics + + Examples + -------- + >>> stats = Stats() + >>> stats.inc_counter("counter") + >>> stats.inc_counter("counter", inc=2) + >>> stats.append_to_list("names", "Foo") + >>> stats.append_to_list("names", "Bar") + >>> dict(stats) + {'counter': 3, 'names': ['Foo', 'Bar']} + + >>> other_stats = Stats( + ... {"counter": 3, "modules": ["pathlib"], "names": ["baz"]} + ... ) + >>> merged = stats.merge(stats, other_stats) + >>> dict(merged) + {'counter': 6, 'names': ['Foo', 'Bar', 'baz'], 'modules': ['pathlib']} + """ + + class _UNSET: + """Sentinel signaling that an argument wasn't set.""" + + def __init__(self, stats=None): + """ + Parameters + ---------- + stats : dict[str, list[Any] or str] + """ + self._stats = {} if stats is None else stats + + def __getitem__(self, key): + """Retrieve a statistic. + + Parameters + ---------- + key : str + + Returns + ------- + value : list[Any] or int + """ + return self._stats[key] + + def __iter__(self): + """ + Returns + ------- + out : Iterator + """ + yield from self._stats + + def __len__(self) -> int: + return len(self._stats) + + def inc_counter(self, key, *, inc=1): + """Increase counter of a statistic. + + Parameters + ---------- + key : str + inc : int, optional + """ + if key not in self._stats: + self._stats[key] = 0 + assert isinstance(inc, int) + self._stats[key] += inc + + def append_to_list(self, key, value): + """Append `value` to statistic. + + Parameters + ---------- + key : str + value : Any + """ + if key not in self._stats: + self._stats[key] = [] + self._stats[key].append(value) + + @classmethod + def merge(cls, *stats): + """ + + Parameters + ---------- + *stats : Self + + Returns + ------- + merged : Self + """ + out = update_with_add_values(*stats) + out = cls(out) + return out + + def __repr__(self) -> str: + keys = ", ".join(self._stats.keys()) + return f"<{type(self).__name__}: {keys}>" + + def pop(self, key, *, default=_UNSET): + """Return and remove a statistic from this container. + + Parameters + ---------- + key : str + default : Any, optional + If given, falls back to the given default value if `key` is not + found. + + Returns + ------- + value : list[Any] or int + """ + if key in self._stats or default is self._UNSET: + return self._stats.pop(key) + return default + + def pop_all(self): + """Return and remove all statistics from this container. + + Returns + ------- + stats : dict[str, list[Any] or int] + """ + out = self._stats + self._stats = {} + return out diff --git a/src/docstub/_stubs.py b/src/docstub/_stubs.py index 94ce476..5d1f68c 100644 --- a/src/docstub/_stubs.py +++ b/src/docstub/_stubs.py @@ -16,9 +16,9 @@ import libcst.matchers as cstm from ._analysis import PyImport -from ._docstrings import DocstringAnnotations, DoctypeTransformer, FallbackAnnotation -from ._report import ContextReporter -from ._utils import module_name_from_path, update_with_add_values +from ._docstrings import DocstringAnnotations, FallbackAnnotation +from ._report import ContextReporter, Stats +from ._utils import module_name_from_path logger: logging.Logger = logging.getLogger(__name__) @@ -328,8 +328,9 @@ def __init__(self, *, matcher=None): ---------- matcher : ~.TypeMatcher """ - self.transformer = DoctypeTransformer(matcher=matcher) + self.matcher = matcher self.reporter = ContextReporter(logger=logger) + self.stats = Stats() # Relevant docstring for the current context self._scope_stack = None # Entered module, class or function scopes self._pytypes_stack = None # Collected pytypes for each stack @@ -355,10 +356,10 @@ def current_source(self, value): value : Path """ self._current_source = value - # TODO pass current_source directly when using the transformer / matcher + # TODO pass current_source directly when using the matcher # instead of assigning it here! - if self.transformer is not None and self.transformer.matcher is not None: - self.transformer.matcher.current_file = value + if self.matcher is not None: + self.matcher.current_file = value @property def is_inside_function_def(self): @@ -408,26 +409,6 @@ def python_to_stub(self, source, *, module_path=None): self._required_imports = None self.current_source = None - def collect_stats(self, *, reset_after=True): - """Return statistics from processing files. - - Parameters - ---------- - reset_after : bool, optional - Whether to reset counters and statistics after returning. - - Returns - ------- - stats : dict of {str: int or list[str]} - """ - collected = [self.transformer.stats, self.transformer.matcher.stats] - merged = update_with_add_values(*collected) - if reset_after is True: - for stats in collected: - for key in stats: - stats[key] = type(stats[key])() - return merged - def visit_ClassDef(self, node): """Collect pytypes from class docstring and add scope to stack. @@ -931,13 +912,15 @@ def _annotations_from_node(self, node): try: annotations = DocstringAnnotations( docstring_value, - transformer=self.transformer, + matcher=self.matcher, reporter=reporter, + stats=self.stats, ) except (SystemExit, KeyboardInterrupt): raise except Exception: reporter.error("could not parse docstring", exc_info=True) + return annotations def _create_annotated_assign( diff --git a/src/docstub/_utils.py b/src/docstub/_utils.py index 5ed4fd9..9cedf7d 100644 --- a/src/docstub/_utils.py +++ b/src/docstub/_utils.py @@ -159,43 +159,6 @@ def pyfile_checksum(path): return key -def update_with_add_values(*mappings, out=None): - """Merge mappings while adding together their values. - - Parameters - ---------- - mappings : Mapping[Hashable, int or Sequence] - out : dict, optional - - Returns - ------- - out : dict, optional - - Examples - -------- - >>> stats_1 = {"errors": 2, "warnings": 0, "unknown": ["string", "integer"]} - >>> stats_2 = {"unknown": ["func"], "errors": 1} - >>> update_with_add_values(stats_1, stats_2) - {'errors': 3, 'warnings': 0, 'unknown': ['string', 'integer', 'func']} - - >>> _ = update_with_add_values(stats_1, out=stats_2) - >>> stats_2 - {'unknown': ['func', 'string', 'integer'], 'errors': 3, 'warnings': 0} - - >>> update_with_add_values({"lines": (1, 33)}, {"lines": (42,)}) - {'lines': (1, 33, 42)} - """ - if out is None: - out = {} - for m in mappings: - for key, value in m.items(): - if hasattr(value, "__add__"): - out[key] = out.setdefault(key, type(value)()) + value - else: - raise TypeError(f"Don't know how to 'add' {value!r}") - return out - - class DocstubError(Exception): """An error raised by docstub.""" diff --git a/src/docstub/doctype.lark b/src/docstub/doctype.lark index 904f735..4bd6826 100644 --- a/src/docstub/doctype.lark +++ b/src/docstub/doctype.lark @@ -6,19 +6,20 @@ // https://lark-parser.readthedocs.io/en/latest/grammar.html -?start: annotation_with_meta +start: _annotation_with_meta // The basic structure of a full docstring annotation as it comes after the // `name : `. It includes additional meta information that is optional and // currently ignored. -?annotation_with_meta: type ("," optional_info)? +_annotation_with_meta: _type ("," optional_info)? // A type annotation. Can range from a simple qualified name to a complex // nested construct of types. -?type: qualname +_type: qualname | union + | literal | subscription | callable | natlang_literal @@ -37,7 +38,7 @@ qualname: (/~/ ".")? (NAME ".")* NAME // An union of different types, joined either by "or" or "|". -union: type (_OR type)+ +union: _type (_OR _type)+ // Operator used in unions. @@ -45,18 +46,18 @@ _OR: "or" | "|" // An expression where an object is subscribed with "A[v, ...]". -subscription: qualname "[" type ("," type)* ("," ELLIPSES)? "]" +subscription: qualname "[" _type ("," _type)* ("," ELLIPSES)? "]" // An expression describing a callable like "Callable[[int], str]" // [1] https://typing.python.org/en/latest/spec/callables.html#callable // -callable: qualname "[" ELLIPSES ("," type)? "]" - | qualname "[" param_spec "," type "]" +callable: qualname "[" ELLIPSES ("," _type)? "]" + | qualname "[" param_spec "," _type "]" // The parameter specification inside a callable expression. -param_spec: "[" type? ("," type)* ("," ELLIPSES)? "]" +param_spec: "[" _type? ("," _type)* ("," ELLIPSES)? "]" // Allow Python's ellipses object @@ -68,16 +69,23 @@ ELLIPSES: "..." natlang_literal: "{" literal_item ("," literal_item)* "}" +// A literal expression as supported by Python proper. The rule "subscription" +// isn't allowed to contain "literal_items", so we need to define this. +// Assign a higher priority so that things like `Literal[Some.ENUM]` is marked +// as a literal expression. +literal.1: qualname "[" literal_item ("," literal_item)* "]" + + // An single item in a literal expression (or `optional`). We must also allow // for qualified names, since a "class" or enum can be used as a literal too. -?literal_item: ELLIPSES | STRING | SIGNED_NUMBER | qualname +literal_item: STRING | SIGNED_NUMBER | qualname // Natural language forms of the subscription expression for containers. // These forms allow nesting with other expressions. But it's discouraged to do // so extensively to maintain readability. natlang_container: qualname "of" qualname _PLURAL_S? - | qualname "of" "(" type ")" + | qualname "of" "(" _type ")" | _natlang_tuple | _natlang_mapping @@ -90,12 +98,12 @@ _PLURAL_S: /(? str: - return name.replace("-", "_").replace(".", "_") - - doctype = fmt.format(name=name, dtype=dtype, shape=shape) - expected = expected_fmt.format( - name=escape(name), dtype=escape(dtype), shape=shape - ) - - transformer = DoctypeTransformer() - annotation, _ = transformer.doctype_to_annotation(doctype) - - assert annotation.value == expected - # fmt: on - - @pytest.mark.parametrize( - ("doctype", "expected"), - [ - ("ndarray of dtype (int or float)", "ndarray[int | float]"), - ], - ) - def test_natlang_array_specific(self, doctype, expected): - transformer = DoctypeTransformer() - annotation, _ = transformer.doctype_to_annotation(doctype) - assert annotation.value == expected - - @pytest.mark.parametrize("shape", ["(-1, 3)", "(1.0, 2)", "-3D", "-2-D"]) - def test_natlang_array_invalid_shape(self, shape): - doctype = f"array of shape {shape}" - transformer = DoctypeTransformer() - with pytest.raises(lark.exceptions.UnexpectedInput): - _ = transformer.doctype_to_annotation(doctype) - - def test_unknown_name(self): +class Test_doctype_to_annotation: + def test_unknown_name(self, caplog): # Simple unknown name is aliased to typing.Any - transformer = DoctypeTransformer() - annotation, unknown_names = transformer.doctype_to_annotation("a") + annotation = doctype_to_annotation("a") assert annotation.value == "a" assert annotation.imports == { PyImport(import_="Incomplete", from_="_typeshed", as_="a") } - assert unknown_names == [("a", 0, 1)] + assert caplog.messages == ["Unknown name in doctype: 'a'"] - def test_unknown_qualname(self): + def test_unknown_qualname(self, caplog): # Unknown qualified name is escaped and aliased to typing.Any as well - transformer = DoctypeTransformer() - annotation, unknown_names = transformer.doctype_to_annotation("a.b") + annotation = doctype_to_annotation("a.b") assert annotation.value == "a_b" assert annotation.imports == { PyImport(import_="Incomplete", from_="_typeshed", as_="a_b") } - assert unknown_names == [("a.b", 0, 3)] + assert caplog.messages == ["Unknown name in doctype: 'a.b'"] - def test_multiple_unknown_names(self): + def test_multiple_unknown_names(self, caplog): # Multiple names are aliased to typing.Any - transformer = DoctypeTransformer() - annotation, unknown_names = transformer.doctype_to_annotation("a.b of c") + annotation = doctype_to_annotation("a.b of c") assert annotation.value == "a_b[c]" assert annotation.imports == { PyImport(import_="Incomplete", from_="_typeshed", as_="a_b"), PyImport(import_="Incomplete", from_="_typeshed", as_="c"), } - assert unknown_names == [("a.b", 0, 3), ("c", 7, 8)] + assert sorted(caplog.messages) == [ + "Unknown name in doctype: 'a.b'", + "Unknown name in doctype: 'c'", + ] class Test_DocstringAnnotations: def test_empty_docstring(self): docstring = dedent("""No sections in this docstring.""") - transformer = DoctypeTransformer() - annotations = DocstringAnnotations(docstring, transformer=transformer) + annotations = DocstringAnnotations(docstring) assert annotations.attributes == {} assert annotations.parameters == {} assert annotations.returns is None @@ -378,8 +94,7 @@ def test_parameters(self, doctype, expected): b : """ ) - transformer = DoctypeTransformer() - annotations = DocstringAnnotations(docstring, transformer=transformer) + annotations = DocstringAnnotations(docstring) assert len(annotations.parameters) == 1 assert annotations.parameters["a"].value == expected assert "b" not in annotations.parameters @@ -400,8 +115,7 @@ def test_returns(self, doctypes, expected): b : {} """ ).format(*doctypes) - transformer = DoctypeTransformer() - annotations = DocstringAnnotations(docstring, transformer=transformer) + annotations = DocstringAnnotations(docstring) assert annotations.returns is not None assert annotations.returns.value == expected @@ -414,8 +128,7 @@ def test_yields(self, caplog): b : str """ ) - transformer = DoctypeTransformer() - annotations = DocstringAnnotations(docstring, transformer=transformer) + annotations = DocstringAnnotations(docstring) assert annotations.returns is not None assert annotations.returns.value == "Generator[tuple[int, str]]" assert annotations.returns.imports == { @@ -436,8 +149,7 @@ def test_receives(self, caplog): d : bytes """ ) - transformer = DoctypeTransformer() - annotations = DocstringAnnotations(docstring, transformer=transformer) + annotations = DocstringAnnotations(docstring) assert annotations.returns is not None assert ( annotations.returns.value @@ -465,8 +177,7 @@ def test_full_generator(self, caplog): e : bool """ ) - transformer = DoctypeTransformer() - annotations = DocstringAnnotations(docstring, transformer=transformer) + annotations = DocstringAnnotations(docstring) assert annotations.returns is not None assert annotations.returns.value == ( "Generator[tuple[int, str], tuple[float, bytes], bool]" @@ -488,8 +199,7 @@ def test_yields_and_returns(self, caplog): e : bool """ ) - transformer = DoctypeTransformer() - annotations = DocstringAnnotations(docstring, transformer=transformer) + annotations = DocstringAnnotations(docstring) assert annotations.returns is not None assert annotations.returns.value == ("Generator[tuple[int, str], None, bool]") assert annotations.returns.imports == { @@ -505,8 +215,7 @@ def test_duplicate_parameters(self, caplog): a : str """ ) - transformer = DoctypeTransformer() - annotations = DocstringAnnotations(docstring, transformer=transformer) + annotations = DocstringAnnotations(docstring) assert len(annotations.parameters) == 1 assert annotations.parameters["a"].value == "int" @@ -519,8 +228,7 @@ def test_duplicate_returns(self, caplog): a : str """ ) - transformer = DoctypeTransformer() - annotations = DocstringAnnotations(docstring, transformer=transformer) + annotations = DocstringAnnotations(docstring) assert annotations.returns is not None assert annotations.returns is not None assert annotations.returns.value == "int" @@ -534,8 +242,7 @@ def test_args_kwargs(self): **kwargs : str """ ) - transformer = DoctypeTransformer() - annotations = DocstringAnnotations(docstring, transformer=transformer) + annotations = DocstringAnnotations(docstring) assert "args" in annotations.parameters assert "*args" not in annotations.parameters assert "kwargs" in annotations.parameters @@ -553,8 +260,7 @@ def test_missing_whitespace(self, caplog): a: int """ ) - transformer = DoctypeTransformer() - annotations = DocstringAnnotations(docstring, transformer=transformer) + annotations = DocstringAnnotations(docstring) assert annotations.parameters["a"].value == "int" assert len(caplog.records) == 1 assert "Possibly missing whitespace" in caplog.text @@ -568,8 +274,7 @@ def test_combined_numpydoc_params(self): d, e : """ ) - transformer = DoctypeTransformer() - annotations = DocstringAnnotations(docstring, transformer=transformer) + annotations = DocstringAnnotations(docstring) assert len(annotations.parameters) == 3 assert annotations.parameters["a"].value == "bool" assert annotations.parameters["b"].value == "bool" @@ -577,3 +282,23 @@ def test_combined_numpydoc_params(self): assert "d" not in annotations.parameters assert "e" not in annotations.parameters + + @pytest.mark.filterwarnings("default:Unknown section:UserWarning:numpydoc") + def test_unknown_section_logged(self, caplog): + docstring = dedent( + """ + Parameters + ---------- + a : bool + + To Do + ----- + An unknown section + """ + ) + annotations = DocstringAnnotations(docstring) + assert len(annotations.parameters) == 1 + assert annotations.parameters["a"].value == "bool" + + assert caplog.messages == ["Warning in NumPyDoc while parsing docstring"] + assert caplog.records[0].details == "Unknown section To Do" diff --git a/tests/test_doctype.py b/tests/test_doctype.py new file mode 100644 index 0000000..3d4b74f --- /dev/null +++ b/tests/test_doctype.py @@ -0,0 +1,304 @@ +import logging + +import lark +import lark.exceptions +import pytest + +from docstub._doctype import BLACKLISTED_QUALNAMES, parse_doctype + + +class Test_parse_doctype: + @pytest.mark.parametrize( + "doctype", + [ + "((float))", + "(float,)", + "(, )", + "...", + "(..., ...)", + "{}", + "{:}", + "{a:}", + "{:b}", + "{'a',}", + "a or (b or c)", + ",, optional", + ], + ) + def test_edge_case_errors(self, doctype): + with pytest.raises(lark.exceptions.UnexpectedInput): + parse_doctype(doctype) + + @pytest.mark.parametrize("doctype", BLACKLISTED_QUALNAMES) + def test_reserved_keywords(self, doctype): + with pytest.raises(lark.exceptions.VisitError): + parse_doctype(doctype) + + @pytest.mark.parametrize( + ("doctype", "expected"), + [ + ("int or float", "int | float"), + ("int or float or str", "int | float | str"), + ], + ) + def test_natlang_union(self, doctype, expected): + expr = parse_doctype(doctype) + assert expr.as_code() == expected + + @pytest.mark.parametrize( + ("doctype", "expected"), + [ + # Conventional + ("list[float]", "list[float]"), + ("dict[str, Union[int, str]]", "dict[str, Union[int, str]]"), + ("tuple[int, ...]", "tuple[int, ...]"), + ("Sequence[int | float]", "Sequence[int | float]"), + # Natural language variant with "of" and optional plural "(s)" + ("list of int", "list[int]"), + ("list of int(s)", "list[int]"), + # Natural tuple variant + ("tuple of (float, int, str)", "tuple[float, int, str]"), + ("tuple of (float, ...)", "tuple[float, ...]"), + # Natural dict variant + ("dict of {str: int}", "dict[str, int]"), + ("dict of {str: int | float}", "dict[str, int | float]"), + ("dict of {str: int or float}", "dict[str, int | float]"), + ("dict[list of str]", "dict[list[str]]"), + ], + ) + def test_subscription(self, doctype, expected): + expr = parse_doctype(doctype) + assert expr.as_code() == expected + + @pytest.mark.parametrize( + ("doctype", "expected"), + [ + # Natural language variant with "of" and optional plural "(s)" + ("list of int", "list[int]"), + ("list of int(s)", "list[int]"), + ("list of (int or float)", "list[int | float]"), + ("list of (list of int)", "list[list[int]]"), + # Natural tuple variant + ("tuple of (float, int, str)", "tuple[float, int, str]"), + ("tuple of (float, ...)", "tuple[float, ...]"), + # Natural dict variant + ("dict of {str: int}", "dict[str, int]"), + ("dict of {str: int | float}", "dict[str, int | float]"), + ("dict of {str: int or float}", "dict[str, int | float]"), + # Nesting is possible but probably rarely a good idea + ("list of (list of int(s))", "list[list[int]]"), + ("tuple of (tuple of (float, ...), ...)", "tuple[tuple[float, ...], ...]"), + ("dict of {str: dict of {str: float}}", "dict[str, dict[str, float]]"), + ("dict of {str: list of (list of int(s))}", "dict[str, list[list[int]]]"), + ], + ) + def test_natlang_container(self, doctype, expected): + expr = parse_doctype(doctype) + assert expr.as_code() == expected + + @pytest.mark.parametrize( + "doctype", + [ + "list of int (s)", + "list of ((float))", + "list of (float,)", + "list of (, )", + "list of ...", + "list of (..., ...)", + "dict of {}", + "dict of {:}", + "dict of {a:}", + "dict of {:b}", + ], + ) + def test_subscription_error(self, doctype): + with pytest.raises(lark.exceptions.UnexpectedInput): + parse_doctype(doctype) + + @pytest.mark.parametrize( + ("doctype"), + [ + "Literal[0]", + "Literal[-1, 1]", + "Literal[None]", + "Literal[True, False]", + """Literal['a', "bar"]""", + # Enum + "Literal[SomeEnum.FIRST]", + "Literal[SomeEnum.FIRST, 1]", + "Literal[SomeEnum.FIRST, 2]", + "Literal[SomeEnum.FIRST, 3]", + # Nesting + "dict[Literal['a', 'b'], int]", + # Custom qualname for literal + "MyLiteral[0]", + "MyLiteral[SomeEnum.FIRST]", + ], + ) + def test_literals(self, doctype): + expr = parse_doctype(doctype) + assert expr.as_code() == doctype + assert "literal" in [e.rule for e in expr.sub_expressions] + + @pytest.mark.parametrize( + ("doctype", "expected"), + [ + ("{0}", "Literal[0]"), + ("{-1, 1}", "Literal[-1, 1]"), + ("{None}", "Literal[None]"), + ("{True, False}", "Literal[True, False]"), + ("""{'a', "bar"}""", """Literal['a', "bar"]"""), + # Enum + ("{SomeEnum.FIRST}", "Literal[SomeEnum.FIRST]"), + ("{`SomeEnum.FIRST`, 1}", "Literal[SomeEnum.FIRST, 1]"), + ("{:ref:`SomeEnum.FIRST`, 2}", "Literal[SomeEnum.FIRST, 2]"), + ("{:py:ref:`SomeEnum.FIRST`, 3}", "Literal[SomeEnum.FIRST, 3]"), + # Nesting + ("dict[{'a', 'b'}, int]", "dict[Literal['a', 'b'], int]"), + # These aren't officially valid as an argument to `Literal` (yet) + # https://typing.python.org/en/latest/spec/literal.html + # TODO figure out how docstub should deal with these + ("{-2., 1.}", "Literal[-2., 1.]"), + pytest.param( + "{-inf, inf, nan}", + "Literal[, 1.]", + marks=pytest.mark.xfail(reason="unsure how to support"), + ), + ], + ) + def test_natlang_literals(self, doctype, expected): + expr = parse_doctype(doctype) + assert expr.as_code() == expected + assert "natlang_literal" in [e.rule for e in expr.sub_expressions] + + def test_single_natlang_literal_warning(self, caplog): + expr = parse_doctype("{True}") + assert expr.as_code() == "Literal[True]" + assert caplog.messages == ["Natural language literal with one item: `{True}`"] + assert caplog.records[0].levelno == logging.WARNING + assert caplog.records[0].details == ( + "Consider using `%s` to improve readability", + "Literal[True]", + ) + + @pytest.mark.parametrize( + ("doctype", "expected"), + [ + ("int", "int"), + ("int | None", "int | None"), + ("tuple of (int, float)", "tuple[int, float]"), + ("{'a', 'b'}", "Literal['a', 'b']"), + ], + ) + @pytest.mark.parametrize( + "optional_info", + [ + "", + ", optional", + ", default -1", + ", default: -1", + ", default = 1", + ", in range (0, 1), optional", + ", optional, in range [0, 1]", + ", see parameter `image`, optional", + ], + ) + def test_optional_info(self, doctype, expected, optional_info): + doctype_with_optional = doctype + optional_info + expr = parse_doctype(doctype_with_optional) + assert expr.as_code() == expected + + @pytest.mark.parametrize( + "doctype", + [ + "Callable[[int], str]", + "some_func[[int], str]", + "Callable[[int, float, byte], list[str]]", + "Callable[..., str]", + "Callable[[], str]", + "Callback[...]", + ], + ) + def test_callable(self, doctype): + expr = parse_doctype(doctype) + assert expr.as_code() == doctype + assert "callable" in [e.rule for e in expr.sub_expressions] + + @pytest.mark.parametrize( + "doctype", + [ + "Callable[Concatenate[int, float], str]", + "Callable[Concatenate[int, ...], str]", + "Callable[P, str]", + ], + ) + def test_callable_subscriptions_form(self, doctype): + expr = parse_doctype(doctype) + assert expr.as_code() == doctype + assert "callable" not in [e.rule for e in expr.sub_expressions] + + @pytest.mark.parametrize( + "doctype", + [ + "Callable[[...], int]", + "Callable[[..., str], int]", + "Callable[[float, str], int, byte]", + ], + ) + def test_callable_error(self, doctype): + with pytest.raises(lark.exceptions.UnexpectedInput): + parse_doctype(doctype) + + @pytest.mark.parametrize( + ("doctype", "expected"), + [ + ("`Generator`", "Generator"), + (":class:`Generator`", "Generator"), + (":py:class:`Generator`", "Generator"), + (":py:class:`Generator`[int]", "Generator[int]"), + (":py:ref:`~.Foo`[int]", "~.Foo[int]"), + ("list[:py:class:`Generator`]", "list[Generator]"), + ], + ) + def test_rst_role(self, doctype, expected): + expr = parse_doctype(doctype) + assert expr.as_code() == expected + + # fmt: off + @pytest.mark.parametrize( + ("fmt", "expected_fmt"), + [ + ("{name} of shape {shape} and dtype {dtype}", "{name}[{dtype}]"), + ("{name} of dtype {dtype} and shape {shape}", "{name}[{dtype}]"), + ], + ) + @pytest.mark.parametrize("name", ["array", "ndarray", "array-like", "array_like"]) + @pytest.mark.parametrize("dtype", ["int", "np.int8"]) + @pytest.mark.parametrize("shape", + ["(2, 3)", "(N, m)", "3D", "2-D", "(N, ...)", "([P,] M, N)"] + ) + def test_natlang_array(self, fmt, expected_fmt, name, dtype, shape): + doctype = fmt.format(name=name, dtype=dtype, shape=shape) + expected = expected_fmt.format(name=name, dtype=dtype, shape=shape) + expr = parse_doctype(doctype) + assert expr.as_code() == expected + assert "natlang_array" in [e.rule for e in expr.sub_expressions] + # fmt: on + + @pytest.mark.parametrize( + ("doctype", "expected"), + [ + ("ndarray of dtype (int or float)", "ndarray[int | float]"), + ("ndarray of shape (M, N)", "ndarray"), + ], + ) + def test_natlang_array_specific(self, doctype, expected): + expr = parse_doctype(doctype) + assert expr.as_code() == expected + assert "natlang_array" in [e.rule for e in expr.sub_expressions] + + @pytest.mark.parametrize("shape", ["(-1, 3)", "(1.0, 2)", "-3D", "-2-D"]) + def test_natlang_array_invalid_shape(self, shape): + doctype = f"array of shape {shape}" + with pytest.raises(lark.exceptions.UnexpectedInput): + _ = parse_doctype(doctype) diff --git a/tests/test_stubs.py b/tests/test_stubs.py index f0a18e1..ad4ead2 100644 --- a/tests/test_stubs.py +++ b/tests/test_stubs.py @@ -717,7 +717,7 @@ class Foo: c: str = None _: KW_ONLY d: dict[str, Any] = field(default_factory=dict) - e: InitVar[tuple] = tuple() + e: InitVar[tuple] = () f: ClassVar g: ClassVar[float] h: Final[ClassVar[int]] = 1