diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 39fc19e2..d5bedf51 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -272,3 +272,18 @@ jobs: - name: Check no test always skipped run: | python continuous_integration/check_no_test_skipped.py test_results + + stubtest: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: "3.9" + + - name: Install dependencies + run: pip install mypy . + + - name: Run stubtest + run: stubtest threadpoolctl diff --git a/pyproject.toml b/pyproject.toml index c0ea1696..a759a473 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,7 @@ classifiers = [ "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", "Topic :: Software Development :: Libraries :: Python Modules", + "Typing :: Typed", ] [tool.black] diff --git a/threadpoolctl.py b/threadpoolctl/__init__.py similarity index 97% rename from threadpoolctl.py rename to threadpoolctl/__init__.py index e6ac58d8..f0fb82d9 100644 --- a/threadpoolctl.py +++ b/threadpoolctl/__init__.py @@ -1248,45 +1248,3 @@ def _get_windll(cls, dll_name): dll = ctypes.WinDLL(f"{dll_name}.dll") cls._system_libraries[dll_name] = dll return dll - - -def _main(): - """Commandline interface to display thread-pool information and exit.""" - import argparse - import importlib - import json - import sys - - parser = argparse.ArgumentParser( - usage="python -m threadpoolctl -i numpy scipy.linalg xgboost", - description="Display thread-pool information and exit.", - ) - parser.add_argument( - "-i", - "--import", - dest="modules", - nargs="*", - default=(), - help="Python modules to import before introspecting thread-pools.", - ) - parser.add_argument( - "-c", - "--command", - help="a Python statement to execute before introspecting thread-pools.", - ) - - options = parser.parse_args(sys.argv[1:]) - for module in options.modules: - try: - importlib.import_module(module, package=None) - except ImportError: - print("WARNING: could not import", module, file=sys.stderr) - - if options.command: - exec(options.command) - - print(json.dumps(threadpool_info(), indent=2)) - - -if __name__ == "__main__": - _main() diff --git a/threadpoolctl/__init__.pyi b/threadpoolctl/__init__.pyi new file mode 100644 index 00000000..d17561aa --- /dev/null +++ b/threadpoolctl/__init__.pyi @@ -0,0 +1,294 @@ +import abc +import ctypes +import sys +import types +from collections.abc import Iterable +from contextlib import ContextDecorator +from typing import Any, ClassVar, Final, Literal, final, type_check_only + +from typing_extensions import Never, Self, TypeAlias, TypedDict, TypeVar, override + +_CDataT = TypeVar("_CDataT", bound=ctypes._CData) # pyright: ignore[reportPrivateUsage] +_PythonT = TypeVar("_PythonT") # pyright: ignore[reportPrivateUsage] +_CField: TypeAlias = ctypes._CField[ # pyright: ignore[reportPrivateUsage] + _CDataT, _PythonT, _CDataT | _PythonT +] + +_ThreadingLayerOpenBLAS: TypeAlias = _ThreadingLayerBLIS | Literal["unknown"] +_ThreadingLayerBLIS: TypeAlias = Literal["openmp", "pthreads", "disabled"] +_ThreadingLayerMKL: TypeAlias = Literal[ + "intel", "sequential", "pgi", "gnu", "tbb", "not specified" +] + +_ToLimits: TypeAlias = ( + Literal["sequential_blas_under_openmp"] + | int + | list[_InfoDict] + | dict[str, Any] + | ThreadpoolController +) + +@type_check_only +class _InfoDict(TypedDict): + user_api: str + internal_api: str + num_threads: int | None + prefix: str + filepath: str + version: str | None + +@final +@type_check_only +class _OMPBlasDict(TypedDict): + limits: Literal[1] | None + user_api: Literal["blas"] | None + +### + +__version__: Final[str] = ... +__all__ = [ + "LibController", + "ThreadpoolController", + "register", + "threadpool_info", + "threadpool_limits", +] + +_SYSTEM_UINT: Final[type[ctypes.c_size_t]] = ... +_SYSTEM_UINT_HALF: Final[type[ctypes.c_uint16 | ctypes.c_uint32]] = ... +_RTLD_NOLOAD: Final[int] = ... + +@final +class _dl_phdr_info(ctypes.Structure): + _fields_: ClassVar = ... + + dlpi_addr: _CField[ctypes.c_size_t, int] + dlpi_name: _CField[ctypes.c_char_p, bytes | None] + dlpi_phdr: _CField[ctypes.c_void_p, int | None] + dlpi_phnum: _CField[ctypes.c_uint16 | ctypes.c_uint32, int] + +class LibController(abc.ABC): + user_api: ClassVar[str] # abstract + internal_api: ClassVar[str] # abstract + filename_prefixes: ClassVar[tuple[str, ...]] # abstract + + parent: Final[LibController | None] + prefix: Final[str | None] + filepath: Final[str | None] + dynlib: Final[ctypes.CDLL] + _symbol_prefix: Final[str] + _symbol_affix: Final[str] + version: Final[str | None] + + @final + def __init__( + self, + /, + *, + filepath: str | None = None, + prefix: str | None = None, + parent: LibController | None = None, + ) -> None: ... + def info(self) -> dict[str, Any]: ... + def set_additional_attributes(self) -> None: ... + @property + def num_threads(self) -> int | None: ... + @abc.abstractmethod + def get_num_threads(self) -> int | None: ... + @abc.abstractmethod + def set_num_threads(self, /, num_threads: int) -> None: ... + @abc.abstractmethod + def get_version(self) -> str | None: ... + def _find_affixes(self) -> tuple[str, str]: ... + def _get_symbol(self, /, name: str) -> Any | None: ... + +class OpenBLASController(LibController): + user_api: ClassVar[str] = "blas" + internal_api: ClassVar[str] = "openblas" + filename_prefixes: ClassVar[tuple[str, ...]] = ... + check_symbols: ClassVar[tuple[str, ...]] = ... + _symbol_prefixes: ClassVar[tuple[str, ...]] = ... + _symbol_suffixes: ClassVar[tuple[str, ...]] = ... + + threading_layer: Final[_ThreadingLayerOpenBLAS] + architecture: Final[str] + + @override + def get_num_threads(self) -> int | None: ... + @override + def set_num_threads(self, /, num_threads: int) -> None: ... + @override + def get_version(self) -> str | None: ... + def _get_threading_layer(self) -> _ThreadingLayerOpenBLAS: ... + def _get_architecture(self) -> str | None: ... + +class BLISController(LibController): + user_api: ClassVar[str] = "blas" + internal_api: ClassVar[str] = "blis" + filename_prefixes: ClassVar[tuple[str, ...]] = ... + check_symbols: ClassVar[tuple[str, ...]] = ... + + threading_layer: Final[_ThreadingLayerBLIS] + architecture: Final[str] + + @override + def get_num_threads(self) -> int | None: ... + @override + def set_num_threads(self, /, num_threads: int) -> None: ... + @override + def get_version(self) -> str | None: ... + def _get_threading_layer(self) -> _ThreadingLayerBLIS: ... + def _get_architecture(self) -> str | None: ... + +class FlexiBLASController(LibController): + user_api: ClassVar[str] = "blas" + internal_api: ClassVar[str] = "flexiblas" + filename_prefixes: ClassVar[tuple[str, ...]] = ... + check_symbols: ClassVar[tuple[str, ...]] = ... + + available_backends: Final[list[str]] + + @property + def loaded_backends(self) -> list[str]: ... + @property + def current_backend(self) -> str: ... + def set_additional_attributes(self) -> None: ... + @override + def get_num_threads(self) -> int | None: ... + @override + def set_num_threads(self, /, num_threads: int) -> None: ... + @override + def get_version(self) -> str | None: ... + def _get_backend_list(self, /, loaded: bool = False) -> list[str]: ... + def _get_current_backend(self) -> str: ... + def switch_backend(self, /, backend: str) -> None: ... + +class MKLController(LibController): + user_api: ClassVar[str] = "blas" + internal_api: ClassVar[str] = "mkl" + filename_prefixes: ClassVar[tuple[str, ...]] = ... + check_symbols: ClassVar[tuple[str, ...]] = ... + + threading_layer: Final[_ThreadingLayerMKL] = ... + + @override + def get_num_threads(self) -> int | None: ... + @override + def set_num_threads(self, /, num_threads: int) -> None: ... + @override + def get_version(self) -> str | None: ... + def _get_threading_layer(self) -> _ThreadingLayerMKL: ... + +class OpenMPController(LibController): + user_api: ClassVar[str] = "openmp" + internal_api: ClassVar[str] = "openmp" + filename_prefixes: ClassVar[tuple[str, ...]] = ... + check_symbols: ClassVar[tuple[str, ...]] = ... + + @override + def get_num_threads(self) -> int | None: ... + @override + def set_num_threads(self, /, num_threads: int) -> None: ... + @override + def get_version(self) -> None: ... + +_ALL_CONTROLLERS: Final[list[LibController]] = ... +_ALL_USER_APIS: Final[list[str]] = ... +_ALL_INTERNAL_APIS: Final[list[str]] = ... +_ALL_BLAS_LIBRARIES: Final[list[str]] = ... +_ALL_OPENMP_LIBRARIES: Final[tuple[str, ...]] = ... + +def register(controller: LibController) -> None: ... +def threadpool_info() -> _InfoDict: ... + +class _ThreadpoolLimiter: + _controller: Final[LibController] + _limits: Final[dict[str, int | None]] + _user_api: Final[list[str]] + _prefixes: Final[list[str]] + _original_info: Final[_InfoDict] + + def __init__( + self, + /, + controller: LibController, + *, + limits: _ToLimits | None = None, + user_api: list[str] | None = None, + ) -> None: ... + def __enter__(self) -> Self: ... + def __exit__( + self, + /, + type: type[BaseException] | None, # noqa: A002 + value: BaseException | None, + traceback: types.TracebackType | None, + ) -> None: ... + @classmethod + def wrap( + cls, + controller: LibController, + *, + limits: _ToLimits | None = None, + user_api: list[str] | None = None, + ) -> Self: ... + def get_original_num_threads(self) -> dict[str, int | None]: ... + def restore_original_limits(self) -> None: ... + unregister = restore_original_limits + + def _check_params( + self, /, limits: _ToLimits | None, user_api: list[str] + ) -> tuple[dict[str, int | None], list[str], list[str]]: ... + def _set_threadpool_limits(self) -> None: ... + +class _ThreadpoolLimiterDecorator(_ThreadpoolLimiter, ContextDecorator): ... + +class threadpool_limits(_ThreadpoolLimiter): + def __init__( + self, limits: _ToLimits | None = None, user_api: list[str] | None = None + ) -> None: ... + @classmethod + def wrap( # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride] + cls, limits: _ToLimits | None = None, user_api: list[str] | None = None + ) -> Self: ... + +class ThreadpoolController: + _system_libraries: ClassVar[dict[str, ctypes.CDLL]] = ... + + lib_controllers: Final[list[LibController]] + + def __init__(self) -> None: ... + @classmethod + def _from_controllers(cls, lib_controllers: list[LibController]) -> Self: ... + def info(self) -> list[_InfoDict]: ... + def select(self, /, **kwargs: object) -> Self: ... + def _get_params_for_sequential_blas_under_openmp(self) -> _OMPBlasDict: ... + def limit( + self, /, *, limits: _ToLimits | None = None, user_api: list[str] | None = None + ) -> _ThreadpoolLimiter: ... + def wrap( + self, /, *, limits: _ToLimits | None = None, user_api: list[str] | None = None + ) -> _ThreadpoolLimiter: ... + def __len__(self) -> int: ... + def _load_libraries(self) -> None: ... + def _find_libraries_with_dl_iterate_phdr(self) -> list[Never] | int | None: ... + def _find_libraries_with_dyld(self) -> list[Never] | None: ... + def _find_libraries_with_enum_process_module_ex(self) -> None: ... + def _find_libraries_pyodide(self) -> None: ... + def _make_controller_from_path(self, /, filepath: str) -> None: ... + def _check_prefix( + self, /, library_basename: str, filename_prefixes: Iterable[str] + ) -> str | None: ... + def _warn_if_incompatible_openmp(self) -> None: ... + @classmethod + def _get_libc(cls) -> ctypes.CDLL: ... + + if sys.platform == "win32": + @classmethod + def _get_windll( # pyright: ignore[reportRedeclaration] + cls, dll_name: str + ) -> ctypes.WinDLL: ... + + else: + @classmethod + def _get_windll(cls, dll_name: str) -> ctypes.CDLL: ... diff --git a/threadpoolctl/__main__.py b/threadpoolctl/__main__.py new file mode 100644 index 00000000..09abef99 --- /dev/null +++ b/threadpoolctl/__main__.py @@ -0,0 +1,46 @@ +"""Commandline interface to display thread-pool information and exit.""" + +__all__ = () + + +def _main() -> None: + import argparse + import importlib + import json + import sys + + from threadpoolctl import threadpool_info + + parser = argparse.ArgumentParser( + usage="python -m threadpoolctl -i numpy scipy.linalg xgboost", + description="Display thread-pool information and exit.", + ) + parser.add_argument( + "-i", + "--import", + dest="modules", + nargs="*", + default=(), + help="Python modules to import before introspecting thread-pools.", + ) + parser.add_argument( + "-c", + "--command", + help="a Python statement to execute before introspecting thread-pools.", + ) + options = parser.parse_args(sys.argv[1:]) + + for module in options.modules: + try: + importlib.import_module(module, package=None) + except ImportError: + print("WARNING: could not import", module, file=sys.stderr) + + if options.command: + exec(options.command) + + print(json.dumps(threadpool_info(), indent=2)) + + +if __name__ == "__main__": + _main() diff --git a/threadpoolctl/py.typed b/threadpoolctl/py.typed new file mode 100644 index 00000000..e69de29b