From aa73c38abf591f96e212b656495024277e63ba5c Mon Sep 17 00:00:00 2001 From: jorenham Date: Wed, 11 Jun 2025 18:28:16 +0200 Subject: [PATCH 1/7] typing stubs --- threadpoolctl.pyi | 286 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 286 insertions(+) create mode 100644 threadpoolctl.pyi diff --git a/threadpoolctl.pyi b/threadpoolctl.pyi new file mode 100644 index 00000000..e6ab4441 --- /dev/null +++ b/threadpoolctl.pyi @@ -0,0 +1,286 @@ +import abc +import ctypes +import sys +import types +from collections.abc import Iterable +from contextlib import ContextDecorator +from typing import Any, ClassVar, Final, Literal, final, type_check_only + +from _typeshed import Incomplete +from typing_extensions import Never, Self, TypeAlias, TypedDict, override + +_ThreadingLayerOpenBLAS: TypeAlias = _ThreadingLayerBLIS | Literal["unknown"] +_ThreadingLayerBLIS: TypeAlias = Literal["openmp", "pthreads", "disabled"] +_ThreadingLayerMKL: TypeAlias = Literal[ + "intel", "sequential", "pgi", "gnu", "tbb", "not specified" +] + +_ToLimits: TypeAlias = ( + Literal["sequential_blas_under_openmp"] + | int + | list[_InfoDict] + | dict[str, Any] + | ThreadpoolController +) + +@type_check_only +class _InfoDict(TypedDict): + user_api: str + internal_api: str + num_threads: int | None + prefix: str + filepath: str + version: str | None + +@final +@type_check_only +class _OpenMPSequentialBlasDict(TypedDict): + limits: Literal[1] | None + user_api: Literal["blas"] | None + +### + +__version__: Final[str] = ... +__all__ = [ + "LibController", + "ThreadpoolController", + "register", + "threadpool_info", + "threadpool_limits", +] + +_SYSTEM_UINT: Final[type[ctypes.c_uint32 | ctypes.c_uint64]] = ... +_SYSTEM_UINT_HALF: Final[type[ctypes.c_uint16 | ctypes.c_uint32]] = ... +_RTLD_NOLOAD: Final[int] = ... + +@final +class _dl_phdr_info(ctypes.Structure): + _fields_: ClassVar = ... + +class LibController(abc.ABC): + user_api: ClassVar[str] # abstract + internal_api: ClassVar[str] # abstract + filename_prefixes: ClassVar[tuple[str, ...]] # abstract + + parent: Final[LibController | None] + prefix: Final[str | None] + filepath: Final[str | None] + dynlib: Final[ctypes.CDLL] + _symbol_prefix: Final[str] + _symbol_affix: Final[str] + version: Final[str | None] + + @final + def __init__( + self, + /, + *, + filepath: str | None = None, + prefix: str | None = None, + parent: LibController | None = None, + ) -> None: ... + def info(self) -> dict[str, Any]: ... + def set_additional_attributes(self) -> None: ... + @property + def num_threads(self) -> int | None: ... + @abc.abstractmethod + def get_num_threads(self) -> int | None: ... + @abc.abstractmethod + def set_num_threads(self, /, num_threads: int) -> None: ... + @abc.abstractmethod + def get_version(self) -> str | None: ... + def _find_affixes(self) -> tuple[str, str]: ... + def _get_symbol(self, /, name: str) -> Incomplete | None: ... + +class OpenBLASController(LibController): + user_api: ClassVar[str] = "blas" + internal_api: ClassVar[str] = "openblas" + filename_prefixes: ClassVar[tuple[str, ...]] = ... + check_symbols: ClassVar[tuple[str, ...]] = ... + _symbol_prefixes: ClassVar[tuple[str, ...]] = ... + _symbol_suffixes: ClassVar[tuple[str, ...]] = ... + + threading_layer: Final[_ThreadingLayerOpenBLAS] + architecture: Final[str] + + @override + def get_num_threads(self) -> int | None: ... + @override + def set_num_threads(self, /, num_threads: int) -> None: ... + @override + def get_version(self) -> str | None: ... + def _get_threading_layer(self) -> _ThreadingLayerOpenBLAS: ... + def _get_architecture(self) -> str | None: ... + +class BLISController(LibController): + user_api: ClassVar[str] = "blas" + internal_api: ClassVar[str] = "blis" + filename_prefixes: ClassVar[tuple[str, ...]] = ... + check_symbols: ClassVar[tuple[str, ...]] = ... + + threading_layer: Final[_ThreadingLayerBLIS] + architecture: Final[str] + + @override + def get_num_threads(self) -> int | None: ... + @override + def set_num_threads(self, /, num_threads: int) -> None: ... + @override + def get_version(self) -> str | None: ... + def _get_threading_layer(self) -> _ThreadingLayerBLIS: ... + def _get_architecture(self) -> str | None: ... + +class FlexiBLASController(LibController): + user_api: ClassVar[str] = "blas" + internal_api: ClassVar[str] = "flexiblas" + filename_prefixes: ClassVar[tuple[str, ...]] = ... + check_symbols: ClassVar[tuple[str, ...]] = ... + + available_backends: Final[list[str]] + + @property + def loaded_backends(self) -> list[str]: ... + @property + def current_backend(self) -> str: ... + def set_additional_attributes(self) -> None: ... + @override + def get_num_threads(self) -> int | None: ... + @override + def set_num_threads(self, /, num_threads: int) -> None: ... + @override + def get_version(self) -> str | None: ... + def _get_backend_list(self, loaded: bool = False) -> list[str]: ... + def _get_current_backend(self) -> str: ... + def switch_backend(self, /, backend: str) -> None: ... + +class MKLController(LibController): + user_api: ClassVar[str] = "blas" + internal_api: ClassVar[str] = "mkl" + filename_prefixes: ClassVar[tuple[str, ...]] = ... + check_symbols: ClassVar[tuple[str, ...]] = ... + + threading_layer: Final[_ThreadingLayerMKL] = ... + + @override + def get_num_threads(self) -> int | None: ... + @override + def set_num_threads(self, /, num_threads: int) -> None: ... + @override + def get_version(self) -> str | None: ... + def _get_threading_layer(self) -> _ThreadingLayerMKL: ... + +class OpenMPController(LibController): + user_api: ClassVar[str] = "openmp" + internal_api: ClassVar[str] = "openmp" + filename_prefixes: ClassVar[tuple[str, ...]] = ... + check_symbols: ClassVar[tuple[str, ...]] = ... + + @override + def get_num_threads(self) -> int | None: ... + @override + def set_num_threads(self, /, num_threads: int) -> None: ... + @override + def get_version(self) -> None: ... + +_ALL_CONTROLLERS: Final[list[LibController]] = ... +_ALL_USER_APIS: Final[list[str]] = ... +_ALL_INTERNAL_APIS: Final[list[str]] = ... +_ALL_BLAS_LIBRARIES: Final[list[str]] = ... +_ALL_OPENMP_LIBRARIES: Final[tuple[str, ...]] = ... + +def register(controller: LibController) -> None: ... +def threadpool_info() -> _InfoDict: ... + +class _ThreadpoolLimiter: + _controller: Final[LibController] + _limits: Final[dict[str, int | None]] + _user_api: Final[list[str]] + _prefixes: Final[list[str]] + _original_info: Final[_InfoDict] + + def __init__( + self, + /, + controller: LibController, + *, + limits: _ToLimits | None = None, + user_api: list[str] | None = None, + ) -> None: ... + def __enter__(self) -> Self: ... + def __exit__( + self, + /, + type: type[BaseException] | None, # noqa: A002 + value: BaseException | None, + traceback: types.TracebackType | None, + ) -> None: ... + @classmethod + def wrap( + cls, + controller: LibController, + *, + limits: _ToLimits | None = None, + user_api: list[str] | None = None, + ) -> Self: ... + def get_original_num_threads(self) -> dict[str, int | None]: ... + def restore_original_limits(self) -> None: ... + unregister = restore_original_limits + + def _check_params( + self, /, limits: _ToLimits | None, user_api: list[str] + ) -> tuple[dict[str, int | None], list[str], list[str]]: ... + def _set_threadpool_limits(self) -> None: ... + +class _ThreadpoolLimiterDecorator(_ThreadpoolLimiter, ContextDecorator): ... + +class threadpool_limits(_ThreadpoolLimiter): + def __init__( + self, limits: _ToLimits | None = None, user_api: list[str] | None = None + ) -> None: ... + @classmethod + def wrap( # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride] + cls, limits: _ToLimits | None = None, user_api: list[str] | None = None + ) -> Self: ... + +class ThreadpoolController: + _system_libraries: ClassVar[dict[str, ctypes.CDLL]] = ... + + lib_controllers: Final[list[LibController]] + + def __init__(self) -> None: ... + @classmethod + def _from_controllers(cls, lib_controllers: list[LibController]) -> Self: ... + def info(self) -> list[_InfoDict]: ... + def select(self, **kwargs: Incomplete) -> Self: ... + def _get_params_for_sequential_blas_under_openmp( + self, + ) -> _OpenMPSequentialBlasDict: ... + def limit( + self, /, *, limits: _ToLimits | None = None, user_api: list[str] | None = None + ) -> _ThreadpoolLimiter: ... + def wrap( + self, /, *, limits: _ToLimits | None = None, user_api: list[str] | None = None + ) -> _ThreadpoolLimiter: ... + def __len__(self) -> int: ... + def _load_libraries(self) -> None: ... + def _find_libraries_with_dl_iterate_phdr( + self, + ) -> list[Never] | Literal[0] | None: ... + def _find_libraries_with_dyld(self) -> list[Never] | None: ... + def _find_libraries_with_enum_process_module_ex(self) -> None: ... + def _find_libraries_pyodide(self) -> None: ... + def _make_controller_from_path(self, /, filepath: str) -> None: ... + def _check_prefix( + self, /, library_basename: str, filename_prefixes: Iterable[str] + ) -> str | None: ... + def _warn_if_incompatible_openmp(self) -> None: ... + @classmethod + def _get_libc(cls) -> ctypes.CDLL: ... + + if sys.platform == "win32": + @classmethod + def _get_windll(cls, dll_name: str) -> ctypes.WinDLL: ... # pyright: ignore[reportRedeclaration] + + else: + @classmethod + def _get_windll(cls, dll_name: str) -> ctypes.CDLL: ... From 4c54d8e140065d878e1b82f560292d746eda2a2c Mon Sep 17 00:00:00 2001 From: jorenham Date: Wed, 11 Jun 2025 18:35:09 +0200 Subject: [PATCH 2/7] turn `threadpoolctl` into a package this is needed for static typing support --- threadpoolctl.py => threadpoolctl/__init__.py | 0 threadpoolctl.pyi => threadpoolctl/__init__.pyi | 0 threadpoolctl/py.typed | 0 3 files changed, 0 insertions(+), 0 deletions(-) rename threadpoolctl.py => threadpoolctl/__init__.py (100%) rename threadpoolctl.pyi => threadpoolctl/__init__.pyi (100%) create mode 100644 threadpoolctl/py.typed diff --git a/threadpoolctl.py b/threadpoolctl/__init__.py similarity index 100% rename from threadpoolctl.py rename to threadpoolctl/__init__.py diff --git a/threadpoolctl.pyi b/threadpoolctl/__init__.pyi similarity index 100% rename from threadpoolctl.pyi rename to threadpoolctl/__init__.pyi diff --git a/threadpoolctl/py.typed b/threadpoolctl/py.typed new file mode 100644 index 00000000..e69de29b From 3be71d73f9d8fcd68809fc379afded6217b45d65 Mon Sep 17 00:00:00 2001 From: jorenham Date: Wed, 11 Jun 2025 18:36:20 +0200 Subject: [PATCH 3/7] add `Typing :: Typed` classifier trove --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index c0ea1696..a759a473 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,7 @@ classifiers = [ "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", "Topic :: Software Development :: Libraries :: Python Modules", + "Typing :: Typed", ] [tool.black] From 4367362626e517333c97c77f2c4441f1715f8551 Mon Sep 17 00:00:00 2001 From: jorenham Date: Wed, 11 Jun 2025 18:58:08 +0200 Subject: [PATCH 4/7] fix stubtest errors --- threadpoolctl/__init__.pyi | 30 +++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/threadpoolctl/__init__.pyi b/threadpoolctl/__init__.pyi index e6ab4441..893d8c52 100644 --- a/threadpoolctl/__init__.pyi +++ b/threadpoolctl/__init__.pyi @@ -6,8 +6,11 @@ from collections.abc import Iterable from contextlib import ContextDecorator from typing import Any, ClassVar, Final, Literal, final, type_check_only -from _typeshed import Incomplete -from typing_extensions import Never, Self, TypeAlias, TypedDict, override +from typing_extensions import Never, Self, TypeAlias, TypedDict, TypeVar, override + +_CDataT = TypeVar("_CDataT", bound=ctypes._CData) # pyright: ignore[reportPrivateUsage] +_PythonT = TypeVar("_PythonT") # pyright: ignore[reportPrivateUsage] +_CField: TypeAlias = ctypes._CField[_CDataT, _PythonT, _CDataT | _PythonT] # pyright: ignore[reportPrivateUsage] _ThreadingLayerOpenBLAS: TypeAlias = _ThreadingLayerBLIS | Literal["unknown"] _ThreadingLayerBLIS: TypeAlias = Literal["openmp", "pthreads", "disabled"] @@ -34,7 +37,7 @@ class _InfoDict(TypedDict): @final @type_check_only -class _OpenMPSequentialBlasDict(TypedDict): +class _OMPBlasDict(TypedDict): limits: Literal[1] | None user_api: Literal["blas"] | None @@ -49,7 +52,7 @@ __all__ = [ "threadpool_limits", ] -_SYSTEM_UINT: Final[type[ctypes.c_uint32 | ctypes.c_uint64]] = ... +_SYSTEM_UINT: Final[type[ctypes.c_size_t]] = ... _SYSTEM_UINT_HALF: Final[type[ctypes.c_uint16 | ctypes.c_uint32]] = ... _RTLD_NOLOAD: Final[int] = ... @@ -57,6 +60,11 @@ _RTLD_NOLOAD: Final[int] = ... class _dl_phdr_info(ctypes.Structure): _fields_: ClassVar = ... + dlpi_addr: _CField[ctypes.c_size_t, int] + dlpi_name: _CField[ctypes.c_char_p, bytes | None] + dlpi_phdr: _CField[ctypes.c_void_p, int | None] + dlpi_phnum: _CField[ctypes.c_uint16 | ctypes.c_uint32, int] + class LibController(abc.ABC): user_api: ClassVar[str] # abstract internal_api: ClassVar[str] # abstract @@ -90,7 +98,7 @@ class LibController(abc.ABC): @abc.abstractmethod def get_version(self) -> str | None: ... def _find_affixes(self) -> tuple[str, str]: ... - def _get_symbol(self, /, name: str) -> Incomplete | None: ... + def _get_symbol(self, /, name: str) -> Any | None: ... class OpenBLASController(LibController): user_api: ClassVar[str] = "blas" @@ -149,7 +157,7 @@ class FlexiBLASController(LibController): def set_num_threads(self, /, num_threads: int) -> None: ... @override def get_version(self) -> str | None: ... - def _get_backend_list(self, loaded: bool = False) -> list[str]: ... + def _get_backend_list(self, /, loaded: bool = False) -> list[str]: ... def _get_current_backend(self) -> str: ... def switch_backend(self, /, backend: str) -> None: ... @@ -251,10 +259,8 @@ class ThreadpoolController: @classmethod def _from_controllers(cls, lib_controllers: list[LibController]) -> Self: ... def info(self) -> list[_InfoDict]: ... - def select(self, **kwargs: Incomplete) -> Self: ... - def _get_params_for_sequential_blas_under_openmp( - self, - ) -> _OpenMPSequentialBlasDict: ... + def select(self, /, **kwargs: object) -> Self: ... + def _get_params_for_sequential_blas_under_openmp(self) -> _OMPBlasDict: ... def limit( self, /, *, limits: _ToLimits | None = None, user_api: list[str] | None = None ) -> _ThreadpoolLimiter: ... @@ -263,9 +269,7 @@ class ThreadpoolController: ) -> _ThreadpoolLimiter: ... def __len__(self) -> int: ... def _load_libraries(self) -> None: ... - def _find_libraries_with_dl_iterate_phdr( - self, - ) -> list[Never] | Literal[0] | None: ... + def _find_libraries_with_dl_iterate_phdr(self) -> list[Never] | int | None: ... def _find_libraries_with_dyld(self) -> list[Never] | None: ... def _find_libraries_with_enum_process_module_ex(self) -> None: ... def _find_libraries_pyodide(self) -> None: ... From 11f53c61be726fb2140e12c10368a60be8b7f9dd Mon Sep 17 00:00:00 2001 From: jorenham Date: Wed, 11 Jun 2025 19:02:45 +0200 Subject: [PATCH 5/7] run stubtest in CI --- .github/workflows/test.yml | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 39fc19e2..d5bedf51 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -272,3 +272,18 @@ jobs: - name: Check no test always skipped run: | python continuous_integration/check_no_test_skipped.py test_results + + stubtest: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: "3.9" + + - name: Install dependencies + run: pip install mypy . + + - name: Run stubtest + run: stubtest threadpoolctl From d46b9902ac8bff564b84c31b05ba082c4ee707c1 Mon Sep 17 00:00:00 2001 From: jorenham Date: Fri, 13 Jun 2025 11:33:53 +0200 Subject: [PATCH 6/7] appease `black` --- threadpoolctl/__init__.pyi | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/threadpoolctl/__init__.pyi b/threadpoolctl/__init__.pyi index 893d8c52..d17561aa 100644 --- a/threadpoolctl/__init__.pyi +++ b/threadpoolctl/__init__.pyi @@ -10,7 +10,9 @@ from typing_extensions import Never, Self, TypeAlias, TypedDict, TypeVar, overri _CDataT = TypeVar("_CDataT", bound=ctypes._CData) # pyright: ignore[reportPrivateUsage] _PythonT = TypeVar("_PythonT") # pyright: ignore[reportPrivateUsage] -_CField: TypeAlias = ctypes._CField[_CDataT, _PythonT, _CDataT | _PythonT] # pyright: ignore[reportPrivateUsage] +_CField: TypeAlias = ctypes._CField[ # pyright: ignore[reportPrivateUsage] + _CDataT, _PythonT, _CDataT | _PythonT +] _ThreadingLayerOpenBLAS: TypeAlias = _ThreadingLayerBLIS | Literal["unknown"] _ThreadingLayerBLIS: TypeAlias = Literal["openmp", "pthreads", "disabled"] @@ -283,7 +285,9 @@ class ThreadpoolController: if sys.platform == "win32": @classmethod - def _get_windll(cls, dll_name: str) -> ctypes.WinDLL: ... # pyright: ignore[reportRedeclaration] + def _get_windll( # pyright: ignore[reportRedeclaration] + cls, dll_name: str + ) -> ctypes.WinDLL: ... else: @classmethod From 2bfc9c2cc86946ec487fbb293acd3094b280fd60 Mon Sep 17 00:00:00 2001 From: jorenham Date: Sat, 21 Jun 2025 01:43:52 +0200 Subject: [PATCH 7/7] move CLI logic into ``__main__`` --- threadpoolctl/__init__.py | 42 ----------------------------------- threadpoolctl/__main__.py | 46 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 42 deletions(-) create mode 100644 threadpoolctl/__main__.py diff --git a/threadpoolctl/__init__.py b/threadpoolctl/__init__.py index e6ac58d8..f0fb82d9 100644 --- a/threadpoolctl/__init__.py +++ b/threadpoolctl/__init__.py @@ -1248,45 +1248,3 @@ def _get_windll(cls, dll_name): dll = ctypes.WinDLL(f"{dll_name}.dll") cls._system_libraries[dll_name] = dll return dll - - -def _main(): - """Commandline interface to display thread-pool information and exit.""" - import argparse - import importlib - import json - import sys - - parser = argparse.ArgumentParser( - usage="python -m threadpoolctl -i numpy scipy.linalg xgboost", - description="Display thread-pool information and exit.", - ) - parser.add_argument( - "-i", - "--import", - dest="modules", - nargs="*", - default=(), - help="Python modules to import before introspecting thread-pools.", - ) - parser.add_argument( - "-c", - "--command", - help="a Python statement to execute before introspecting thread-pools.", - ) - - options = parser.parse_args(sys.argv[1:]) - for module in options.modules: - try: - importlib.import_module(module, package=None) - except ImportError: - print("WARNING: could not import", module, file=sys.stderr) - - if options.command: - exec(options.command) - - print(json.dumps(threadpool_info(), indent=2)) - - -if __name__ == "__main__": - _main() diff --git a/threadpoolctl/__main__.py b/threadpoolctl/__main__.py new file mode 100644 index 00000000..09abef99 --- /dev/null +++ b/threadpoolctl/__main__.py @@ -0,0 +1,46 @@ +"""Commandline interface to display thread-pool information and exit.""" + +__all__ = () + + +def _main() -> None: + import argparse + import importlib + import json + import sys + + from threadpoolctl import threadpool_info + + parser = argparse.ArgumentParser( + usage="python -m threadpoolctl -i numpy scipy.linalg xgboost", + description="Display thread-pool information and exit.", + ) + parser.add_argument( + "-i", + "--import", + dest="modules", + nargs="*", + default=(), + help="Python modules to import before introspecting thread-pools.", + ) + parser.add_argument( + "-c", + "--command", + help="a Python statement to execute before introspecting thread-pools.", + ) + options = parser.parse_args(sys.argv[1:]) + + for module in options.modules: + try: + importlib.import_module(module, package=None) + except ImportError: + print("WARNING: could not import", module, file=sys.stderr) + + if options.command: + exec(options.command) + + print(json.dumps(threadpool_info(), indent=2)) + + +if __name__ == "__main__": + _main()