Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -272,3 +272,18 @@ jobs:
- name: Check no test always skipped
run: |
python continuous_integration/check_no_test_skipped.py test_results

stubtest:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4

- uses: actions/setup-python@v5
with:
python-version: "3.9"

- name: Install dependencies
run: pip install mypy .

- name: Run stubtest
run: stubtest threadpoolctl
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ classifiers = [
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Topic :: Software Development :: Libraries :: Python Modules",
"Typing :: Typed",
]

[tool.black]
Expand Down
42 changes: 0 additions & 42 deletions threadpoolctl.py → threadpoolctl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1248,45 +1248,3 @@ def _get_windll(cls, dll_name):
dll = ctypes.WinDLL(f"{dll_name}.dll")
cls._system_libraries[dll_name] = dll
return dll


def _main():
"""Commandline interface to display thread-pool information and exit."""
import argparse
import importlib
import json
import sys

parser = argparse.ArgumentParser(
usage="python -m threadpoolctl -i numpy scipy.linalg xgboost",
description="Display thread-pool information and exit.",
)
parser.add_argument(
"-i",
"--import",
dest="modules",
nargs="*",
default=(),
help="Python modules to import before introspecting thread-pools.",
)
parser.add_argument(
"-c",
"--command",
help="a Python statement to execute before introspecting thread-pools.",
)

options = parser.parse_args(sys.argv[1:])
for module in options.modules:
try:
importlib.import_module(module, package=None)
except ImportError:
print("WARNING: could not import", module, file=sys.stderr)

if options.command:
exec(options.command)

print(json.dumps(threadpool_info(), indent=2))


if __name__ == "__main__":
_main()
294 changes: 294 additions & 0 deletions threadpoolctl/__init__.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,294 @@
import abc
import ctypes
import sys
import types
from collections.abc import Iterable
from contextlib import ContextDecorator
from typing import Any, ClassVar, Final, Literal, final, type_check_only

from typing_extensions import Never, Self, TypeAlias, TypedDict, TypeVar, override

_CDataT = TypeVar("_CDataT", bound=ctypes._CData) # pyright: ignore[reportPrivateUsage]
_PythonT = TypeVar("_PythonT") # pyright: ignore[reportPrivateUsage]
_CField: TypeAlias = ctypes._CField[ # pyright: ignore[reportPrivateUsage]
_CDataT, _PythonT, _CDataT | _PythonT
]

_ThreadingLayerOpenBLAS: TypeAlias = _ThreadingLayerBLIS | Literal["unknown"]
_ThreadingLayerBLIS: TypeAlias = Literal["openmp", "pthreads", "disabled"]
_ThreadingLayerMKL: TypeAlias = Literal[
"intel", "sequential", "pgi", "gnu", "tbb", "not specified"
]

_ToLimits: TypeAlias = (
Literal["sequential_blas_under_openmp"]
| int
| list[_InfoDict]
| dict[str, Any]
| ThreadpoolController
)

@type_check_only
class _InfoDict(TypedDict):
user_api: str
internal_api: str
num_threads: int | None
prefix: str
filepath: str
version: str | None

@final
@type_check_only
class _OMPBlasDict(TypedDict):
limits: Literal[1] | None
user_api: Literal["blas"] | None

###

__version__: Final[str] = ...
__all__ = [
"LibController",
"ThreadpoolController",
"register",
"threadpool_info",
"threadpool_limits",
]

_SYSTEM_UINT: Final[type[ctypes.c_size_t]] = ...
_SYSTEM_UINT_HALF: Final[type[ctypes.c_uint16 | ctypes.c_uint32]] = ...
_RTLD_NOLOAD: Final[int] = ...

@final
class _dl_phdr_info(ctypes.Structure):
_fields_: ClassVar = ...

dlpi_addr: _CField[ctypes.c_size_t, int]
dlpi_name: _CField[ctypes.c_char_p, bytes | None]
dlpi_phdr: _CField[ctypes.c_void_p, int | None]
dlpi_phnum: _CField[ctypes.c_uint16 | ctypes.c_uint32, int]

class LibController(abc.ABC):
user_api: ClassVar[str] # abstract
internal_api: ClassVar[str] # abstract
filename_prefixes: ClassVar[tuple[str, ...]] # abstract

parent: Final[LibController | None]
prefix: Final[str | None]
filepath: Final[str | None]
dynlib: Final[ctypes.CDLL]
_symbol_prefix: Final[str]
_symbol_affix: Final[str]
version: Final[str | None]

@final
def __init__(
self,
/,
*,
filepath: str | None = None,
prefix: str | None = None,
parent: LibController | None = None,
) -> None: ...
def info(self) -> dict[str, Any]: ...
def set_additional_attributes(self) -> None: ...
@property
def num_threads(self) -> int | None: ...
@abc.abstractmethod
def get_num_threads(self) -> int | None: ...
@abc.abstractmethod
def set_num_threads(self, /, num_threads: int) -> None: ...
@abc.abstractmethod
def get_version(self) -> str | None: ...
def _find_affixes(self) -> tuple[str, str]: ...
def _get_symbol(self, /, name: str) -> Any | None: ...

class OpenBLASController(LibController):
user_api: ClassVar[str] = "blas"
internal_api: ClassVar[str] = "openblas"
filename_prefixes: ClassVar[tuple[str, ...]] = ...
check_symbols: ClassVar[tuple[str, ...]] = ...
_symbol_prefixes: ClassVar[tuple[str, ...]] = ...
_symbol_suffixes: ClassVar[tuple[str, ...]] = ...

threading_layer: Final[_ThreadingLayerOpenBLAS]
architecture: Final[str]

@override
def get_num_threads(self) -> int | None: ...
@override
def set_num_threads(self, /, num_threads: int) -> None: ...
@override
def get_version(self) -> str | None: ...
def _get_threading_layer(self) -> _ThreadingLayerOpenBLAS: ...
def _get_architecture(self) -> str | None: ...

class BLISController(LibController):
user_api: ClassVar[str] = "blas"
internal_api: ClassVar[str] = "blis"
filename_prefixes: ClassVar[tuple[str, ...]] = ...
check_symbols: ClassVar[tuple[str, ...]] = ...

threading_layer: Final[_ThreadingLayerBLIS]
architecture: Final[str]

@override
def get_num_threads(self) -> int | None: ...
@override
def set_num_threads(self, /, num_threads: int) -> None: ...
@override
def get_version(self) -> str | None: ...
def _get_threading_layer(self) -> _ThreadingLayerBLIS: ...
def _get_architecture(self) -> str | None: ...

class FlexiBLASController(LibController):
user_api: ClassVar[str] = "blas"
internal_api: ClassVar[str] = "flexiblas"
filename_prefixes: ClassVar[tuple[str, ...]] = ...
check_symbols: ClassVar[tuple[str, ...]] = ...

available_backends: Final[list[str]]

@property
def loaded_backends(self) -> list[str]: ...
@property
def current_backend(self) -> str: ...
def set_additional_attributes(self) -> None: ...
@override
def get_num_threads(self) -> int | None: ...
@override
def set_num_threads(self, /, num_threads: int) -> None: ...
@override
def get_version(self) -> str | None: ...
def _get_backend_list(self, /, loaded: bool = False) -> list[str]: ...
def _get_current_backend(self) -> str: ...
def switch_backend(self, /, backend: str) -> None: ...

class MKLController(LibController):
user_api: ClassVar[str] = "blas"
internal_api: ClassVar[str] = "mkl"
filename_prefixes: ClassVar[tuple[str, ...]] = ...
check_symbols: ClassVar[tuple[str, ...]] = ...

threading_layer: Final[_ThreadingLayerMKL] = ...

@override
def get_num_threads(self) -> int | None: ...
@override
def set_num_threads(self, /, num_threads: int) -> None: ...
@override
def get_version(self) -> str | None: ...
def _get_threading_layer(self) -> _ThreadingLayerMKL: ...

class OpenMPController(LibController):
user_api: ClassVar[str] = "openmp"
internal_api: ClassVar[str] = "openmp"
filename_prefixes: ClassVar[tuple[str, ...]] = ...
check_symbols: ClassVar[tuple[str, ...]] = ...

@override
def get_num_threads(self) -> int | None: ...
@override
def set_num_threads(self, /, num_threads: int) -> None: ...
@override
def get_version(self) -> None: ...

_ALL_CONTROLLERS: Final[list[LibController]] = ...
_ALL_USER_APIS: Final[list[str]] = ...
_ALL_INTERNAL_APIS: Final[list[str]] = ...
_ALL_BLAS_LIBRARIES: Final[list[str]] = ...
_ALL_OPENMP_LIBRARIES: Final[tuple[str, ...]] = ...

def register(controller: LibController) -> None: ...
def threadpool_info() -> _InfoDict: ...

class _ThreadpoolLimiter:
_controller: Final[LibController]
_limits: Final[dict[str, int | None]]
_user_api: Final[list[str]]
_prefixes: Final[list[str]]
_original_info: Final[_InfoDict]

def __init__(
self,
/,
controller: LibController,
*,
limits: _ToLimits | None = None,
user_api: list[str] | None = None,
) -> None: ...
def __enter__(self) -> Self: ...
def __exit__(
self,
/,
type: type[BaseException] | None, # noqa: A002
value: BaseException | None,
traceback: types.TracebackType | None,
) -> None: ...
@classmethod
def wrap(
cls,
controller: LibController,
*,
limits: _ToLimits | None = None,
user_api: list[str] | None = None,
) -> Self: ...
def get_original_num_threads(self) -> dict[str, int | None]: ...
def restore_original_limits(self) -> None: ...
unregister = restore_original_limits

def _check_params(
self, /, limits: _ToLimits | None, user_api: list[str]
) -> tuple[dict[str, int | None], list[str], list[str]]: ...
def _set_threadpool_limits(self) -> None: ...

class _ThreadpoolLimiterDecorator(_ThreadpoolLimiter, ContextDecorator): ...

class threadpool_limits(_ThreadpoolLimiter):
def __init__(
self, limits: _ToLimits | None = None, user_api: list[str] | None = None
) -> None: ...
@classmethod
def wrap( # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride]
cls, limits: _ToLimits | None = None, user_api: list[str] | None = None
) -> Self: ...

class ThreadpoolController:
_system_libraries: ClassVar[dict[str, ctypes.CDLL]] = ...

lib_controllers: Final[list[LibController]]

def __init__(self) -> None: ...
@classmethod
def _from_controllers(cls, lib_controllers: list[LibController]) -> Self: ...
def info(self) -> list[_InfoDict]: ...
def select(self, /, **kwargs: object) -> Self: ...
def _get_params_for_sequential_blas_under_openmp(self) -> _OMPBlasDict: ...
def limit(
self, /, *, limits: _ToLimits | None = None, user_api: list[str] | None = None
) -> _ThreadpoolLimiter: ...
def wrap(
self, /, *, limits: _ToLimits | None = None, user_api: list[str] | None = None
) -> _ThreadpoolLimiter: ...
def __len__(self) -> int: ...
def _load_libraries(self) -> None: ...
def _find_libraries_with_dl_iterate_phdr(self) -> list[Never] | int | None: ...
def _find_libraries_with_dyld(self) -> list[Never] | None: ...
def _find_libraries_with_enum_process_module_ex(self) -> None: ...
def _find_libraries_pyodide(self) -> None: ...
def _make_controller_from_path(self, /, filepath: str) -> None: ...
def _check_prefix(
self, /, library_basename: str, filename_prefixes: Iterable[str]
) -> str | None: ...
def _warn_if_incompatible_openmp(self) -> None: ...
@classmethod
def _get_libc(cls) -> ctypes.CDLL: ...

if sys.platform == "win32":
@classmethod
def _get_windll( # pyright: ignore[reportRedeclaration]
cls, dll_name: str
) -> ctypes.WinDLL: ...

else:
@classmethod
def _get_windll(cls, dll_name: str) -> ctypes.CDLL: ...
Loading