diff --git a/.gitignore b/.gitignore index de2d5e08..918ca89d 100644 --- a/.gitignore +++ b/.gitignore @@ -150,3 +150,4 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ +.aider* diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 00000000..3bdb31d9 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,20 @@ +#FROM python:3.12-slim-bullseye +FROM pytorch/pytorch:2.7.0-cuda12.6-cudnn9-runtime + +# Set the working directory in the container +WORKDIR /app + +# Copy the sources +COPY . /src + +# Install & upgrade software +RUN apt-get update && \ + apt-get install -y git libqt5gui5 && \ + rm -rf /var/lib/apt/lists/* && \ + python3 -m pip install --root-user-action=ignore --no-cache-dir --upgrade pip && \ + python3 -m pip install --root-user-action=ignore --no-cache-dir /src[globus,gui,ptychi] && \ + rm -rf /src + +# Run ptychodus when the container launches +#ENV QT_DEBUG_PLUGINS=1 +CMD ["python3", "-m", "ptychodus"] diff --git a/README.rst b/README.rst index 01c4e823..d52c2bda 100644 --- a/README.rst +++ b/README.rst @@ -67,11 +67,11 @@ Developer Installation $ conda activate ptychodus $ pip install -e ./ptychodus -* To install the `tike`_ backend: +* To install the `pty-chi`_ backend: .. code-block:: shell - $ conda install -n ptychodus -c conda-forge tike + $ pip install ptychi * To install the `PtychoNN`_ backend: @@ -93,6 +93,6 @@ Reporting Bugs Open a bug at https://github.com/AdvancedPhotonSource/ptychodus/issues. .. _`ptychodus`: https://github.com/AdvancedPhotonSource/ptychodus -.. _`tike`: https://github.com/tomography/tike +.. _`pty-chi`: https://github.com/AdvancedPhotonSource/pty-chi .. _`PtychoNN`: https://github.com/mcherukara/PtychoNN .. _`PvaPy`: https://github.com/epics-base/pvaPy diff --git a/apptainer/ptychodus.def b/apptainer/ptychodus.def deleted file mode 100644 index 6f4bb435..00000000 --- a/apptainer/ptychodus.def +++ /dev/null @@ -1,26 +0,0 @@ -Bootstrap: docker -From: registry.fedoraproject.org/fedora-minimal:40-{{ target_arch }} - -%arguments -target_arch=x86_64 -cuda_version=12.0 -pkg_version=master - -%post -curl -L -o conda-installer.sh https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-Linux-{{ target_arch }}.sh -bash conda-installer.sh -b -p "/opt/miniconda" -rm conda-installer.sh -/opt/miniconda/bin/conda install unzip --yes -curl -L -o source.zip https://github.com/AdvancedPhotonSource/ptychodus/archive/{{ pkg_version }}.zip -/opt/miniconda/bin/unzip source.zip -rm source.zip -cd ptychodus* -CONDA_OVERRIDE_CUDA={{ cuda_version }} /opt/miniconda/bin/conda install cuda-version={{ cuda_version }} --file requirements.txt -c conda-forge --yes -/opt/miniconda/bin/pip install . --no-deps --no-build-isolation -/opt/miniconda/bin/pip check -cd .. -rm ptychodus* -rf -/opt/miniconda/bin/conda clean --all --yes - -%runscript -/opt/miniconda/bin/python -m ptychodus "$@" diff --git a/doc/api.rst b/doc/api.rst new file mode 100644 index 00000000..7c2cc12f --- /dev/null +++ b/doc/api.rst @@ -0,0 +1,12 @@ +#!/usr/bin/env python + +from pathlib import Path +from ptychodus.model import ModelCore + +def main() -> int: + settings_file = Path("path/to/settings.ini") + + with ModelCore(settings_file) as model: + input_product_api = model.workflow_api.create_product("new_product_name") + output_product_api = input_product_api.reconstruct_local() + output_product_api.save_product("/path/to/file.h5", file_type="HDF5") diff --git a/doc/dist.rst b/doc/dist.rst new file mode 100644 index 00000000..4757c7f7 --- /dev/null +++ b/doc/dist.rst @@ -0,0 +1,36 @@ +Distribution Instructions +========================= + +Python Package Index (PyPI) +--------------------------- + +From the ptychodus directory, create wheel in ./dist/ + +.. code-block:: shell + + $ python -m build . + +Upload to PyPI + +.. code-block:: shell + + $ twine upload dist/* + +Docker +------ + +Build Docker image + +.. code-block:: shell + + $ podman build -t ptychodus:latest . + + +Run container + +.. code-block:: shell + + $ xhost +local:podman + $ podman run -it --rm -e "DISPLAY=$DISPLAY" -v "$HOME/.Xauthority:/root/.Xauthority:ro" --network host \ + --gpus all --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 python-ptychodus + $ xhost -local:podman diff --git a/polaris.yaml b/polaris.yaml new file mode 100644 index 00000000..ba18891f --- /dev/null +++ b/polaris.yaml @@ -0,0 +1,40 @@ +engine: + type: HighThroughputEngine + max_workers_per_node: 1 + + # Un-comment to give each worker exclusive access to a single GPU + # available_accelerators: 4 + + strategy: + type: SimpleStrategy + max_idletime: 3600 + + address: + type: address_by_interface + ifname: bond0 + + provider: + type: PBSProProvider + + launcher: + type: MpiExecLauncher + # Ensures 1 manger per node, work on all 64 cores + bind_cmd: --cpu-bind + overrides: --depth=64 --ppn 1 + + account: APSDataAnalysis + queue: preemptable + cpus_per_node: 32 + select_options: ngpus=4 + + # e.g., "#PBS -l filesystems=home:grand:eagle\n#PBS -k doe" + scheduler_options: "#PBS -l filesystems=home:grand:eagle" + + # Node setup: activate necessary conda environment and such + worker_init: "source ~/miniconda3/etc/profile.d/conda.sh; conda activate ptychodus", + + walltime: 01:00:00 + nodes_per_block: 1 + init_blocks: 0 + min_blocks: 0 + max_blocks: 2 diff --git a/pyproject.toml b/pyproject.toml index aae521b1..f271c887 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["setuptools>=64", "setuptools_scm>=8"] +requires = ["setuptools>=64", "setuptools_scm[toml]>=8"] build-backend = "setuptools.build_meta" [project] @@ -7,14 +7,17 @@ name = "ptychodus" description = "Ptychodus is a ptychography data analysis application." readme = "README.rst" requires-python = ">=3.10" -license = {file = "LICENSE.txt"} +license = {file = "LICENSE"} dependencies = [ - "h5py", + "h5py>=3", + "hdf5plugin", "matplotlib", "numpy", "psutil", + "requests", "scikit-image", "scipy", + "tables", "tifffile", "watchdog", ] @@ -24,10 +27,17 @@ dynamic = ["version"] ptychodus = "ptychodus.__main__:main" [project.optional-dependencies] -globus = ["gladier", "gladier-tools"] +globus = ["gladier", "gladier-tools>=0.5.4"] gui = ["PyQt5"] ptychonn = ["ptychonn==0.3.*,>=0.3.7"] tike = ["tike==0.25.*,>=0.25.3"] +ptychi = ["ptychi==1.*"] + +[tool.setuptools.package-data] +"ptychodus" = ["py.typed"] + +[tool.setuptools.packages.find] +where = ["src"] [tool.setuptools_scm] @@ -44,6 +54,8 @@ module = [ "hdf5plugin", "lightning.*", "parsl.*", + "ptychi.*", + "ptycho.*", "ptychonn.*", "pvaccess", "pvapy.*", @@ -61,9 +73,11 @@ target-version = "py310" [tool.ruff.format] quote-style = "single" -[tool.setuptools.package-data] -"ptychodus" = ["py.typed"] - -[tool.setuptools.packages.find] -where = ["src"] +[tool.ruff.lint] +select = [ + "N", + "NPY", +] +[tool.pyright] +pythonVersion = "3.10" diff --git a/requirements-dev.txt b/requirements-dev.txt index c42305e2..3113269b 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,4 +1,5 @@ -h5py +build +h5py>=3 hdf5plugin matplotlib mypy @@ -6,6 +7,7 @@ numpy psutil pyqt pyqt-stubs +pytables pytest python>=3.10 ruff @@ -16,4 +18,3 @@ setuptools_scm>=8 tifffile toml watchdog -wheel diff --git a/src/ptychodus/__init__.py b/src/ptychodus/__init__.py index d30ba50a..06cf3392 100644 --- a/src/ptychodus/__init__.py +++ b/src/ptychodus/__init__.py @@ -6,7 +6,7 @@ pass try: - from .ptychodusAdImageProcessor import PtychodusAdImageProcessor + from .ptychodus_stream_processor import PtychodusAdImageProcessor except ModuleNotFoundError: pass diff --git a/src/ptychodus/__main__.py b/src/ptychodus/__main__.py index 5fb3c484..e6be2bad 100644 --- a/src/ptychodus/__main__.py +++ b/src/ptychodus/__main__.py @@ -2,17 +2,20 @@ from pathlib import Path import argparse +import logging import sys from ptychodus.model import ModelCore import ptychodus +logger = logging.getLogger(__name__) -def versionString() -> str: + +def version_string() -> str: return f'{ptychodus.__name__.title()} ({ptychodus.__version__})' -def verifyAllArgumentsParsed(parser: argparse.ArgumentParser, argv: list[str]) -> None: +def verify_all_arguments_parsed(parser: argparse.ArgumentParser, argv: list[str]) -> None: if argv: parser.error('unrecognized arguments: %s' % ' '.join(argv)) @@ -79,61 +82,62 @@ def main() -> int: '-v', '--version', action='version', - version=versionString(), + version=version_string(), ) - parsedArgs, unparsedArgs = parser.parse_known_args() - settingsFile = Path(parsedArgs.settings.name) if parsedArgs.settings else None + parsed_args, unparsed_args = parser.parse_known_args() + settings_file = Path(parsed_args.settings.name) if parsed_args.settings else None - with ModelCore(settingsFile, isDeveloperModeEnabled=parsedArgs.dev) as model: - if parsedArgs.patterns is not None: - patternsFilePath = Path(parsedArgs.patterns.name) - model.workflowAPI.importProcessedPatterns(patternsFilePath) + with ModelCore(settings_file, is_developer_mode_enabled=parsed_args.dev) as model: + if parsed_args.patterns is not None: + patterns_file_path = Path(parsed_args.patterns.name) + model.workflow_api.import_assembled_patterns(patterns_file_path) - if parsedArgs.batch is not None: - verifyAllArgumentsParsed(parser, unparsedArgs) + if parsed_args.batch is not None: + verify_all_arguments_parsed(parser, unparsed_args) - if parsedArgs.input is None or parsedArgs.output is None: + if parsed_args.input is None or parsed_args.output is None: parser.error('Batch mode requires input and output arguments!') return -1 - action = parsedArgs.batch - inputFilePath = Path(parsedArgs.input.name) - outputFilePath = Path(parsedArgs.output.name) - fluorescenceInputFilePath: Path | None = None - fluorescenceOutputFilePath: Path | None = None + action = parsed_args.batch + input_file_path = Path(parsed_args.input.name) + output_file_path = Path(parsed_args.output.name) + fluorescence_input_file_path: Path | None = None + fluorescence_output_file_path: Path | None = None - if parsedArgs.fluorescence_input is not None: - fluorescenceInputFilePath = Path(parsedArgs.fluorescence_input.name) + if parsed_args.fluorescence_input is not None: + fluorescence_input_file_path = Path(parsed_args.fluorescence_input.name) - if parsedArgs.fluorescence_output is not None: - fluorescenceOutputFilePath = Path(parsedArgs.fluorescence_output.name) + if parsed_args.fluorescence_output is not None: + fluorescence_output_file_path = Path(parsed_args.fluorescence_output.name) - return model.batchModeExecute( + return model.batch_mode_execute( action, - inputFilePath, - outputFilePath, - fluorescenceInputFilePath=fluorescenceInputFilePath, - fluorescenceOutputFilePath=fluorescenceOutputFilePath, + input_file_path, + output_file_path, + fluorescence_input_file_path=fluorescence_input_file_path, + fluorescence_output_file_path=fluorescence_output_file_path, ) try: from PyQt5.QtWidgets import QApplication except ModuleNotFoundError: + logger.warning('PyQt5 not found.') return 0 # QApplication expects the first argument to be the program name - app = QApplication(sys.argv[:1] + unparsedArgs) - verifyAllArgumentsParsed(parser, app.arguments()[1:]) + app = QApplication(sys.argv[:1] + unparsed_args) + verify_all_arguments_parsed(parser, app.arguments()[1:]) from ptychodus.view import ViewCore - view = ViewCore.createInstance(parsedArgs.dev) + view = ViewCore() from ptychodus.controller import ControllerCore - controller = ControllerCore(model, view) - controller.showMainWindow(versionString()) + controller = ControllerCore(model, view, is_developer_mode_enabled=parsed_args.dev) + controller.show_main_window(version_string()) return app.exec() diff --git a/src/ptychodus/api/constants.py b/src/ptychodus/api/constants.py deleted file mode 100644 index a377bc38..00000000 --- a/src/ptychodus/api/constants.py +++ /dev/null @@ -1,6 +0,0 @@ -from typing import Final - -# Source: https://physics.nist.gov/cuu/Constants/index.html -ELECTRON_VOLT_J: Final[float] = 1.602176634e-19 -LIGHT_SPEED_M_PER_S: Final[float] = 299792458 -PLANCK_CONSTANT_J_PER_HZ: Final[float] = 6.62607015e-34 diff --git a/src/ptychodus/api/fluorescence.py b/src/ptychodus/api/fluorescence.py index b1352e4c..723a85df 100644 --- a/src/ptychodus/api/fluorescence.py +++ b/src/ptychodus/api/fluorescence.py @@ -32,14 +32,14 @@ def enhance(self, dataset: FluorescenceDataset, product: Product) -> Fluorescenc class FluorescenceFileReader(ABC): @abstractmethod - def read(self, filePath: Path) -> FluorescenceDataset: + def read(self, file_path: Path) -> FluorescenceDataset: """reads a fluorescence dataset from file""" pass class FluorescenceFileWriter(ABC): @abstractmethod - def write(self, filePath: Path, dataset: FluorescenceDataset) -> None: + def write(self, file_path: Path, dataset: FluorescenceDataset) -> None: """writes a fluorescence dataset to file""" pass diff --git a/src/ptychodus/api/geometry.py b/src/ptychodus/api/geometry.py index 324c6aab..ec10e5c8 100644 --- a/src/ptychodus/api/geometry.py +++ b/src/ptychodus/api/geometry.py @@ -6,29 +6,56 @@ T = TypeVar('T', int, float, Decimal) +@dataclass(frozen=True) +class AffineTransform: + a00: float + a01: float + a02: float + + a10: float + a11: float + a12: float + + def __call__(self, y: float, x: float) -> tuple[float, float]: + yp = self.a00 * y + self.a01 * x + self.a02 + xp = self.a10 * y + self.a11 * x + self.a12 + return yp, xp + + @dataclass(frozen=True) class PixelGeometry: - widthInMeters: float - heightInMeters: float + width_m: float + height_m: float - def __repr__(self) -> str: - return f'{type(self).__name__}({self.widthInMeters}, {self.heightInMeters})' + @property + def area_m2(self) -> float: + return self.width_m * self.height_m + + @property + def aspect_ratio(self) -> float: + return self.width_m / self.height_m + + def copy(self) -> PixelGeometry: + return PixelGeometry( + width_m=float(self.width_m), + height_m=float(self.height_m), + ) @dataclass(frozen=True) class ImageExtent: - widthInPixels: int - heightInPixels: int + width_px: int + height_px: int @property def size(self) -> int: """returns the number of pixels in the image""" - return self.widthInPixels * self.heightInPixels + return self.width_px * self.height_px @property def shape(self) -> tuple[int, int]: - """returns the image shape (heightInPixels, widthInPixels) tuple""" - return self.heightInPixels, self.widthInPixels + """returns the image shape (height_px, width_px) tuple""" + return self.height_px, self.width_px def __eq__(self, other: object) -> bool: if isinstance(other, ImageExtent): @@ -36,18 +63,12 @@ def __eq__(self, other: object) -> bool: return False - def __repr__(self) -> str: - return f'{type(self).__name__}({self.widthInPixels}, {self.heightInPixels})' - @dataclass(frozen=True) class Point2D: x: float y: float - def __repr__(self) -> str: - return f'{type(self).__name__}({self.x}, {self.y})' - @dataclass(frozen=True) class Line2D: @@ -60,9 +81,6 @@ def lerp(self, alpha: float) -> Point2D: y = beta * self.begin.y + alpha * self.end.y return Point2D(x, y) - def __repr__(self) -> str: - return f'{type(self).__name__}({self.begin}, {self.end})' - @dataclass(frozen=True) class Box2D: @@ -87,9 +105,6 @@ def y_begin(self) -> float: def y_end(self) -> float: return self.y + self.height - def __repr__(self) -> str: - return f'{type(self).__name__}({self.x}, {self.y}, {self.width}, {self.height})' - class Interval(Generic[T]): def __init__(self, lower: T, upper: T) -> None: @@ -97,14 +112,14 @@ def __init__(self, lower: T, upper: T) -> None: self.upper: T = upper @classmethod - def createProper(self, a: T, b: T) -> Interval[T]: + def create_proper(cls, a: T, b: T) -> Interval[T]: if b < a: return Interval[T](b, a) else: return Interval[T](a, b) @property - def isEmpty(self) -> bool: + def is_empty(self) -> bool: return self.upper < self.lower def clamp(self, value: T) -> T: diff --git a/src/ptychodus/api/object.py b/src/ptychodus/api/object.py index 60c2cfc0..617da303 100644 --- a/src/ptychodus/api/object.py +++ b/src/ptychodus/api/object.py @@ -3,237 +3,216 @@ from collections.abc import Sequence from dataclasses import dataclass from pathlib import Path -from typing import Any, TypeAlias import numpy -import numpy.typing from .geometry import ImageExtent, PixelGeometry from .scan import ScanPoint +from .typing import ComplexArrayType -ObjectArrayType: TypeAlias = numpy.typing.NDArray[numpy.complexfloating[Any, Any]] + +@dataclass(frozen=True) +class ObjectCenter: + position_x_m: float + position_y_m: float + + def copy(self) -> ObjectCenter: + return ObjectCenter( + position_x_m=float(self.position_x_m), + position_y_m=float(self.position_y_m), + ) @dataclass(frozen=True) class ObjectPoint: index: int - positionXInPixels: float - positionYInPixels: float + position_x_px: float + position_y_px: float @dataclass(frozen=True) class ObjectGeometry: - widthInPixels: int - heightInPixels: int - pixelWidthInMeters: float - pixelHeightInMeters: float - centerXInMeters: float - centerYInMeters: float + width_px: int + height_px: int + pixel_width_m: float + pixel_height_m: float + center_x_m: float + center_y_m: float @property - def widthInMeters(self) -> float: - return self.widthInPixels * self.pixelWidthInMeters + def width_m(self) -> float: + return self.width_px * self.pixel_width_m @property - def heightInMeters(self) -> float: - return self.heightInPixels * self.pixelHeightInMeters + def height_m(self) -> float: + return self.height_px * self.pixel_height_m @property - def minimumXInMeters(self) -> float: - return self.centerXInMeters - self.widthInMeters / 2.0 + def minimum_x_m(self) -> float: + return self.center_x_m - self.width_m / 2.0 @property - def minimumYInMeters(self) -> float: - return self.centerYInMeters - self.heightInMeters / 2.0 + def minimum_y_m(self) -> float: + return self.center_y_m - self.height_m / 2.0 - def getPixelGeometry(self) -> PixelGeometry: + def get_pixel_geometry(self) -> PixelGeometry: return PixelGeometry( - widthInMeters=self.pixelWidthInMeters, - heightInMeters=self.pixelHeightInMeters, + width_m=self.pixel_width_m, + height_m=self.pixel_height_m, + ) + + def get_center(self) -> ObjectCenter: + return ObjectCenter( + position_x_m=self.center_x_m, + position_y_m=self.center_y_m, ) - def mapObjectPointToScanPoint(self, point: ObjectPoint) -> ScanPoint: - rx_px = self.widthInPixels / 2 - ry_px = self.heightInPixels / 2 - dx_m = self.pixelWidthInMeters - dy_m = self.pixelHeightInMeters + def map_object_point_to_scan_point(self, point: ObjectPoint) -> ScanPoint: + rx_px = self.width_px / 2 + ry_px = self.height_px / 2 + dx_m = self.pixel_width_m + dy_m = self.pixel_height_m - x_m = self.centerXInMeters + dx_m * (point.positionXInPixels - rx_px) - y_m = self.centerYInMeters + dy_m * (point.positionYInPixels - ry_px) + x_m = self.center_x_m + dx_m * (point.position_x_px - rx_px) + y_m = self.center_y_m + dy_m * (point.position_y_px - ry_px) return ScanPoint(point.index, x_m, y_m) - def mapScanPointToObjectPoint(self, point: ScanPoint) -> ObjectPoint: - rx_px = self.widthInPixels / 2 - ry_px = self.heightInPixels / 2 - dx_m = self.pixelWidthInMeters - dy_m = self.pixelHeightInMeters + def map_scan_point_to_object_point(self, point: ScanPoint) -> ObjectPoint: + rx_px = self.width_px / 2 + ry_px = self.height_px / 2 + dx_m = self.pixel_width_m + dy_m = self.pixel_height_m - x_px = (point.positionXInMeters - self.centerXInMeters) / dx_m + rx_px - y_px = (point.positionYInMeters - self.centerYInMeters) / dy_m + ry_px + x_px = (point.position_x_m - self.center_x_m) / dx_m + rx_px + y_px = (point.position_y_m - self.center_y_m) / dy_m + ry_px return ObjectPoint(point.index, x_px, y_px) def contains(self, geometry: ObjectGeometry) -> bool: - dx = self.centerXInMeters - geometry.centerXInMeters - dy = self.centerYInMeters - geometry.centerYInMeters - dw = self.widthInMeters - geometry.widthInMeters - dh = self.heightInMeters - geometry.heightInMeters + dx = self.center_x_m - geometry.center_x_m + dy = self.center_y_m - geometry.center_y_m + dw = self.width_m - geometry.width_m + dh = self.height_m - geometry.height_m return abs(dx) <= dw and abs(dy) <= dh class ObjectGeometryProvider(ABC): @abstractmethod - def getObjectGeometry(self) -> ObjectGeometry: + def get_object_geometry(self) -> ObjectGeometry: pass class Object: def __init__( self, - array: ObjectArrayType | None = None, - layerDistanceInMeters: Sequence[float] | None = None, - *, - pixelWidthInMeters: float = 0.0, - pixelHeightInMeters: float = 0.0, - centerXInMeters: float = 0.0, - centerYInMeters: float = 0.0, + array: ComplexArrayType | None, + pixel_geometry: PixelGeometry | None, + center: ObjectCenter | None, + layer_spacing_m: Sequence[float] = [], ) -> None: if array is None: - self._array = numpy.zeros((1, 0, 0), dtype=complex) - else: - if numpy.iscomplexobj(array): - if array.ndim == 2: - self._array = array[numpy.newaxis, :, :] - elif array.ndim == 3: + self._array: ComplexArrayType = numpy.zeros((1, 0, 0), dtype=complex) + elif numpy.iscomplexobj(array): + match array.ndim: + case 2: + self._array = array[numpy.newaxis, ...] + case 3: self._array = array - else: + case _: raise ValueError('Object must be 2- or 3-dimensional ndarray.') - else: - raise TypeError('Object must be a complex-valued ndarray') - - if layerDistanceInMeters is None: - self._layerDistanceInMeters: Sequence[float] = [numpy.inf] else: - self._layerDistanceInMeters = layerDistanceInMeters + raise TypeError('Object must be a complex-valued ndarray') - expectedLayers = self.numberOfLayers - actualLayers = len(self._layerDistanceInMeters) + self._pixel_geometry = pixel_geometry + self._center = center + self._layer_spacing_m = layer_spacing_m - if actualLayers < expectedLayers: - raise ValueError(f'Expected {expectedLayers} layer distances; got {actualLayers}!') + expected_layers = self._array.shape[-3] + actual_layers = len(layer_spacing_m) + 1 - self._pixelWidthInMeters = pixelWidthInMeters - self._pixelHeightInMeters = pixelHeightInMeters - self._centerXInMeters = centerXInMeters - self._centerYInMeters = centerYInMeters + if actual_layers != expected_layers: + raise ValueError(f'Expected {expected_layers} layers; got {actual_layers}!') def copy(self) -> Object: return Object( - array=numpy.array(self._array), - layerDistanceInMeters=list(self._layerDistanceInMeters), - pixelWidthInMeters=float(self._pixelWidthInMeters), - pixelHeightInMeters=float(self._pixelHeightInMeters), - centerXInMeters=float(self._centerXInMeters), - centerYInMeters=float(self._centerYInMeters), + array=self._array.copy(), + pixel_geometry=None if self._pixel_geometry is None else self._pixel_geometry.copy(), + center=None if self._center is None else self._center.copy(), + layer_spacing_m=list(self._layer_spacing_m), ) - @property - def array(self) -> ObjectArrayType: + def get_array(self) -> ComplexArrayType: return self._array @property - def dataType(self) -> numpy.dtype: + def dtype(self) -> numpy.dtype: return self._array.dtype @property - def numberOfLayers(self) -> int: - return self._array.shape[-3] - - @property - def sizeInBytes(self) -> int: + def nbytes(self) -> int: return self._array.nbytes @property - def widthInPixels(self) -> int: + def width_px(self) -> int: return self._array.shape[-1] @property - def heightInPixels(self) -> int: + def height_px(self) -> int: return self._array.shape[-2] @property - def pixelWidthInMeters(self) -> float: - return self._pixelWidthInMeters + def num_layers(self) -> int: + return self._array.shape[-3] - @property - def pixelHeightInMeters(self) -> float: - return self._pixelHeightInMeters + def get_pixel_geometry(self) -> PixelGeometry: + if self._pixel_geometry is None: + raise ValueError('Missing object pixel geometry!') - @property - def centerXInMeters(self) -> float: - return self._centerXInMeters + return self._pixel_geometry - @property - def centerYInMeters(self) -> float: - return self._centerYInMeters + def get_center(self) -> ObjectCenter: + if self._center is None: + raise ValueError('Missing object center!') - def getGeometry(self) -> ObjectGeometry: - return ObjectGeometry( - widthInPixels=self.widthInPixels, - heightInPixels=self.heightInPixels, - pixelWidthInMeters=self._pixelWidthInMeters, - pixelHeightInMeters=self._pixelHeightInMeters, - centerXInMeters=self._centerXInMeters, - centerYInMeters=self._centerYInMeters, - ) + return self._center - def getPixelGeometry(self) -> PixelGeometry: - return PixelGeometry( - widthInMeters=self._pixelWidthInMeters, - heightInMeters=self._pixelHeightInMeters, + def get_geometry(self) -> ObjectGeometry: + pixel_geometry = self.get_pixel_geometry() + center = self.get_center() + + return ObjectGeometry( + width_px=self.width_px, + height_px=self.height_px, + pixel_width_m=pixel_geometry.width_m, + pixel_height_m=pixel_geometry.height_m, + center_x_m=center.position_x_m, + center_y_m=center.position_y_m, ) - def getLayer(self, number: int) -> ObjectArrayType: + def get_layer(self, number: int) -> ComplexArrayType: return self._array[number, :, :] - def getLayersFlattened(self) -> ObjectArrayType: + def get_layers_flattened(self) -> ComplexArrayType: return numpy.prod(self._array, axis=-3) @property - def layerDistanceInMeters(self) -> Sequence[float]: - return self._layerDistanceInMeters - - def getLayerDistanceInMeters(self, number: int) -> float: - return self._layerDistanceInMeters[number] - - def getTotalLayerDistanceInMeters(self) -> float: - return sum(self._layerDistanceInMeters[:-1]) + def layer_spacing_m(self) -> Sequence[float]: + return self._layer_spacing_m - -class ObjectInterpolator(ABC): - @abstractmethod - def getPatch(self, patchCenter: ScanPoint, patchExtent: ImageExtent) -> Object: - """returns an interpolated patch from the object array""" - pass - - -class ObjectPhaseCenteringStrategy(ABC): - @abstractmethod - def __call__(self, array: ObjectArrayType) -> ObjectArrayType: - """returns the phase-centered array""" - pass + def get_total_thickness_m(self) -> float: + return sum(self._layer_spacing_m) class ObjectFileReader(ABC): @abstractmethod - def read(self, filePath: Path) -> Object: + def read(self, file_path: Path) -> Object: """reads an object from file""" pass class ObjectFileWriter(ABC): @abstractmethod - def write(self, filePath: Path, object_: Object) -> None: + def write(self, file_path: Path, object_: Object) -> None: """writes an object to file""" pass diff --git a/src/ptychodus/api/observer.py b/src/ptychodus/api/observer.py index de024cb5..02ac4718 100644 --- a/src/ptychodus/api/observer.py +++ b/src/ptychodus/api/observer.py @@ -15,65 +15,78 @@ class Observer(ABC): @abstractmethod - def update(self, observable: Observable) -> None: + def _update(self, observable: Observable) -> None: pass class Observable: def __init__(self) -> None: - self._observerList: list[Observer] = list() + self._observer_list: list[Observer] = list() + self._block_notifications = False + self._pending_notify = False - def addObserver(self, observer: Observer) -> None: - if observer not in self._observerList: - self._observerList.append(observer) + def add_observer(self, observer: Observer) -> None: + if observer not in self._observer_list: + self._observer_list.append(observer) - def removeObserver(self, observer: Observer) -> None: + def remove_observer(self, observer: Observer) -> None: try: - self._observerList.remove(observer) + self._observer_list.remove(observer) except ValueError: pass - def notifyObservers(self) -> None: - for observer in self._observerList: - observer.update(self) + def block_notifications(self, block: bool) -> None: + self._block_notifications = block + + if self._pending_notify: + self.notify_observers() + self._pending_notify = False + + def notify_observers(self) -> None: + if self._block_notifications: + self._pending_notify = True + return + + for observer in self._observer_list: + observer._update(self) class SequenceObserver(Generic[T], ABC): @abstractmethod - def handleItemInserted(self, index: int, item: T) -> None: + def handle_item_inserted(self, index: int, item: T) -> None: pass @abstractmethod - def handleItemChanged(self, index: int, item: T) -> None: + def handle_item_changed(self, index: int, item: T) -> None: pass @abstractmethod - def handleItemRemoved(self, index: int, item: T) -> None: + def handle_item_removed(self, index: int, item: T) -> None: pass class ObservableSequence(Sequence[T]): def __init__(self) -> None: - self._observerList: list[SequenceObserver[T]] = list() + self._observer_list: list[SequenceObserver[T]] = list() - def addObserver(self, observer: SequenceObserver[T]) -> None: - if observer not in self._observerList: - self._observerList.append(observer) + def add_observer(self, observer: SequenceObserver[T]) -> None: + if observer not in self._observer_list: + self._observer_list.append(observer) - def removeObserver(self, observer: SequenceObserver[T]) -> None: + def remove_observer(self, observer: SequenceObserver[T]) -> None: try: - self._observerList.remove(observer) + self._observer_list.remove(observer) except ValueError: pass - def notifyObserversItemInserted(self, index: int, item: T) -> None: - for observer in self._observerList: - observer.handleItemInserted(index, item) + def notify_observers_item_inserted(self, index: int, item: T) -> None: + for observer in self._observer_list: + observer.handle_item_inserted(index, item) - def notifyObserversItemChanged(self, index: int, item: T) -> None: - for observer in self._observerList: - observer.handleItemChanged(index, item) + def notify_observers_item_changed(self, index: int, item: T) -> None: + for observer in self._observer_list: + observer.handle_item_changed(index, item) - def notifyObserversItemRemoved(self, index: int, item: T) -> None: - for observer in self._observerList: - observer.handleItemRemoved(index, item) + def notify_observers_item_removed(self, index: int, item: T) -> None: + for observer in self._observer_list: + observer.handle_item_removed(index, item) diff --git a/src/ptychodus/api/parametric.py b/src/ptychodus/api/parametric.py index 6a00d4f2..5f7bbba8 100644 --- a/src/ptychodus/api/parametric.py +++ b/src/ptychodus/api/parametric.py @@ -24,36 +24,36 @@ def __init__(self, parent: Parameter[T] | None = None) -> None: self._parent = parent @abstractmethod - def getValue(self) -> T: + def get_value(self) -> T: pass @abstractmethod - def setValue(self, value: T, *, notify: bool = True) -> None: + def set_value(self, value: T, *, notify: bool = True) -> None: pass @abstractmethod - def getValueAsString(self) -> str: + def get_value_as_string(self) -> str: pass @abstractmethod - def setValueFromString(self, value: str) -> None: + def set_value_from_string(self, value: str) -> None: pass @abstractmethod def copy(self) -> Parameter[T]: pass - def syncValueToParent(self) -> None: + def sync_value_to_parent(self) -> None: if self._parent is None: - logger.warning('syncValueToParent: parent is None!') + logger.warning('sync_value_to_parent: parent is None!') else: - self._parent.setValue(self.getValue()) + self._parent.set_value(self.get_value()) - def syncValueFromParent(self) -> None: + def sync_value_from_parent(self) -> None: if self._parent is None: - logger.warning('syncValueFromParent: parent is None!') + logger.warning('sync_value_from_parent: parent is None!') else: - self.setValue(self._parent.getValue()) + self.set_value(self._parent.get_value()) class ParameterBase(Parameter[T]): @@ -61,17 +61,17 @@ def __init__(self, value: T, parent: Parameter[T] | None) -> None: super().__init__(parent) self._value = value - def getValue(self) -> T: + def get_value(self) -> T: return self._value - def setValue(self, value: T, *, notify: bool = True) -> None: + def set_value(self, value: T, *, notify: bool = True) -> None: if self._value != value: self._value = value if notify: - self.notifyObservers() + self.notify_observers() - def getValueAsString(self) -> str: + def get_value_as_string(self) -> str: return repr(self._value) @@ -79,25 +79,25 @@ class StringParameter(ParameterBase[str]): def __init__(self, value: str, parent: StringParameter | None) -> None: super().__init__(value, parent) - def setValueFromString(self, value: str) -> None: - self.setValue(str(value)) + def set_value_from_string(self, value: str) -> None: + self.set_value(str(value)) - def getValueAsString(self) -> str: + def get_value_as_string(self) -> str: return str(self._value) def copy(self) -> StringParameter: - return StringParameter(self.getValue(), self) + return StringParameter(self.get_value(), self) class PathParameter(ParameterBase[Path]): def __init__(self, value: Path, parent: PathParameter | None) -> None: super().__init__(value, parent) - def setValueFromString(self, value: str) -> None: - self.setValue(Path(value)) + def set_value_from_string(self, value: str) -> None: + self.set_value(Path(value)) - def changePathPrefix(self, find_path_prefix: Path, replacement_path_prefix: Path) -> Path: - value = self.getValue() + def change_path_prefix(self, find_path_prefix: Path, replacement_path_prefix: Path) -> Path: + value = self.get_value() try: relative_path = value.resolve().relative_to(find_path_prefix) @@ -108,25 +108,25 @@ def changePathPrefix(self, find_path_prefix: Path, replacement_path_prefix: Path return value - def getValueAsString(self) -> str: + def get_value_as_string(self) -> str: return str(self._value) def copy(self) -> PathParameter: - return PathParameter(self.getValue(), self) + return PathParameter(self.get_value(), self) class UUIDParameter(ParameterBase[UUID]): def __init__(self, value: UUID, parent: UUIDParameter | None) -> None: super().__init__(value, parent) - def setValueFromString(self, value: str) -> None: - self.setValue(UUID(value)) + def set_value_from_string(self, value: str) -> None: + self.set_value(UUID(value)) - def getValueAsString(self) -> str: + def get_value_as_string(self) -> str: return str(self._value) def copy(self) -> UUIDParameter: - return UUIDParameter(self.getValue(), self) + return UUIDParameter(self.get_value(), self) class BooleanParameter(ParameterBase[bool]): @@ -135,11 +135,11 @@ class BooleanParameter(ParameterBase[bool]): def __init__(self, value: bool, parent: BooleanParameter | None) -> None: super().__init__(value, parent) - def setValueFromString(self, value: str) -> None: - self.setValue(value.lower() in BooleanParameter.TRUE_VALUES) + def set_value_from_string(self, value: str) -> None: + self.set_value(value.lower() in BooleanParameter.TRUE_VALUES) def copy(self) -> BooleanParameter: - return BooleanParameter(self.getValue(), self) + return BooleanParameter(self.get_value(), self) class IntegerParameter(ParameterBase[int]): @@ -155,14 +155,14 @@ def __init__( self._minimum = minimum self._maximum = maximum - def getMinimum(self) -> int | None: + def get_minimum(self) -> int | None: return self._minimum - def getMaximum(self) -> int | None: + def get_maximum(self) -> int | None: return self._maximum - def getValue(self) -> int: - value = super().getValue() + def get_value(self) -> int: + value = super().get_value() if self._minimum is not None: value = max(self._minimum, value) @@ -172,12 +172,12 @@ def getValue(self) -> int: return value - def setValueFromString(self, value: str) -> None: - self.setValue(int(value)) + def set_value_from_string(self, value: str) -> None: + self.set_value(int(value)) def copy(self) -> IntegerParameter: return IntegerParameter( - self.getValue(), self, minimum=self.getMinimum(), maximum=self.getMaximum() + self.get_value(), self, minimum=self.get_minimum(), maximum=self.get_maximum() ) @@ -194,14 +194,14 @@ def __init__( self._minimum = minimum self._maximum = maximum - def getMinimum(self) -> float | None: + def get_minimum(self) -> float | None: return self._minimum - def getMaximum(self) -> float | None: + def get_maximum(self) -> float | None: return self._maximum - def getValue(self) -> float: - value = super().getValue() + def get_value(self) -> float: + value = super().get_value() if self._minimum is not None: value = max(self._minimum, value) @@ -211,17 +211,66 @@ def getValue(self) -> float: return value - def setValueFromString(self, value: str) -> None: - self.setValue(float(value)) + def set_value_from_string(self, value: str) -> None: + self.set_value(float(value)) def copy(self) -> RealParameter: return RealParameter( - self.getValue(), self, minimum=self.getMinimum(), maximum=self.getMaximum() + self.get_value(), self, minimum=self.get_minimum(), maximum=self.get_maximum() ) -class RealArrayParameter(ParameterBase[MutableSequence[float]]): - def __init__(self, value: Sequence[float], parent: RealArrayParameter | None) -> None: +class IntegerSequenceParameter(ParameterBase[MutableSequence[int]]): + def __init__(self, value: Sequence[int], parent: IntegerSequenceParameter | None) -> None: + super().__init__(list(value), parent) + + def __iter__(self) -> Iterator[int]: + return iter(self._value) + + def __getitem__(self, index: int) -> int: + return self._value[index] + + def __setitem__(self, index: int, value: int) -> None: + if self._value[index] != value: + self._value[index] = value + self.notify_observers() + + def __delitem__(self, index: int) -> None: + del self._value[index] + self.notify_observers() + + def insert(self, index: int, value: int) -> None: + self._value.insert(index, value) + self.notify_observers() + + def __len__(self) -> int: + return len(self._value) + + def set_value(self, value: Sequence[int], *, notify: bool = True) -> None: + if self._value != value: + self._value = list(value) + + if notify: + self.notify_observers() + + def get_value_as_string(self) -> str: + return ','.join(repr(value) for value in self) + + def set_value_from_string(self, value: str) -> None: + new_value: list[int] = list() + + for xstr in value.split(','): + if xstr: + new_value.append(int(xstr)) + + self.set_value(new_value) + + def copy(self) -> IntegerSequenceParameter: + return IntegerSequenceParameter(self.get_value(), self) + + +class RealSequenceParameter(ParameterBase[MutableSequence[float]]): + def __init__(self, value: Sequence[float], parent: RealSequenceParameter | None) -> None: super().__init__(list(value), parent) def __iter__(self) -> Iterator[float]: @@ -233,48 +282,49 @@ def __getitem__(self, index: int) -> float: def __setitem__(self, index: int, value: float) -> None: if self._value[index] != value: self._value[index] = value - self.notifyObservers() + self.notify_observers() def __delitem__(self, index: int) -> None: del self._value[index] - self.notifyObservers() + self.notify_observers() def insert(self, index: int, value: float) -> None: self._value.insert(index, value) - self.notifyObservers() + self.notify_observers() def __len__(self) -> int: return len(self._value) - def setValue(self, value: Sequence[float], *, notify: bool = True) -> None: + def set_value(self, value: Sequence[float], *, notify: bool = True) -> None: if self._value != value: self._value = list(value) if notify: - self.notifyObservers() + self.notify_observers() - def getValueAsString(self) -> str: + def get_value_as_string(self) -> str: return ','.join(repr(value) for value in self) - def setValueFromString(self, value: str) -> None: + def set_value_from_string(self, value: str) -> None: tmp: list[float] = list() for xstr in value.split(','): - try: - x = float(xstr) - except ValueError: - x = float('nan') + if xstr: + try: + x = float(xstr) + except ValueError: + x = float('nan') - tmp.append(x) + tmp.append(x) - self.setValue(tmp) + self.set_value(tmp) - def copy(self) -> RealArrayParameter: - return RealArrayParameter(self.getValue(), self) + def copy(self) -> RealSequenceParameter: + return RealSequenceParameter(self.get_value(), self) -class ComplexArrayParameter(ParameterBase[MutableSequence[complex]]): - def __init__(self, value: Sequence[complex], parent: ComplexArrayParameter | None) -> None: +class ComplexSequenceParameter(ParameterBase[MutableSequence[complex]]): + def __init__(self, value: Sequence[complex], parent: ComplexSequenceParameter | None) -> None: super().__init__(list(value), parent) def __iter__(self) -> Iterator[complex]: @@ -286,44 +336,45 @@ def __getitem__(self, index: int) -> complex: def __setitem__(self, index: int, value: complex) -> None: if self._value[index] != value: self._value[index] = value - self.notifyObservers() + self.notify_observers() def __delitem__(self, index: int) -> None: del self._value[index] - self.notifyObservers() + self.notify_observers() def insert(self, index: int, value: complex) -> None: self._value.insert(index, value) - self.notifyObservers() + self.notify_observers() def __len__(self) -> int: return len(self._value) - def setValue(self, value: Sequence[complex], *, notify: bool = True) -> None: + def set_value(self, value: Sequence[complex], *, notify: bool = True) -> None: if self._value != value: self._value = list(value) if notify: - self.notifyObservers() + self.notify_observers() - def getValueAsString(self) -> str: + def get_value_as_string(self) -> str: return ','.join(repr(value) for value in self) - def setValueFromString(self, value: str) -> None: + def set_value_from_string(self, value: str) -> None: tmp: list[complex] = list() for xstr in value.split(','): - try: - x = complex(xstr) - except ValueError: - x = float('nan') * 1j + if xstr: + try: + x = complex(xstr) + except ValueError: + x = float('nan') * 1j - tmp.append(x) + tmp.append(x) - self.setValue(tmp) + self.set_value(tmp) - def copy(self) -> ComplexArrayParameter: - return ComplexArrayParameter(self.getValue(), self) + def copy(self) -> ComplexSequenceParameter: + return ComplexSequenceParameter(self.get_value(), self) class ParameterGroup(Observable, Observer): @@ -335,86 +386,95 @@ def __init__(self) -> None: def parameters(self) -> Mapping[str, Parameter[Any]]: return self._parameters - def _addParameter(self, name: str, parameter: Parameter[Any]) -> None: + def _add_parameter(self, name: str, parameter: Parameter[Any]) -> None: if self._parameters.setdefault(name, parameter) is parameter: - parameter.addObserver(self) + parameter.add_observer(self) else: raise ValueError(f'Parameter "{name}" already exists!') - def createStringParameter(self, name: str, value: str) -> StringParameter: + def create_string_parameter(self, name: str, value: str) -> StringParameter: parameter = StringParameter(value, parent=None) - self._addParameter(name, parameter) + self._add_parameter(name, parameter) return parameter - def createPathParameter(self, name: str, value: Path) -> PathParameter: + def create_path_parameter(self, name: str, value: Path) -> PathParameter: parameter = PathParameter(value, parent=None) - self._addParameter(name, parameter) + self._add_parameter(name, parameter) return parameter - def createUUIDParameter(self, name: str, value: UUID) -> UUIDParameter: + def create_uuid_parameter(self, name: str, value: UUID) -> UUIDParameter: parameter = UUIDParameter(value, parent=None) - self._addParameter(name, parameter) + self._add_parameter(name, parameter) return parameter - def createBooleanParameter(self, name: str, value: bool) -> BooleanParameter: + def create_boolean_parameter(self, name: str, value: bool) -> BooleanParameter: parameter = BooleanParameter(value, parent=None) - self._addParameter(name, parameter) + self._add_parameter(name, parameter) return parameter - def createIntegerParameter( + def create_integer_parameter( self, name: str, value: int, *, minimum: int | None = None, maximum: int | None = None ) -> IntegerParameter: parameter = IntegerParameter(value, parent=None, minimum=minimum, maximum=maximum) - self._addParameter(name, parameter) + self._add_parameter(name, parameter) + return parameter + + def create_integer_sequence_parameter( + self, name: str, value: Sequence[int] + ) -> IntegerSequenceParameter: + parameter = IntegerSequenceParameter(value, parent=None) + self._add_parameter(name, parameter) return parameter - def createRealParameter( + def create_real_parameter( self, name: str, value: float, *, minimum: float | None = None, maximum: float | None = None ) -> RealParameter: parameter = RealParameter(value, parent=None, minimum=minimum, maximum=maximum) - self._addParameter(name, parameter) + self._add_parameter(name, parameter) return parameter - def createRealArrayParameter(self, name: str, value: Sequence[float]) -> RealArrayParameter: - parameter = RealArrayParameter(value, parent=None) - self._addParameter(name, parameter) + def create_real_sequence_parameter( + self, name: str, value: Sequence[float] + ) -> RealSequenceParameter: + parameter = RealSequenceParameter(value, parent=None) + self._add_parameter(name, parameter) return parameter - def createComplexArrayParameter( + def create_complex_sequence_parameter( self, name: str, value: Sequence[complex] - ) -> ComplexArrayParameter: - parameter = ComplexArrayParameter(value, parent=None) - self._addParameter(name, parameter) + ) -> ComplexSequenceParameter: + parameter = ComplexSequenceParameter(value, parent=None) + self._add_parameter(name, parameter) return parameter def groups(self) -> Mapping[str, ParameterGroup]: return self._groups - def _addGroup(self, name: str, group: ParameterGroup, *, observe: bool = False) -> None: + def _add_group(self, name: str, group: ParameterGroup, *, observe: bool = False) -> None: if self._groups.setdefault(name, group) is group: if observe: - group.addObserver(self) + group.add_observer(self) else: raise ValueError(f'Group "{name}" already exists!') - def _removeGroup(self, name: str) -> None: + def _remove_group(self, name: str) -> None: try: group = self._groups.pop(name) except KeyError: pass else: - group.removeObserver(self) + group.remove_observer(self) - def createGroup(self, name: str) -> ParameterGroup: + def create_group(self, name: str) -> ParameterGroup: group = ParameterGroup() - self._addGroup(name, group) + self._add_group(name, group) return group - def getGroup(self, name: str) -> ParameterGroup: + def get_group(self, name: str) -> ParameterGroup: return self._groups[name] - def update(self, observable: Observable) -> None: + def _update(self, observable: Observable) -> None: if observable in self._parameters.values(): - self.notifyObservers() + self.notify_observers() elif observable in self._groups.values(): - self.notifyObservers() + self.notify_observers() diff --git a/src/ptychodus/api/patterns.py b/src/ptychodus/api/patterns.py index 877f0446..af416854 100644 --- a/src/ptychodus/api/patterns.py +++ b/src/ptychodus/api/patterns.py @@ -2,7 +2,6 @@ from abc import ABC, abstractmethod from collections.abc import Sequence from dataclasses import dataclass -from enum import Enum, auto from pathlib import Path from typing import overload, Any, TypeAlias @@ -10,106 +9,84 @@ import numpy.typing from .geometry import ImageExtent, PixelGeometry -from .observer import Observable from .tree import SimpleTreeNode -BooleanArrayType: TypeAlias = numpy.typing.NDArray[numpy.bool_] -DiffractionPatternArrayType: TypeAlias = numpy.typing.NDArray[numpy.integer[Any]] -DiffractionPatternIndexes: TypeAlias = numpy.typing.NDArray[numpy.integer[Any]] +PatternDataType: TypeAlias = numpy.typing.NDArray[numpy.integer[Any]] +PatternIndexesType: TypeAlias = numpy.typing.NDArray[numpy.integer[Any]] @dataclass(frozen=True) class CropCenter: - positionXInPixels: int - positionYInPixels: int + position_x_px: int + position_y_px: int -class DiffractionPatternState(Enum): - UNKNOWN = auto() - MISSING = auto() - FOUND = auto() - LOADED = auto() - - -class DiffractionPatternArray(Observable): +class DiffractionPatternArray: @abstractmethod - def getLabel(self) -> str: + def get_label(self) -> str: pass @abstractmethod - def getIndex(self) -> int: + def get_indexes(self) -> PatternIndexesType: pass @abstractmethod - def getData(self) -> DiffractionPatternArrayType: + def get_data(self) -> PatternDataType: pass - def getNumberOfPatterns(self) -> int: - return self.getData().shape[0] - - @abstractmethod - def getState(self) -> DiffractionPatternState: - pass + def get_num_patterns(self) -> int: + return self.get_data().shape[0] class SimpleDiffractionPatternArray(DiffractionPatternArray): def __init__( self, label: str, - index: int, - data: DiffractionPatternArrayType, - state: DiffractionPatternState, + indexes: PatternIndexesType, + data: PatternDataType, ) -> None: super().__init__() self._label = label - self._index = index + self._indexes = indexes self._data = data - self._state = state - - @classmethod - def createNullInstance(cls) -> SimpleDiffractionPatternArray: - data = numpy.zeros((1, 1, 1), dtype=numpy.uint16) - state = DiffractionPatternState.MISSING - return cls('Null', 0, data, state) - def getLabel(self) -> str: + def get_label(self) -> str: return self._label - def getIndex(self) -> int: - return self._index + def get_indexes(self) -> PatternIndexesType: + return self._indexes - def getData(self) -> DiffractionPatternArrayType: + def get_data(self) -> PatternDataType: return self._data - def getState(self) -> DiffractionPatternState: - return self._state - @dataclass(frozen=True) class DiffractionMetadata: - numberOfPatternsPerArray: int - numberOfPatternsTotal: int - patternDataType: numpy.dtype[numpy.integer[Any]] - detectorDistanceInMeters: float | None = None - detectorExtent: ImageExtent | None = None - detectorPixelGeometry: PixelGeometry | None = None - detectorBitDepth: int | None = None - cropCenter: CropCenter | None = None - probeEnergyInElectronVolts: float | None = None - filePath: Path | None = None + num_patterns_per_array: int + num_patterns_total: int + pattern_dtype: numpy.dtype[numpy.integer[Any]] + detector_distance_m: float | None = None + detector_extent: ImageExtent | None = None + detector_pixel_geometry: PixelGeometry | None = None + detector_bit_depth: int | None = None + crop_center: CropCenter | None = None + probe_photon_count: int | None = None + probe_energy_eV: float | None = None # noqa: N815 + tomography_angle_deg: float | None = None + file_path: Path | None = None @classmethod - def createNullInstance(cls, filePath: Path | None = None) -> DiffractionMetadata: - return cls(0, 0, numpy.dtype(numpy.ubyte), filePath=filePath) + def create_null(cls, file_path: Path | None = None) -> DiffractionMetadata: + return cls(0, 0, numpy.dtype(numpy.ubyte), file_path=file_path) -class DiffractionDataset(Sequence[DiffractionPatternArray], Observable): +class DiffractionDataset(Sequence[DiffractionPatternArray]): @abstractmethod - def getMetadata(self) -> DiffractionMetadata: + def get_metadata(self) -> DiffractionMetadata: pass @abstractmethod - def getContentsTree(self) -> SimpleTreeNode: + def get_contents_tree(self) -> SimpleTreeNode: pass @@ -117,26 +94,26 @@ class SimpleDiffractionDataset(DiffractionDataset): def __init__( self, metadata: DiffractionMetadata, - contentsTree: SimpleTreeNode, - arrayList: list[DiffractionPatternArray], + contents_tree: SimpleTreeNode, + array_list: list[DiffractionPatternArray], ) -> None: super().__init__() self._metadata = metadata - self._contentsTree = contentsTree - self._arrayList = arrayList + self._contents_tree = contents_tree + self._array_list = array_list @classmethod - def createNullInstance(cls, filePath: Path | None = None) -> SimpleDiffractionDataset: - metadata = DiffractionMetadata.createNullInstance(filePath) - contentsTree = SimpleTreeNode.createRoot(list()) - arrayList: list[DiffractionPatternArray] = list() - return cls(metadata, contentsTree, arrayList) + def create_null(cls, file_path: Path | None = None) -> SimpleDiffractionDataset: + metadata = DiffractionMetadata.create_null(file_path) + contents_tree = SimpleTreeNode.create_root(list()) + array_list: list[DiffractionPatternArray] = list() + return cls(metadata, contents_tree, array_list) - def getMetadata(self) -> DiffractionMetadata: + def get_metadata(self) -> DiffractionMetadata: return self._metadata - def getContentsTree(self) -> SimpleTreeNode: - return self._contentsTree + def get_contents_tree(self) -> SimpleTreeNode: + return self._contents_tree @overload def __getitem__(self, index: int) -> DiffractionPatternArray: ... @@ -147,17 +124,17 @@ def __getitem__(self, index: slice) -> Sequence[DiffractionPatternArray]: ... def __getitem__( self, index: int | slice ) -> DiffractionPatternArray | Sequence[DiffractionPatternArray]: - return self._arrayList[index] + return self._array_list[index] def __len__(self) -> int: - return len(self._arrayList) + return len(self._array_list) class DiffractionFileReader(ABC): """interface for plugins that read diffraction files""" @abstractmethod - def read(self, filePath: Path) -> DiffractionDataset: + def read(self, file_path: Path) -> DiffractionDataset: """reads a diffraction dataset from file""" pass @@ -166,6 +143,6 @@ class DiffractionFileWriter(ABC): """interface for plugins that write diffraction files""" @abstractmethod - def write(self, filePath: Path, dataset: DiffractionDataset) -> None: + def write(self, file_path: Path, dataset: DiffractionDataset) -> None: """writes a diffraction dataset to file""" pass diff --git a/src/ptychodus/api/plugins.py b/src/ptychodus/api/plugins.py index d0e0ab4b..7a64d7df 100644 --- a/src/ptychodus/api/plugins.py +++ b/src/ptychodus/api/plugins.py @@ -1,8 +1,9 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Iterable, Iterator from dataclasses import dataclass +from pathlib import Path from types import ModuleType -from typing import Generic, TypeVar, overload +from typing import Generic, TypeVar import importlib import logging import pkgutil @@ -14,12 +15,13 @@ FluorescenceFileWriter, UpscalingStrategy, ) -from .object import ObjectPhaseCenteringStrategy, ObjectFileReader, ObjectFileWriter -from .observer import Observable +from .object import ObjectFileReader, ObjectFileWriter, Object +from .observer import Observable, Observer +from .parametric import StringParameter from .patterns import DiffractionFileReader, DiffractionFileWriter -from .probe import FresnelZonePlate, ProbeFileReader, ProbeFileWriter +from .probe import FresnelZonePlate, ProbeFileReader, ProbeFileWriter, ProbeSequence from .product import ProductFileReader, ProductFileWriter -from .scan import ScanFileReader, ScanFileWriter +from .scan import PositionFileReader, PositionFileWriter, PositionSequence from .workflow import FileBasedWorkflow __all__ = [ @@ -32,94 +34,134 @@ logger = logging.getLogger(__name__) +class ProductPositionFileReader(PositionFileReader): + def __init__(self, reader: ProductFileReader) -> None: + super().__init__() + self._reader = reader + + def read(self, file_path: Path) -> PositionSequence: + product = self._reader.read(file_path) + return product.positions + + +class ProductProbeFileReader(ProbeFileReader): + def __init__(self, reader: ProductFileReader) -> None: + super().__init__() + self._reader = reader + + def read(self, file_path: Path) -> ProbeSequence: + product = self._reader.read(file_path) + return product.probes + + +class ProductObjectFileReader(ObjectFileReader): + def __init__(self, reader: ProductFileReader) -> None: + super().__init__() + self._reader = reader + + def read(self, file_path: Path) -> Object: + product = self._reader.read(file_path) + return product.object_ + + @dataclass(frozen=True) -class PluginEntry(Generic[T]): +class Plugin(Generic[T]): strategy: T - simpleName: str - displayName: str + simple_name: str + display_name: str -class PluginChooser(Sequence[PluginEntry[T]], Observable): +class PluginChooser(Iterable[Plugin[T]], Observable, Observer): def __init__(self) -> None: super().__init__() - self._entryList: list[PluginEntry[T]] = list() - self._currentIndex = 0 + self._registered_plugins: list[Plugin[T]] = list() + self._current_index = 0 + self._parameter: StringParameter | None = None - def getSimpleNameList(self) -> Sequence[str]: - return [entry.simpleName for entry in self._entryList] + def register_plugin(self, strategy: T, *, display_name: str, simple_name: str = '') -> None: + if not simple_name: + simple_name = re.sub(r'\W+', '', display_name) - def getDisplayNameList(self) -> Sequence[str]: - return [entry.displayName for entry in self._entryList] + plugin = Plugin[T](strategy, simple_name, display_name) + self._registered_plugins.append(plugin) + self._registered_plugins.sort(key=lambda x: x.simple_name) + self.notify_observers() - def registerPlugin(self, strategy: T, *, displayName: str, simpleName: str = '') -> None: - if not simpleName: - simpleName = re.sub(r'\W+', '', displayName) + def get_current_plugin(self) -> Plugin[T]: + return self._registered_plugins[self._current_index] - entry = PluginEntry[T](strategy, simpleName, displayName) - self._entryList.append(entry) - self.notifyObservers() - - @property - def currentPlugin(self) -> PluginEntry[T]: - return self._entryList[self._currentIndex] - - def setCurrentPluginByName(self, name: str) -> None: + def set_current_plugin(self, name: str) -> None: namecf = name.casefold() - for index, entry in enumerate(self._entryList): - if namecf == entry.simpleName.casefold() or namecf == entry.displayName.casefold(): - if self._currentIndex != index: - self._currentIndex = index - self.notifyObservers() + for index, plugin in enumerate(self._registered_plugins): + if namecf == plugin.simple_name.casefold() or namecf == plugin.display_name.casefold(): + if self._current_index != index: + self._current_index = index - return + if self._parameter is not None: + self._parameter.set_value(self.get_current_plugin().simple_name) - logger.debug(f'Invalid plugin name "{name}"') + self.notify_observers() - @overload - def __getitem__(self, index: int) -> PluginEntry[T]: ... + return - @overload - def __getitem__(self, index: slice) -> Sequence[PluginEntry[T]]: ... + registered_plugins = ', '.join(f'"{pi.simple_name}"' for pi in self._registered_plugins) + logger.debug(f'Invalid plugin name "{name}". Registered plugins: {registered_plugins}.') - def __getitem__(self, index: int | slice) -> PluginEntry[T] | Sequence[PluginEntry[T]]: - return self._entryList[index] + def synchronize_with_parameter(self, parameter: StringParameter) -> None: + self._parameter = parameter + self.set_current_plugin(parameter.get_value()) + self._parameter.add_observer(self) - def __len__(self) -> int: - return len(self._entryList) + def __iter__(self) -> Iterator[Plugin[T]]: + for plugin in self._registered_plugins: + yield plugin def __bool__(self) -> bool: - return bool(self._entryList) + return bool(self._registered_plugins) - def copy(self) -> PluginChooser[T]: - clone = PluginChooser[T]() - clone._entryList = self._entryList.copy() - clone._currentIndex = self._currentIndex - return clone + def _update(self, observable: Observable) -> None: + if self._parameter is not None and observable is self._parameter: + self.set_current_plugin(self._parameter.get_value()) class PluginRegistry: def __init__(self) -> None: - self.diffractionFileReaders = PluginChooser[DiffractionFileReader]() - self.diffractionFileWriters = PluginChooser[DiffractionFileWriter]() - self.scanFileReaders = PluginChooser[ScanFileReader]() - self.scanFileWriters = PluginChooser[ScanFileWriter]() - self.fresnelZonePlates = PluginChooser[FresnelZonePlate]() - self.probeFileReaders = PluginChooser[ProbeFileReader]() - self.probeFileWriters = PluginChooser[ProbeFileWriter]() - self.objectPhaseCenteringStrategies = PluginChooser[ObjectPhaseCenteringStrategy]() - self.objectFileReaders = PluginChooser[ObjectFileReader]() - self.objectFileWriters = PluginChooser[ObjectFileWriter]() - self.productFileReaders = PluginChooser[ProductFileReader]() - self.productFileWriters = PluginChooser[ProductFileWriter]() - self.fileBasedWorkflows = PluginChooser[FileBasedWorkflow]() - self.fluorescenceFileReaders = PluginChooser[FluorescenceFileReader]() - self.fluorescenceFileWriters = PluginChooser[FluorescenceFileWriter]() - self.upscalingStrategies = PluginChooser[UpscalingStrategy]() - self.deconvolutionStrategies = PluginChooser[DeconvolutionStrategy]() + self.diffraction_file_readers = PluginChooser[DiffractionFileReader]() + self.diffraction_file_writers = PluginChooser[DiffractionFileWriter]() + self.position_file_readers = PluginChooser[PositionFileReader]() + self.position_file_writers = PluginChooser[PositionFileWriter]() + self.fresnel_zone_plates = PluginChooser[FresnelZonePlate]() + self.probe_file_readers = PluginChooser[ProbeFileReader]() + self.probe_file_writers = PluginChooser[ProbeFileWriter]() + self.object_file_readers = PluginChooser[ObjectFileReader]() + self.object_file_writers = PluginChooser[ObjectFileWriter]() + self.product_file_readers = PluginChooser[ProductFileReader]() + self.product_file_writers = PluginChooser[ProductFileWriter]() + self.file_based_workflows = PluginChooser[FileBasedWorkflow]() + self.fluorescence_file_readers = PluginChooser[FluorescenceFileReader]() + self.fluorescence_file_writers = PluginChooser[FluorescenceFileWriter]() + self.upscaling_strategies = PluginChooser[UpscalingStrategy]() + self.deconvolution_strategies = PluginChooser[DeconvolutionStrategy]() + + def register_product_file_reader_with_adapters( + self, strategy: ProductFileReader, *, display_name: str, simple_name: str = '' + ) -> None: + self.position_file_readers.register_plugin( + ProductPositionFileReader(strategy), display_name=display_name, simple_name=simple_name + ) + self.probe_file_readers.register_plugin( + ProductProbeFileReader(strategy), display_name=display_name, simple_name=simple_name + ) + self.object_file_readers.register_plugin( + ProductObjectFileReader(strategy), display_name=display_name, simple_name=simple_name + ) + self.product_file_readers.register_plugin( + strategy, display_name=display_name, simple_name=simple_name + ) @classmethod - def loadPlugins(cls) -> PluginRegistry: + def load_plugins(cls) -> PluginRegistry: registry = cls() import ptychodus.plugins @@ -130,14 +172,14 @@ def loadPlugins(cls) -> PluginRegistry: # returned name an absolute name instead of a relative one. This allows # import_module to work without having to do additional modification to # the name. - for moduleInfo in pkgutil.iter_modules(ns_pkg.__path__, ns_pkg.__name__ + '.'): + for module_info in pkgutil.iter_modules(ns_pkg.__path__, ns_pkg.__name__ + '.'): try: - module = importlib.import_module(moduleInfo.name) + module = importlib.import_module(module_info.name) except ModuleNotFoundError as exc: - logger.info(f'Skipping {moduleInfo.name}') + logger.info(f'Skipping {module_info.name}') logger.warning(exc) else: - logger.info(f'Registering {moduleInfo.name}') - module.registerPlugins(registry) + logger.info(f'Registering {module_info.name}') + module.register_plugins(registry) return registry diff --git a/src/ptychodus/api/probe.py b/src/ptychodus/api/probe.py index 6e46cdc3..a2ff8785 100644 --- a/src/ptychodus/api/probe.py +++ b/src/ptychodus/api/probe.py @@ -3,201 +3,280 @@ from collections.abc import Sequence from dataclasses import dataclass from pathlib import Path +from typing import overload import numpy -from .geometry import ImageExtent, PixelGeometry -from .propagator import WavefieldArrayType, intensity -from .typing import RealArrayType +from .geometry import PixelGeometry +from .propagator import intensity +from .typing import ComplexArrayType, RealArrayType @dataclass(frozen=True) class FresnelZonePlate: - zonePlateDiameterInMeters: float - outermostZoneWidthInMeters: float - centralBeamstopDiameterInMeters: float - - def getFocalLengthInMeters(self, centralWavelengthInMeters: float) -> float: - return ( - self.zonePlateDiameterInMeters - * self.outermostZoneWidthInMeters - / centralWavelengthInMeters - ) + zone_plate_diameter_m: float + outermost_zone_width_m: float + central_beamstop_diameter_m: float + + def get_focal_length_m(self, central_wavelength_m: float) -> float: + return self.zone_plate_diameter_m * self.outermost_zone_width_m / central_wavelength_m @dataclass(frozen=True) class ProbeGeometry: - widthInPixels: int - heightInPixels: int - pixelWidthInMeters: float - pixelHeightInMeters: float + width_px: int + height_px: int + pixel_width_m: float + pixel_height_m: float @property - def widthInMeters(self) -> float: - return self.widthInPixels * self.pixelWidthInMeters + def width_m(self) -> float: + return self.width_px * self.pixel_width_m @property - def heightInMeters(self) -> float: - return self.heightInPixels * self.pixelHeightInMeters + def height_m(self) -> float: + return self.height_px * self.pixel_height_m - def _asTuple(self) -> tuple[int, int, float, float]: - return ( - self.widthInPixels, - self.heightInPixels, - self.pixelWidthInMeters, - self.pixelHeightInMeters, + def get_pixel_geometry(self) -> PixelGeometry: + return PixelGeometry( + width_m=self.pixel_width_m, + height_m=self.pixel_height_m, ) - def __eq__(self, other: object) -> bool: - if isinstance(other, ProbeGeometry): - return self._asTuple() == other._asTuple() - return False +class ProbeGeometryProvider(ABC): + @property + @abstractmethod + def detector_distance_m(self) -> float: + pass + @property + @abstractmethod + def probe_photon_count(self) -> float: + pass -class ProbeGeometryProvider(ABC): @property @abstractmethod - def detectorDistanceInMeters(self) -> float: + def probe_wavelength_m(self) -> float: pass @property @abstractmethod - def probeWavelengthInMeters(self) -> float: + def probe_power_W(self) -> float: # noqa: N802 pass @property @abstractmethod - def probePowerInWatts(self) -> float: + def num_scan_points(self) -> int: pass @abstractmethod - def getProbeGeometry(self) -> ProbeGeometry: + def get_detector_pixel_geometry(self) -> PixelGeometry: + pass + + @abstractmethod + def get_probe_geometry(self) -> ProbeGeometry: pass class Probe: - @staticmethod - def _calculateModeRelativePower(array: WavefieldArrayType) -> Sequence[float]: + def __init__( + self, + array: ComplexArrayType, + pixel_geometry: PixelGeometry, + ) -> None: + self._array = array + self._pixel_geometry = pixel_geometry + + if array.ndim != 3: + raise ValueError('Probe must be a 3-dimensional ndarray.') + power = numpy.sum(intensity(array), axis=(-2, -1)) powersum = numpy.sum(power) if powersum > 0.0: power /= powersum - return power.tolist() + self._mode_relative_power = power.tolist() + def copy(self) -> Probe: + return Probe( + array=self._array.copy(), + pixel_geometry=self._pixel_geometry.copy(), + ) + + def get_array(self) -> ComplexArrayType: + return self._array + + def get_pixel_geometry(self) -> PixelGeometry: + return self._pixel_geometry + + @property + def dtype(self) -> numpy.dtype: + return self._array.dtype + + @property + def width_px(self) -> int: + return self._array.shape[-1] + + @property + def height_px(self) -> int: + return self._array.shape[-2] + + @property + def num_incoherent_modes(self) -> int: + return self._array.shape[-3] + + def get_incoherent_mode(self, number: int) -> ComplexArrayType: + return self._array[number, :, :] + + def get_incoherent_modes_flattened(self) -> ComplexArrayType: + return self._array.transpose((1, 0, 2)).reshape(self.height_px, -1) + + def get_incoherent_mode_relative_power(self, number: int) -> float: + return self._mode_relative_power[number] + + def get_coherence(self) -> float: + return numpy.sqrt(numpy.sum(numpy.square(self._mode_relative_power))) + + def get_intensity(self) -> RealArrayType: + return numpy.sum(intensity(self._array), axis=-3) + + +class ProbeSequence(Sequence[Probe]): def __init__( self, - array: WavefieldArrayType | None = None, - *, - pixelWidthInMeters: float = 0.0, - pixelHeightInMeters: float = 0.0, + array: ComplexArrayType | None, + opr_weights: RealArrayType | None, + pixel_geometry: PixelGeometry | None, ) -> None: if array is None: - self._array = numpy.zeros((1, 0, 0), dtype=complex) - else: - if numpy.iscomplexobj(array): - if array.ndim == 2: - self._array = array[numpy.newaxis, :, :] - elif array.ndim == 3: + self._array: ComplexArrayType = numpy.zeros((1, 1, 0, 0), dtype=complex) + elif numpy.iscomplexobj(array): + match array.ndim: + case 2: + self._array = array[numpy.newaxis, numpy.newaxis, ...] + case 3: + self._array = array[numpy.newaxis, ...] + case 4: self._array = array + case _: + raise ValueError('Probe must be 2-, 3-, or 4-dimensional ndarray.') + else: + raise TypeError('Probe must be a complex-valued ndarray') + + if opr_weights is None: + self._opr_weights = None + elif numpy.issubdtype(opr_weights.dtype, numpy.floating): + if opr_weights.ndim == 2: + if opr_weights.shape[1] == self._array.shape[0]: + self._opr_weights = opr_weights else: - raise ValueError('Probe must be 2- or 3-dimensional ndarray.') + raise ValueError('opr_weights do not match the number of coherent probe modes') else: - raise TypeError('Probe must be a complex-valued ndarray') + raise ValueError('opr_weights must be 2-dimensional ndarray') + else: + raise TypeError('opr_weights must be a floating-point ndarray') - self._modeRelativePower = Probe._calculateModeRelativePower(self._array) - self._pixelWidthInMeters = pixelWidthInMeters - self._pixelHeightInMeters = pixelHeightInMeters + self._pixel_geometry = pixel_geometry - def copy(self) -> Probe: - return Probe( - array=numpy.array(self._array), - pixelWidthInMeters=float(self._pixelWidthInMeters), - pixelHeightInMeters=float(self._pixelHeightInMeters), + def copy(self) -> ProbeSequence: + return ProbeSequence( + self._array.copy(), + None if self._opr_weights is None else self._opr_weights.copy(), + None if self._pixel_geometry is None else self._pixel_geometry.copy(), ) - @property - def array(self) -> WavefieldArrayType: + def get_array(self) -> ComplexArrayType: return self._array + def get_opr_weights(self) -> RealArrayType: + if self._opr_weights is None: + raise ValueError('Missing opr_weights!') + + return self._opr_weights + + def get_pixel_geometry(self) -> PixelGeometry: + if self._pixel_geometry is None: + raise ValueError('Missing probe pixel geometry!') + + return self._pixel_geometry + @property - def dataType(self) -> numpy.dtype: + def dtype(self) -> numpy.dtype: return self._array.dtype @property - def numberOfModes(self) -> int: - return self._array.shape[-3] + def nbytes(self) -> int: + sz = self._array.nbytes - @property - def sizeInBytes(self) -> int: - return self._array.nbytes + if self._opr_weights is not None: + sz += self._opr_weights.nbytes + + return sz @property - def widthInPixels(self) -> int: - return self._array.shape[-1] + def num_coherent_modes(self) -> int: + return self._array.shape[0] @property - def heightInPixels(self) -> int: - return self._array.shape[-2] + def num_incoherent_modes(self) -> int: + return self._array.shape[1] @property - def pixelWidthInMeters(self) -> float: - return self._pixelWidthInMeters + def height_px(self) -> int: + return self._array.shape[2] @property - def pixelHeightInMeters(self) -> float: - return self._pixelHeightInMeters + def width_px(self) -> int: + return self._array.shape[3] - def getGeometry(self) -> ProbeGeometry: - return ProbeGeometry( - widthInPixels=self.widthInPixels, - heightInPixels=self.heightInPixels, - pixelWidthInMeters=self._pixelWidthInMeters, - pixelHeightInMeters=self._pixelHeightInMeters, - ) + @overload + def __getitem__(self, index: int) -> Probe: ... - def getPixelGeometry(self) -> PixelGeometry: - return PixelGeometry( - widthInMeters=self._pixelWidthInMeters, - heightInMeters=self._pixelHeightInMeters, - ) + @overload + def __getitem__(self, index: slice) -> Sequence[Probe]: ... - def getExtent(self) -> ImageExtent: - return ImageExtent( - widthInPixels=self.widthInPixels, - heightInPixels=self.heightInPixels, - ) + def __getitem__(self, index: int | slice) -> Probe | Sequence[Probe]: + if isinstance(index, slice): + return [self[idx] for idx in range(index.start, index.stop, index.step)] - def getMode(self, number: int) -> WavefieldArrayType: - return self._array[number, :, :] + array = self._array[0, :, :, :].copy() - def getModesFlattened(self) -> WavefieldArrayType: - if self._array.size > 0: - return self._array.transpose((1, 0, 2)).reshape(self._array.shape[-2], -1) - else: - return self._array + if self._opr_weights is not None: + array[0, :, :] = numpy.tensordot( + self._opr_weights[index, :], self._array[:, 0, :, :], axes=1 + ) - def getModeRelativePower(self, number: int) -> float: - return self._modeRelativePower[number] + return Probe(array, self.get_pixel_geometry()) - def getCoherence(self) -> float: - return numpy.sqrt(numpy.sum(numpy.square(self._modeRelativePower))) + def get_probe_no_opr(self) -> Probe: + array = self._array[0, :, :, :].copy() + return Probe(array, self.get_pixel_geometry()) - def getIntensity(self) -> RealArrayType: - return numpy.sum(intensity(self._array), axis=-3) + def get_geometry(self) -> ProbeGeometry: + pixel_geometry = self.get_pixel_geometry() + + return ProbeGeometry( + width_px=self.width_px, + height_px=self.height_px, + pixel_width_m=pixel_geometry.width_m, + pixel_height_m=pixel_geometry.height_m, + ) + + def __len__(self) -> int: + return 1 if self._opr_weights is None else self._opr_weights.shape[0] class ProbeFileReader(ABC): @abstractmethod - def read(self, filePath: Path) -> Probe: + def read(self, file_path: Path) -> ProbeSequence: """reads a probe from file""" pass class ProbeFileWriter(ABC): @abstractmethod - def write(self, filePath: Path, probe: Probe) -> None: + def write(self, file_path: Path, probes: ProbeSequence) -> None: """writes a probe to file""" pass diff --git a/src/ptychodus/api/product.py b/src/ptychodus/api/product.py index 73b660d4..4af41ccd 100644 --- a/src/ptychodus/api/product.py +++ b/src/ptychodus/api/product.py @@ -1,74 +1,83 @@ from abc import ABC, abstractmethod from collections.abc import Sequence +from typing import Final from dataclasses import dataclass from pathlib import Path from sys import getsizeof -from .constants import ELECTRON_VOLT_J, LIGHT_SPEED_M_PER_S, PLANCK_CONSTANT_J_PER_HZ from .object import Object -from .probe import Probe -from .scan import Scan +from .probe import ProbeSequence +from .scan import PositionSequence + +# Source: https://physics.nist.gov/cuu/Constants/index.html +ELECTRON_VOLT_J: Final[float] = 1.602176634e-19 +LIGHT_SPEED_M_PER_S: Final[float] = 299792458 +PLANCK_CONSTANT_J_PER_HZ: Final[float] = 6.62607015e-34 @dataclass(frozen=True) class ProductMetadata: name: str comments: str - detectorDistanceInMeters: float - probeEnergyInElectronVolts: float - probePhotonsPerSecond: float - exposureTimeInSeconds: float + detector_distance_m: float + probe_energy_eV: float # noqa: N815 + probe_photon_count: float + exposure_time_s: float + mass_attenuation_m2_kg: float + tomography_angle_deg: float @property - def probeEnergyInJoules(self) -> float: - return self.probeEnergyInElectronVolts * ELECTRON_VOLT_J + def probe_energy_J(self) -> float: # noqa: N802 + return self.probe_energy_eV * ELECTRON_VOLT_J @property - def probeWavelengthInMeters(self) -> float: - hc_Jm = PLANCK_CONSTANT_J_PER_HZ * LIGHT_SPEED_M_PER_S + def probe_wavelength_m(self) -> float: + hc_Jm = PLANCK_CONSTANT_J_PER_HZ * LIGHT_SPEED_M_PER_S # noqa: N806 try: - return hc_Jm / self.probeEnergyInJoules + return hc_Jm / self.probe_energy_J except ZeroDivisionError: return 0.0 @property - def sizeInBytes(self) -> int: + def nbytes(self) -> int: sz = getsizeof(self.name) sz += getsizeof(self.comments) - sz += getsizeof(self.detectorDistanceInMeters) - sz += getsizeof(self.probeEnergyInElectronVolts) - sz += getsizeof(self.probePhotonsPerSecond) - sz += getsizeof(self.exposureTimeInSeconds) + sz += getsizeof(self.detector_distance_m) + sz += getsizeof(self.probe_energy_eV) + sz += getsizeof(self.probe_photon_count) + sz += getsizeof(self.exposure_time_s) + sz += getsizeof(self.mass_attenuation_m2_kg) + sz += getsizeof(self.tomography_angle_deg) return sz @dataclass(frozen=True) class Product: metadata: ProductMetadata - scan: Scan - probe: Probe + positions: PositionSequence + probes: ProbeSequence object_: Object costs: Sequence[float] @property - def sizeInBytes(self) -> int: - sz = self.metadata.sizeInBytes - sz += self.scan.sizeInBytes - sz += self.probe.sizeInBytes - sz += self.object_.sizeInBytes + def nbytes(self) -> int: + sz = self.metadata.nbytes + sz += self.positions.nbytes + sz += self.probes.nbytes + sz += self.object_.nbytes return sz class ProductFileReader(ABC): @abstractmethod - def read(self, filePath: Path) -> Product: + def read(self, file_path: Path) -> Product: """reads a product from file""" pass class ProductFileWriter(ABC): @abstractmethod - def write(self, filePath: Path, product: Product) -> None: + def write(self, file_path: Path, product: Product) -> None: """writes a product to file""" pass diff --git a/src/ptychodus/api/propagator.py b/src/ptychodus/api/propagator.py index 8b80477c..1e6b5e49 100644 --- a/src/ptychodus/api/propagator.py +++ b/src/ptychodus/api/propagator.py @@ -1,16 +1,13 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import TypeAlias from scipy.fft import fft2, fftfreq, fftshift, ifft2, ifftshift import numpy from .typing import ComplexArrayType, RealArrayType -WavefieldArrayType: TypeAlias = ComplexArrayType - -def intensity(wavefield: WavefieldArrayType) -> RealArrayType: +def intensity(wavefield: ComplexArrayType) -> RealArrayType: return numpy.real(numpy.multiply(wavefield, numpy.conjugate(wavefield))) @@ -50,21 +47,21 @@ def fresnel_number(self) -> float: return numpy.square(self.dx) / numpy.absolute(self.z) def get_spatial_coordinates(self) -> tuple[RealArrayType, RealArrayType]: - JJ, II = numpy.mgrid[: self.height_px, : self.width_px] - XX = II - self.width_px // 2 - YY = JJ - self.height_px // 2 + JJ, II = numpy.mgrid[: self.height_px, : self.width_px] # noqa: N806 + XX = II - self.width_px // 2 # noqa: N806 + YY = JJ - self.height_px // 2 # noqa: N806 return YY, XX def get_frequency_coordinates(self) -> tuple[RealArrayType, RealArrayType]: fx = fftshift(fftfreq(self.width_px)) fy = fftshift(fftfreq(self.height_px)) - FY, FX = numpy.meshgrid(fy, fx) + FY, FX = numpy.meshgrid(fy, fx, indexing='ij') # noqa: N806 return FY, FX class Propagator(ABC): @abstractmethod - def propagate(self, wavefield: WavefieldArrayType) -> WavefieldArrayType: + def propagate(self, wavefield: ComplexArrayType) -> ComplexArrayType: pass @@ -73,14 +70,14 @@ def __init__(self, parameters: PropagatorParameters) -> None: ar = parameters.pixel_aspect_ratio i2piz = 2j * numpy.pi * parameters.z - FY, FX = parameters.get_frequency_coordinates() - F2 = numpy.square(FX) + numpy.square(ar * FY) + FY, FX = parameters.get_frequency_coordinates() # noqa: N806 + F2 = numpy.square(FX) + numpy.square(ar * FY) # noqa: N806 ratio = F2 / numpy.square(parameters.dx) tf = numpy.exp(i2piz * numpy.sqrt(1 - ratio)) self._transfer_function = numpy.where(ratio < 1, tf, 0) - def propagate(self, wavefield: WavefieldArrayType) -> WavefieldArrayType: + def propagate(self, wavefield: ComplexArrayType) -> ComplexArrayType: return fftshift(ifft2(self._transfer_function * fft2(ifftshift(wavefield)))) @@ -89,13 +86,13 @@ def __init__(self, parameters: PropagatorParameters) -> None: ar = parameters.pixel_aspect_ratio i2piz = 2j * numpy.pi * parameters.z - FY, FX = parameters.get_frequency_coordinates() - F2 = numpy.square(FX) + numpy.square(ar * FY) + FY, FX = parameters.get_frequency_coordinates() # noqa: N806 + F2 = numpy.square(FX) + numpy.square(ar * FY) # noqa: N806 ratio = F2 / numpy.square(parameters.dx) self._transfer_function = numpy.exp(i2piz * (1 - ratio / 2)) - def propagate(self, wavefield: WavefieldArrayType) -> WavefieldArrayType: + def propagate(self, wavefield: ComplexArrayType) -> ComplexArrayType: return fftshift(ifft2(self._transfer_function * fft2(ifftshift(wavefield)))) @@ -103,22 +100,22 @@ class FresnelTransformPropagator(Propagator): def __init__(self, parameters: PropagatorParameters) -> None: ipi = 1j * numpy.pi - Fr = parameters.fresnel_number + Fr = parameters.fresnel_number # noqa: N806 ar = parameters.pixel_aspect_ratio - N = parameters.width_px - M = parameters.height_px - YY, XX = parameters.get_spatial_coordinates() + N = parameters.width_px # noqa: N806 + M = parameters.height_px # noqa: N806 + YY, XX = parameters.get_spatial_coordinates() # noqa: N806 - C0 = Fr / (1j * ar) - C1 = numpy.exp(2j * numpy.pi * parameters.z) - C2 = numpy.exp((numpy.square(XX / N) + numpy.square(ar * YY / M)) * ipi / Fr) + C0 = Fr / (1j * ar) # noqa: N806 + C1 = numpy.exp(2j * numpy.pi * parameters.z) # noqa: N806 + C2 = numpy.exp((numpy.square(XX / N) + numpy.square(ar * YY / M)) * ipi / Fr) # noqa: N806 is_forward = parameters.propagation_distance_m >= 0.0 self._is_forward = is_forward self._A = C2 * C1 * C0 if is_forward else C2 * C1 / C0 self._B = numpy.exp(ipi * Fr * (numpy.square(XX) + numpy.square(YY / ar))) - def propagate(self, wavefield: WavefieldArrayType) -> WavefieldArrayType: + def propagate(self, wavefield: ComplexArrayType) -> ComplexArrayType: if self._is_forward: return self._A * fftshift(fft2(ifftshift(wavefield * self._B))) else: @@ -129,21 +126,21 @@ class FraunhoferPropagator(Propagator): def __init__(self, parameters: PropagatorParameters) -> None: ipi = 1j * numpy.pi - Fr = parameters.fresnel_number + Fr = parameters.fresnel_number # noqa: N806 ar = parameters.pixel_aspect_ratio - N = parameters.width_px - M = parameters.height_px - YY, XX = parameters.get_spatial_coordinates() + N = parameters.width_px # noqa: N806 + M = parameters.height_px # noqa: N806 + YY, XX = parameters.get_spatial_coordinates() # noqa: N806 - C0 = Fr / (1j * ar) - C1 = numpy.exp(2j * numpy.pi * parameters.z) - C2 = numpy.exp((numpy.square(XX / N) + numpy.square(ar * YY / M)) * ipi / Fr) + C0 = Fr / (1j * ar) # noqa: N806 + C1 = numpy.exp(2j * numpy.pi * parameters.z) # noqa: N806 + C2 = numpy.exp((numpy.square(XX / N) + numpy.square(ar * YY / M)) * ipi / Fr) # noqa: N806 is_forward = parameters.propagation_distance_m >= 0.0 self._is_forward = is_forward self._A = C2 * C1 * C0 if is_forward else C2 * C1 / C0 - def propagate(self, wavefield: WavefieldArrayType) -> WavefieldArrayType: + def propagate(self, wavefield: ComplexArrayType) -> ComplexArrayType: if self._is_forward: return self._A * fftshift(fft2(ifftshift(wavefield))) else: diff --git a/src/ptychodus/api/reconstructor.py b/src/ptychodus/api/reconstructor.py index ad5eb67e..4984eb62 100644 --- a/src/ptychodus/api/reconstructor.py +++ b/src/ptychodus/api/reconstructor.py @@ -4,14 +4,16 @@ from dataclasses import dataclass from pathlib import Path +from ptychodus.api.typing import BooleanArrayType + +from .patterns import PatternDataType from .product import Product -from .patterns import BooleanArrayType, DiffractionPatternArrayType @dataclass(frozen=True) class ReconstructInput: - patterns: DiffractionPatternArrayType - goodPixelMask: BooleanArrayType + patterns: PatternDataType + bad_pixels: BooleanArrayType product: Product @@ -32,72 +34,46 @@ def reconstruct(self, parameters: ReconstructInput) -> ReconstructOutput: pass +@dataclass(frozen=True) +class LossValue: + epoch: int + training_loss: float + validation_loss: float + + @dataclass(frozen=True) class TrainOutput: - trainingLoss: Sequence[float] - validationLoss: Sequence[float] + losses: Sequence[LossValue] result: int class TrainableReconstructor(Reconstructor): @abstractmethod - def ingestTrainingData(self, parameters: ReconstructInput) -> None: - pass - - @abstractmethod - def getOpenTrainingDataFileFilterList(self) -> Sequence[str]: - pass - - @abstractmethod - def getOpenTrainingDataFileFilter(self) -> str: + def get_model_file_filter(self) -> str: pass @abstractmethod - def openTrainingData(self, filePath: Path) -> None: + def open_model(self, file_path: Path) -> None: pass @abstractmethod - def getSaveTrainingDataFileFilterList(self) -> Sequence[str]: + def save_model(self, file_path: Path) -> None: pass @abstractmethod - def getSaveTrainingDataFileFilter(self) -> str: + def get_training_data_file_filter(self) -> str: pass @abstractmethod - def saveTrainingData(self, filePath: Path) -> None: + def export_training_data(self, file_path: Path, parameters: ReconstructInput) -> None: pass @abstractmethod - def train(self) -> TrainOutput: + def get_training_data_path(self) -> Path: pass @abstractmethod - def clearTrainingData(self) -> None: - pass - - @abstractmethod - def getOpenModelFileFilterList(self) -> Sequence[str]: - pass - - @abstractmethod - def getOpenModelFileFilter(self) -> str: - pass - - @abstractmethod - def openModel(self, filePath: Path) -> None: - pass - - @abstractmethod - def getSaveModelFileFilterList(self) -> Sequence[str]: - pass - - @abstractmethod - def getSaveModelFileFilter(self) -> str: - pass - - @abstractmethod - def saveModel(self, filePath: Path) -> None: + def train(self, data_path: Path) -> TrainOutput: pass @@ -112,50 +88,26 @@ def name(self) -> str: def reconstruct(self, parameters: ReconstructInput) -> ReconstructOutput: return ReconstructOutput(parameters.product, 0) - def ingestTrainingData(self, parameters: ReconstructInput) -> None: - pass - - def getOpenTrainingDataFileFilterList(self) -> Sequence[str]: - return list() - - def getOpenTrainingDataFileFilter(self) -> str: - return str() - - def openTrainingData(self, filePath: Path) -> None: - pass - - def getSaveTrainingDataFileFilterList(self) -> Sequence[str]: - return list() - - def getSaveTrainingDataFileFilter(self) -> str: + def get_model_file_filter(self) -> str: return str() - def saveTrainingData(self, filePath: Path) -> None: + def open_model(self, file_path: Path) -> None: pass - def train(self) -> TrainOutput: - return TrainOutput([], [], 0) - - def clearTrainingData(self) -> None: + def save_model(self, file_path: Path) -> None: pass - def getOpenModelFileFilterList(self) -> Sequence[str]: - return list() - - def getOpenModelFileFilter(self) -> str: + def get_training_data_file_filter(self) -> str: return str() - def openModel(self, filePath: Path) -> None: + def export_training_data(self, file_path: Path, parameters: ReconstructInput) -> None: pass - def getSaveModelFileFilterList(self) -> Sequence[str]: - return list() + def get_training_data_path(self) -> Path: + return Path() - def getSaveModelFileFilter(self) -> str: - return str() - - def saveModel(self, filePath: Path) -> None: - pass + def train(self, data_path: Path) -> TrainOutput: + return TrainOutput([], 0) class ReconstructorLibrary(Iterable[Reconstructor]): @@ -163,3 +115,8 @@ class ReconstructorLibrary(Iterable[Reconstructor]): @abstractmethod def name(self) -> str: pass + + @property + @abstractmethod + def logger_name(self) -> str: + pass diff --git a/src/ptychodus/api/scan.py b/src/ptychodus/api/scan.py index db6e9dda..e8e64bf0 100644 --- a/src/ptychodus/api/scan.py +++ b/src/ptychodus/api/scan.py @@ -6,52 +6,63 @@ from typing import overload import sys +import numpy + @dataclass(frozen=True) class ScanPoint: index: int - positionXInMeters: float - positionYInMeters: float + position_x_m: float + position_y_m: float @dataclass(frozen=True) class ScanBoundingBox: - minimumXInMeters: float - maximumXInMeters: float - minimumYInMeters: float - maximumYInMeters: float + minimum_x_m: float + maximum_x_m: float + minimum_y_m: float + maximum_y_m: float @property - def widthInMeters(self) -> float: - return self.maximumXInMeters - self.minimumXInMeters + def width_m(self) -> float: + return self.maximum_x_m - self.minimum_x_m @property - def heightInMeters(self) -> float: - return self.maximumYInMeters - self.minimumYInMeters + def height_m(self) -> float: + return self.maximum_y_m - self.minimum_y_m @property - def centerXInMeters(self) -> float: - return self.minimumXInMeters + self.widthInMeters / 2.0 + def center_x_m(self) -> float: + return self.minimum_x_m + self.width_m / 2.0 @property - def centerYInMeters(self) -> float: - return self.minimumYInMeters + self.heightInMeters / 2.0 + def center_y_m(self) -> float: + return self.minimum_y_m + self.height_m / 2.0 def hull(self, bbox: ScanBoundingBox) -> ScanBoundingBox: return ScanBoundingBox( - minimumXInMeters=min(self.minimumXInMeters, bbox.minimumXInMeters), - maximumXInMeters=max(self.maximumXInMeters, bbox.maximumXInMeters), - minimumYInMeters=min(self.minimumYInMeters, bbox.minimumYInMeters), - maximumYInMeters=max(self.maximumYInMeters, bbox.maximumYInMeters), + minimum_x_m=min(self.minimum_x_m, bbox.minimum_x_m), + maximum_x_m=max(self.maximum_x_m, bbox.maximum_x_m), + minimum_y_m=min(self.minimum_y_m, bbox.minimum_y_m), + maximum_y_m=max(self.maximum_y_m, bbox.maximum_y_m), ) -class Scan(Sequence[ScanPoint]): - def __init__(self, pointSeq: Sequence[ScanPoint] | None = None) -> None: - self._pointSeq: Sequence[ScanPoint] = [] if pointSeq is None else pointSeq +class PositionSequence(Sequence[ScanPoint]): + def __init__(self, point_seq: Sequence[ScanPoint] | None = None) -> None: + coordinates_m: list[float] = [] + + if point_seq is not None: + for point in point_seq: + coordinates_m.append(point.position_y_m) + coordinates_m.append(point.position_x_m) + + self._coordinates_m = numpy.reshape(coordinates_m, (-1, 2)) - def copy(self) -> Scan: - return Scan(list(self._pointSeq)) + def copy(self) -> PositionSequence: + seq = PositionSequence() + seq._coordinates_m = self._coordinates_m.copy() + return seq @overload def __getitem__(self, index: int) -> ScanPoint: ... @@ -60,16 +71,21 @@ def __getitem__(self, index: int) -> ScanPoint: ... def __getitem__(self, index: slice) -> Sequence[ScanPoint]: ... def __getitem__(self, index: int | slice) -> ScanPoint | Sequence[ScanPoint]: - return self._pointSeq[index] + if isinstance(index, slice): + return [self[idx] for idx in range(index.start, index.stop, index.step)] + + return ScanPoint( + index=index, + position_x_m=self._coordinates_m[index, -1], + position_y_m=self._coordinates_m[index, -2], + ) def __len__(self) -> int: - return len(self._pointSeq) + return self._coordinates_m.shape[0] @property - def sizeInBytes(self) -> int: - numBytes = sys.getsizeof(self._pointSeq) - numBytes += sum(sys.getsizeof(point) for point in self._pointSeq) - return numBytes + def nbytes(self) -> int: + return self._coordinates_m.nbytes class ScanPointParseError(Exception): @@ -78,19 +94,19 @@ class ScanPointParseError(Exception): pass -class ScanFileReader(ABC): - """interface for plugins that read scan files""" +class PositionFileReader(ABC): + """interface for plugins that read position files""" @abstractmethod - def read(self, filePath: Path) -> Scan: - """reads a scan dictionary from file""" + def read(self, file_path: Path) -> PositionSequence: + """reads positions from file""" pass -class ScanFileWriter(ABC): - """interface for plugins that write scan files""" +class PositionFileWriter(ABC): + """interface for plugins that write position files""" @abstractmethod - def write(self, filePath: Path, scan: Scan) -> None: - """writes a scan dictionary to file""" + def write(self, file_path: Path, positions: PositionSequence) -> None: + """writes positions to file""" pass diff --git a/src/ptychodus/api/settings.py b/src/ptychodus/api/settings.py index fcfbed9c..01fdb202 100644 --- a/src/ptychodus/api/settings.py +++ b/src/ptychodus/api/settings.py @@ -16,84 +16,93 @@ @dataclass(frozen=True) class PathPrefixChange: - findPathPrefix: Path - replacementPathPrefix: Path + find_path_prefix: Path + replacement_path_prefix: Path class SettingsRegistry(Observable): def __init__(self) -> None: super().__init__() - self._parameterGroup = ParameterGroup() - self._fileFilterList: list[str] = ['Initialization Files (*.ini)'] + self._parameter_group = ParameterGroup() + self._file_filter_list: list[str] = ['Initialization Files (*.ini)'] - def createGroup(self, name: str) -> ParameterGroup: - return self._parameterGroup.createGroup(name) + def create_group(self, name: str) -> ParameterGroup: + return self._parameter_group.create_group(name) def __iter__(self) -> Iterator[str]: - return iter(self._parameterGroup.groups()) + return iter(self._parameter_group.groups()) def __getitem__(self, name: str) -> ParameterGroup: - return self._parameterGroup.getGroup(name) + return self._parameter_group.get_group(name) def __len__(self) -> int: - return len(self._parameterGroup.groups()) + return len(self._parameter_group.groups()) - def getOpenFileFilterList(self) -> Sequence[str]: - return self._fileFilterList + def get_open_file_filters(self) -> Sequence[str]: + return self._file_filter_list - def getOpenFileFilter(self) -> str: - return self._fileFilterList[0] + def get_open_file_filter(self) -> str: + return self._file_filter_list[0] - def openSettings(self, filePath: Path) -> None: + def open_settings(self, file_path: Path) -> None: config = configparser.ConfigParser(interpolation=None) - logger.debug(f'Reading settings from "{filePath}"') - config.read(filePath) + logger.debug(f'Reading settings from "{file_path}"') + + try: + config.read(file_path) + except Exception as exc: + logger.exception(exc) + return # TODO generalize to support nested parameter groups - for groupName, group in self._parameterGroup.groups().items(): + for group_name, group in self._parameter_group.groups().items(): try: - groupConfig = config[groupName] + group_config = config[group_name] except KeyError: pass else: - for parameterName, parameter in group.parameters().items(): + for parameter_name, parameter in group.parameters().items(): try: - valueString = groupConfig[parameterName] + value_string = group_config[parameter_name] except KeyError: pass else: - parameter.setValueFromString(valueString) + parameter.set_value_from_string(value_string) - self.notifyObservers() + self.notify_observers() - def getSaveFileFilterList(self) -> Sequence[str]: - return self._fileFilterList + def get_save_file_filters(self) -> Sequence[str]: + return self._file_filter_list - def getSaveFileFilter(self) -> str: - return self._fileFilterList[0] + def get_save_file_filter(self) -> str: + return self._file_filter_list[0] - def saveSettings( - self, filePath: Path, changePathPrefix: PathPrefixChange | None = None + def save_settings( + self, file_path: Path, change_path_prefix: PathPrefixChange | None = None ) -> None: config = configparser.ConfigParser(interpolation=None) setattr(config, 'optionxform', lambda option: option) - for groupName, group in self._parameterGroup.groups().items(): - config.add_section(groupName) + for group_name, group in self._parameter_group.groups().items(): + config.add_section(group_name) - for parameterName, parameter in group.parameters().items(): - valueString = parameter.getValueAsString() + for parameter_name, parameter in group.parameters().items(): + value_string = parameter.get_value_as_string() - if changePathPrefix and isinstance(parameter, PathParameter): - modifiedPath = parameter.changePathPrefix( - changePathPrefix.findPathPrefix, - changePathPrefix.replacementPathPrefix, + if change_path_prefix and isinstance(parameter, PathParameter): + modified_path = parameter.change_path_prefix( + change_path_prefix.find_path_prefix, + change_path_prefix.replacement_path_prefix, ) - valueString = str(modifiedPath) + value_string = str(modified_path) - config.set(groupName, parameterName, valueString) + config.set(group_name, parameter_name, value_string) - logger.debug(f'Writing settings to "{filePath}"') + logger.debug(f'Writing settings to "{file_path}"') - with filePath.open(mode='w') as configFile: - config.write(configFile) + try: + with file_path.open(mode='w') as config_file: + config.write(config_file) + except Exception as exc: + logger.exception(exc) + return diff --git a/src/ptychodus/api/tree.py b/src/ptychodus/api/tree.py index 0819869d..d21bb258 100644 --- a/src/ptychodus/api/tree.py +++ b/src/ptychodus/api/tree.py @@ -3,36 +3,36 @@ class SimpleTreeNode: - def __init__(self, parentItem: SimpleTreeNode | None, itemData: Sequence[str]) -> None: - self.parentItem = parentItem - self.itemData = itemData - self.childItems: list[SimpleTreeNode] = list() + def __init__(self, parent_item: SimpleTreeNode | None, item_data: Sequence[str]) -> None: + self.parent_item = parent_item + self.item_data = item_data + self.child_items: list[SimpleTreeNode] = list() @classmethod - def createRoot(cls, itemData: Sequence[str]) -> SimpleTreeNode: - return cls(None, itemData) + def create_root(cls, item_data: Sequence[str]) -> SimpleTreeNode: + return cls(None, item_data) - def createChild(self, itemData: Sequence[str]) -> SimpleTreeNode: - childItem = SimpleTreeNode(self, itemData) - self.childItems.append(childItem) - return childItem + def create_child(self, item_data: Sequence[str]) -> SimpleTreeNode: + child_item = SimpleTreeNode(self, item_data) + self.child_items.append(child_item) + return child_item @property - def isRoot(self) -> bool: - return self.parentItem is None + def is_root(self) -> bool: + return self.parent_item is None @property - def isLeaf(self) -> bool: - return not self.childItems + def is_leaf(self) -> bool: + return not self.child_items def data(self, column: int) -> str | None: try: - return self.itemData[column] + return self.item_data[column] except IndexError: return None def row(self) -> int: - if self.parentItem: - return self.parentItem.childItems.index(self) + if self.parent_item: + return self.parent_item.child_items.index(self) return 0 diff --git a/src/ptychodus/api/typing.py b/src/ptychodus/api/typing.py index 6dee298c..8cc148d1 100644 --- a/src/ptychodus/api/typing.py +++ b/src/ptychodus/api/typing.py @@ -3,10 +3,8 @@ import numpy import numpy.typing +BooleanArrayType: TypeAlias = numpy.typing.NDArray[numpy.bool_] IntegerArrayType: TypeAlias = numpy.typing.NDArray[numpy.integer[Any]] -Float32ArrayType: TypeAlias = numpy.typing.NDArray[numpy.float32] RealArrayType: TypeAlias = numpy.typing.NDArray[numpy.floating[Any]] ComplexArrayType: TypeAlias = numpy.typing.NDArray[numpy.complexfloating[Any, Any]] - -NumberTypes: TypeAlias = numpy.integer[Any] | numpy.floating[Any] | numpy.complexfloating[Any, Any] -NumberArrayType: TypeAlias = numpy.typing.NDArray[NumberTypes] +NumberArrayType: TypeAlias = numpy.typing.NDArray[numpy.number] diff --git a/src/ptychodus/api/units.py b/src/ptychodus/api/units.py new file mode 100644 index 00000000..90854455 --- /dev/null +++ b/src/ptychodus/api/units.py @@ -0,0 +1,3 @@ +from typing import Final + +BYTES_PER_MEGABYTE: Final[int] = 1000 * 1000 diff --git a/src/ptychodus/api/visualization.py b/src/ptychodus/api/visualization.py index ffb8b950..6b0e6f0a 100644 --- a/src/ptychodus/api/visualization.py +++ b/src/ptychodus/api/visualization.py @@ -25,7 +25,7 @@ class PlotAxis: series: Sequence[PlotSeries] @classmethod - def createNull(cls) -> PlotAxis: + def create_null(cls) -> PlotAxis: return cls('', []) def copy(self) -> PlotAxis: @@ -34,27 +34,27 @@ def copy(self) -> PlotAxis: @dataclass(frozen=True) class Plot2D: - axisX: PlotAxis - axisY: PlotAxis + axis_x: PlotAxis + axis_y: PlotAxis @classmethod - def createNull(cls) -> Plot2D: - return cls(PlotAxis.createNull(), PlotAxis.createNull()) + def create_null(cls) -> Plot2D: + return cls(PlotAxis.create_null(), PlotAxis.create_null()) def copy(self) -> Plot2D: - return Plot2D(self.axisX.copy(), self.axisY.copy()) + return Plot2D(self.axis_x.copy(), self.axis_y.copy()) @dataclass(frozen=True) class LineCut: - distanceInMeters: Sequence[float] + distance_m: Sequence[float] value: Sequence[float | complex] @dataclass(frozen=True) class KernelDensityEstimate: - valueLower: float - valueUpper: float + value_lower: float + value_upper: float kde: gaussian_kde @@ -63,10 +63,10 @@ class VisualizationProduct: def __init__( self, - valueLabel: str, + value_label: str, values: NumberArrayType, rgba: RealArrayType, - pixelGeometry: PixelGeometry, + pixel_geometry: PixelGeometry, ) -> None: if values.ndim != 2: raise ValueError(f'Values must be a 2-dimensional ndarray (actual={values.ndim}).') @@ -80,42 +80,42 @@ def __init__( if values.shape[0] != rgba.shape[0] or values.shape[1] != rgba.shape[1]: raise ValueError(f'Shape mismatch (values={values.shape} and rgba={rgba.shape}).') - self._valueLabel = valueLabel + self._value_label = value_label self._values = values self._rgba = rgba - self._pixelWidthInMeters = pixelGeometry.widthInMeters - self._pixelHeightInMeters = pixelGeometry.heightInMeters + self._pixel_width_m = pixel_geometry.width_m + self._pixel_height_m = pixel_geometry.height_m - def getValueLabel(self) -> str: - return self._valueLabel + def get_value_label(self) -> str: + return self._value_label - def getValues(self) -> NumberArrayType: + def get_values(self) -> NumberArrayType: return self._values - def getImageRGBA(self) -> RealArrayType: + def get_image_rgba(self) -> RealArrayType: return self._rgba - def getPixelGeometry(self) -> PixelGeometry: + def get_pixel_geometry(self) -> PixelGeometry: return PixelGeometry( - widthInMeters=self._pixelWidthInMeters, - heightInMeters=self._pixelHeightInMeters, + width_m=self._pixel_width_m, + height_m=self._pixel_height_m, ) @staticmethod - def _intersectBoundingBox(begin: float, end: float, n: int) -> Interval[float]: + def _intersect_bounding_box(begin: float, end: float, n: int) -> Interval[float]: length = end - begin if abs(length) < VisualizationProduct.EPS: return Interval[float](-numpy.inf, numpy.inf) else: - return Interval[float].createProper( + return Interval[float].create_proper( (0 - begin) / length, (n - begin) / length, ) @staticmethod - def _intersectGridLines( - begin: float, end: float, alphaLimits: Interval[float] + def _intersect_grid_lines( + begin: float, end: float, alpha_limits: Interval[float] ) -> Iterator[float]: ibegin = int(begin) iend = int(end) @@ -129,33 +129,33 @@ def _intersectGridLines( for idx in range(ibegin, iend + 1): alpha = (idx - begin) / length - if alpha in alphaLimits: + if alpha in alpha_limits: yield alpha - def _clipToBoundingBox(self, line: Line2D) -> Interval[float]: - alphaX = self._intersectBoundingBox(line.begin.x, line.end.x, self._values.shape[-1]) - alphaY = self._intersectBoundingBox(line.begin.y, line.end.y, self._values.shape[-2]) + def _clip_to_bounding_box(self, line: Line2D) -> Interval[float]: + alpha_x = self._intersect_bounding_box(line.begin.x, line.end.x, self._values.shape[-1]) + alpha_y = self._intersect_bounding_box(line.begin.y, line.end.y, self._values.shape[-2]) - return Interval[float].createProper( - max(0.0, max(alphaX.lower, alphaY.lower)), - min(1.0, min(alphaX.upper, alphaY.upper)), + return Interval[float].create_proper( + max(0.0, max(alpha_x.lower, alpha_y.lower)), + min(1.0, min(alpha_x.upper, alpha_y.upper)), ) - def _intersectGrid(self, line: Line2D) -> Sequence[float]: - alphaLimits = self._clipToBoundingBox(line) - xIntersections = [ - x for x in self._intersectGridLines(line.begin.x, line.end.x, alphaLimits) + def _intersect_grid(self, line: Line2D) -> Sequence[float]: + alpha_limits = self._clip_to_bounding_box(line) + x_intersections = [ + x for x in self._intersect_grid_lines(line.begin.x, line.end.x, alpha_limits) ] - yIntersections = [ - x for x in self._intersectGridLines(line.begin.y, line.end.y, alphaLimits) + y_intersections = [ + x for x in self._intersect_grid_lines(line.begin.y, line.end.y, alpha_limits) ] - alpha = {alphaLimits.lower, alphaLimits.upper} - alpha = alpha.union(xIntersections) - alpha = alpha.union(yIntersections) + alpha = {alpha_limits.lower, alpha_limits.upper} + alpha = alpha.union(x_intersections) + alpha = alpha.union(y_intersections) return sorted(alpha) - def getInfoText(self, x: float, y: float) -> str: + def get_info_text(self, x: float, y: float) -> str: ix = 0 if x < 0.0 else int(x) ix = min(ix, self._values.shape[-1]) iy = 0 if y < 0.0 else int(y) @@ -169,27 +169,27 @@ def getInfoText(self, x: float, y: float) -> str: return f'{x=:.1f} {y=:.1f} {value=:6g}' - def getLineCut(self, line: Line2D) -> LineCut: - intersections = self._intersectGrid(line) + def get_line_cut(self, line: Line2D) -> LineCut: + intersections = self._intersect_grid(line) - dx = (line.end.x - line.begin.x) * self._pixelWidthInMeters - dy = (line.end.y - line.begin.y) * self._pixelHeightInMeters - lineLength = numpy.hypot(dx, dy) + dx = (line.end.x - line.begin.x) * self._pixel_width_m + dy = (line.end.y - line.begin.y) * self._pixel_height_m + line_length = numpy.hypot(dx, dy) distances: list[float] = list() values: list[float] = list() - for alphaL, alphaR in zip(intersections[:-1], intersections[1:]): - alpha = (alphaL + alphaR) / 2.0 + for alpha_l, alpha_r in zip(intersections[:-1], intersections[1:]): + alpha = (alpha_l + alpha_r) / 2.0 point = line.lerp(alpha) value = self._values[int(point.y), int(point.x)] - distances.append(alpha * lineLength) + distances.append(alpha * line_length) values.append(value) return LineCut(distances, values) - def estimateKernelDensity(self, box: Box2D) -> KernelDensityEstimate: + def estimate_kernel_density(self, box: Box2D) -> KernelDensityEstimate: x_range = Interval[int](0, self._values.shape[-1]) x_begin = x_range.clamp(int(box.x_begin)) x_end = x_range.clamp(int(box.x_end) + 1) diff --git a/src/ptychodus/api/workflow.py b/src/ptychodus/api/workflow.py index 54e331b8..39de93b7 100644 --- a/src/ptychodus/api/workflow.py +++ b/src/ptychodus/api/workflow.py @@ -11,93 +11,95 @@ class WorkflowProductAPI(ABC): @abstractmethod - def openScan(self, filePath: Path, *, fileType: str | None = None) -> None: + def open_scan(self, file_path: Path, *, file_type: str | None = None) -> None: pass @abstractmethod - def buildScan( - self, builderName: str | None = None, builderParameters: Mapping[str, Any] = {} + def build_scan( + self, builder_name: str | None = None, builder_parameters: Mapping[str, Any] = {} ) -> None: pass @abstractmethod - def openProbe(self, filePath: Path, *, fileType: str | None = None) -> None: + def open_probe(self, file_path: Path, *, file_type: str | None = None) -> None: pass @abstractmethod - def buildProbe( - self, builderName: str | None = None, builderParameters: Mapping[str, Any] = {} + def build_probe( + self, builder_name: str | None = None, builder_parameters: Mapping[str, Any] = {} ) -> None: pass @abstractmethod - def openObject(self, filePath: Path, *, fileType: str | None = None) -> None: + def open_object(self, file_path: Path, *, file_type: str | None = None) -> None: pass @abstractmethod - def buildObject( - self, builderName: str | None = None, builderParameters: Mapping[str, Any] = {} + def build_object( + self, builder_name: str | None = None, builder_parameters: Mapping[str, Any] = {} ) -> None: pass @abstractmethod - def reconstructLocal(self, outputProductName: str) -> WorkflowProductAPI: + def reconstruct_local(self) -> WorkflowProductAPI: pass @abstractmethod - def reconstructRemote(self) -> None: + def reconstruct_remote(self) -> None: pass @abstractmethod - def saveProduct(self, filePath: Path, *, fileType: str | None = None) -> None: + def save_product(self, file_path: Path, *, file_type: str | None = None) -> None: pass class WorkflowAPI(ABC): @abstractmethod - def openPatterns( + def open_patterns( self, - filePath: Path, + file_path: Path, *, - fileType: str | None = None, - cropCenter: CropCenter | None = None, - cropExtent: ImageExtent | None = None, + file_type: str | None = None, + crop_center: CropCenter | None = None, + crop_extent: ImageExtent | None = None, ) -> None: """opens diffraction patterns from file""" pass @abstractmethod - def importProcessedPatterns(self, filePath: Path) -> None: - """import processed patterns""" + def import_assembled_patterns(self, file_path: Path) -> None: + """import assembled patterns""" pass @abstractmethod - def exportProcessedPatterns(self, filePath: Path) -> None: - """export processed patterns""" + def export_assembled_patterns(self, file_path: Path) -> None: + """export assembled patterns""" pass @abstractmethod - def openProduct(self, filePath: Path, *, fileType: str | None = None) -> WorkflowProductAPI: + def open_product(self, file_path: Path, *, file_type: str | None = None) -> WorkflowProductAPI: """opens product from file""" pass @abstractmethod - def createProduct( + def create_product( self, name: str, *, comments: str = '', - detectorDistanceInMeters: float | None = None, - probeEnergyInElectronVolts: float | None = None, - probePhotonsPerSecond: float | None = None, - exposureTimeInSeconds: float | None = None, + detector_distance_m: float | None = None, + probe_energy_eV: float | None = None, # noqa: N803 + probe_photon_count: float | None = None, + exposure_time_s: float | None = None, + mass_attenuation_m2_kg: float | None = None, + tomography_angle_deg: float | None = None, ) -> WorkflowProductAPI: """creates a new product""" pass @abstractmethod - def saveSettings( - self, filePath: Path, changePathPrefix: PathPrefixChange | None = None + def save_settings( + self, file_path: Path, change_path_prefix: PathPrefixChange | None = None ) -> None: pass @@ -105,16 +107,16 @@ def saveSettings( class FileBasedWorkflow(ABC): @property @abstractmethod - def isWatchRecursive(self) -> bool: + def is_watch_recursive(self) -> bool: """indicates whether the data directory must be watched recursively""" pass @abstractmethod - def getWatchFilePattern(self) -> str: + def get_watch_file_pattern(self) -> str: """UNIX-style filename pattern. For rules see fnmatch from Python standard library.""" pass @abstractmethod - def execute(self, api: WorkflowAPI, filePath: Path) -> None: + def execute(self, api: WorkflowAPI, file_path: Path) -> None: """uses workflow API to execute the workflow""" pass diff --git a/src/ptychodus/controller/agent/__init__.py b/src/ptychodus/controller/agent/__init__.py new file mode 100644 index 00000000..4b4f537b --- /dev/null +++ b/src/ptychodus/controller/agent/__init__.py @@ -0,0 +1,6 @@ +from .core import AgentChatController, AgentController + +__all__ = [ + 'AgentChatController', + 'AgentController', +] diff --git a/src/ptychodus/controller/agent/core.py b/src/ptychodus/controller/agent/core.py new file mode 100644 index 00000000..695de2a6 --- /dev/null +++ b/src/ptychodus/controller/agent/core.py @@ -0,0 +1,163 @@ +from PyQt5.QtCore import QEvent, QModelIndex, QObject, Qt +from PyQt5.QtGui import QKeyEvent +from PyQt5.QtWidgets import ( + QAbstractItemView, + QFormLayout, + QGroupBox, + QInputDialog, + QListView, + QPushButton, + QVBoxLayout, +) + +from ...model.agent import ( + AgentPresenter, + ArgoSettings, + ChatHistory, + ChatMessage, + ChatObserver, +) +from ...view.agent import AgentChatView, AgentInputView, AgentView +from ..parametric import ( + ComboBoxParameterViewController, + DecimalSliderParameterViewController, + LineEditParameterViewController, + SpinBoxParameterViewController, +) +from .item_delegate import ChatBubbleItemDelegate +from .list_model import AgentMessageListModel + +__all__ = ['AgentChatController', 'AgentController'] + + +class AgentInputController(QObject): + def __init__(self, presenter: AgentPresenter, view: AgentInputView) -> None: + super().__init__() + self._presenter = presenter + self._view = view + + view.text_edit.installEventFilter(self) + view.send_button.clicked.connect(self._send_message) + + def _send_message(self) -> None: + text = self._view.text_edit.toPlainText() + self._presenter.send_message(text) + self._view.text_edit.clear() + + def eventFilter(self, a0: QObject, a1: QEvent) -> bool: # noqa: N802 + if a0 == self._view.text_edit and isinstance(a1, QKeyEvent): + is_shift_pressed = bool(a1.modifiers() & Qt.KeyboardModifier.ShiftModifier) + + # require shift+enter for new line, otherwise send on enter + if a1.key() in (Qt.Key_Enter, Qt.Key_Return) and not is_shift_pressed: + self._send_message() + return True + + return super().eventFilter(a0, a1) + + +class AgentChatController(ChatObserver): + def __init__( + self, history: ChatHistory, presenter: AgentPresenter, view: AgentChatView + ) -> None: + super().__init__() + self._history = history + self._presenter = presenter + self._view = view + self._message_list_model = AgentMessageListModel(history) + self._input_controller = AgentInputController(presenter, view.input_view) + + view.message_list_view.setModel(self._message_list_model) + view.message_list_view.setItemDelegate(ChatBubbleItemDelegate()) + view.message_list_view.setResizeMode(QListView.ResizeMode.Adjust) + view.message_list_view.setVerticalScrollMode(QAbstractItemView.ScrollPerPixel) + + history.add_observer(self) + + def handle_new_message(self, message: ChatMessage, index: int) -> None: + parent = QModelIndex() + self._message_list_model.beginInsertRows(parent, index, index) + self._message_list_model.endInsertRows() + + def handle_chat_cleared(self) -> None: + self._message_list_model.beginResetModel() + self._message_list_model.endResetModel() + + +class AgentController: + def __init__(self, settings: ArgoSettings, presenter: AgentPresenter, view: AgentView) -> None: + self._settings = settings + self._presenter = presenter + self._view = view + + self._user_view_controller = LineEditParameterViewController(settings.user) + self._chat_endpoint_url_view_controller = LineEditParameterViewController( + settings.chat_endpoint_url, tool_tip='The chat endpoint URL.' + ) + self._chat_model_view_controller = ComboBoxParameterViewController( + settings.chat_model, + presenter.get_available_chat_models(), + tool_tip='The chat model to use.', + ) + self._temperature_view_controller = DecimalSliderParameterViewController( + settings.temperature, + tool_tip='What sampling temperature to use, between 0 and 2. Higher values mean the model takes more risks.', + ) + self._top_p_view_controller = DecimalSliderParameterViewController( + settings.top_p, + tool_tip='An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass.', + ) + self._max_tokens_view_controller = SpinBoxParameterViewController( + settings.max_tokens, + tool_tip='The maximum number of tokens that can be generated in the chat completion.', + ) + self._max_completion_tokens_view_controller = SpinBoxParameterViewController( + settings.max_completion_tokens, + tool_tip='An upper bound for the number of tokens that can be generated for a completion, including visible output tokens and reasoning tokens.', + ) + self._embeddings_endpoint_url_view_controller = LineEditParameterViewController( + settings.embeddings_endpoint_url, tool_tip='The embeddings endpoint URL.' + ) + self._embeddings_model_view_controller = ComboBoxParameterViewController( + settings.embeddings_model, + presenter.get_available_embeddings_models(), + tool_tip='The embeddings model to use.', + ) + self._embed_button = QPushButton('Embed Text') + self._embed_button.clicked.connect(self._embed_text) + + group_box_layout = QFormLayout() + group_box_layout.addRow('User:', self._user_view_controller.get_widget()) + group_box_layout.addRow( + 'Chat Endpoint URL:', self._chat_endpoint_url_view_controller.get_widget() + ) + group_box_layout.addRow('Chat Model:', self._chat_model_view_controller.get_widget()) + group_box_layout.addRow('Temperature:', self._temperature_view_controller.get_widget()) + group_box_layout.addRow('Top P:', self._top_p_view_controller.get_widget()) + group_box_layout.addRow('Max Tokens:', self._max_tokens_view_controller.get_widget()) + group_box_layout.addRow( + 'Max Completion Tokens:', self._max_completion_tokens_view_controller.get_widget() + ) + group_box_layout.addRow( + 'Embeddings Endpoint URL:', self._embeddings_endpoint_url_view_controller.get_widget() + ) + group_box_layout.addRow( + 'Embeddings Model:', self._embeddings_model_view_controller.get_widget() + ) + group_box_layout.addRow(self._embed_button) + + group_box = QGroupBox('Argo') + group_box.setLayout(group_box_layout) + + layout = QVBoxLayout() + layout.addWidget(group_box) + layout.addStretch() + view.setLayout(layout) + + def _embed_text(self) -> None: + title = 'Embed Text' + label = 'Enter text to embed:' + text, ok_pressed = QInputDialog.getMultiLineText(self._view, title, label, text='') + + if ok_pressed: + self._presenter.embed_text(text.splitlines()) diff --git a/src/ptychodus/controller/agent/item_delegate.py b/src/ptychodus/controller/agent/item_delegate.py new file mode 100644 index 00000000..47a462ed --- /dev/null +++ b/src/ptychodus/controller/agent/item_delegate.py @@ -0,0 +1,121 @@ +from __future__ import annotations +from dataclasses import dataclass +from typing import Final + +from PyQt5.QtCore import QModelIndex, QPointF, QRectF, QSize, QSizeF, Qt +from PyQt5.QtGui import QBrush, QFontMetrics, QPainter, QPen, QTextDocument, QTextOption +from PyQt5.QtWidgets import QApplication, QStyle, QStyleOptionViewItem, QStyledItemDelegate + + +@dataclass(frozen=True) +class BubbleMetrics: + margin_px: int + border_px: int + padding_px: int + radius_px: int + + @property + def mbp_px(self) -> int: + return self.margin_px + self.border_px + self.padding_px + + @property + def bp_px(self) -> int: + return self.margin_px + self.border_px + + @classmethod + def from_document( + cls, + document: QTextDocument, + *, + margin_em: int = 1, + border_px: int = 1, + padding_em: int = 1, + radius_px: int = 10, + ) -> BubbleMetrics: + font_metrics = QFontMetrics(document.defaultFont()) + one_em_px = font_metrics.horizontalAdvance('m') + return cls( + margin_px=margin_em * one_em_px, + border_px=border_px, + padding_px=padding_em * one_em_px, + radius_px=radius_px, + ) + + +class ChatBubbleItemDelegate(QStyledItemDelegate): + TEXT_FRACTIONAL_WIDTH: Final[float] = 0.8 + + def _create_text_document( + self, option: QStyleOptionViewItem, index: QModelIndex + ) -> QTextDocument: + text = index.model().data(index, Qt.DisplayRole) + + text_option = QTextOption() + text_option.setWrapMode(QTextOption.WordWrap) + text_option.setTextDirection(option.direction) + + doc = QTextDocument() + doc.setDefaultTextOption(text_option) + doc.setHtml(text) + doc.setDefaultFont(option.font) + doc.setDocumentMargin(0) + + text_width = min(self.TEXT_FRACTIONAL_WIDTH * option.rect.width(), doc.idealWidth()) + doc.setTextWidth(text_width) + + return doc + + def paint(self, painter: QPainter, option: QStyleOptionViewItem, index: QModelIndex) -> None: + style = option.widget.style() if option.widget else QApplication.style() + doc = self._create_text_document(option, index) + metrics = BubbleMetrics.from_document(doc) + alignment = Qt.Alignment(index.data(Qt.ItemDataRole.TextAlignmentRole)) + + doc_size = doc.size() + item_size = QSizeF( + doc_size.width() + 2 * metrics.mbp_px, + doc_size.height() + 2 * metrics.mbp_px, + ) + layout_rect = QStyle.alignedRect( + Qt.LayoutDirectionAuto, + alignment, + item_size.toSize(), + style.subElementRect(QStyle.SE_ItemViewItemText, option), + ) + bubble_rect = QRectF( + layout_rect.left() + metrics.margin_px, + layout_rect.top() + metrics.margin_px, + doc_size.width() + 2 * metrics.bp_px, + doc_size.height() + 2 * metrics.bp_px, + ) + text_origin = QPointF( + bubble_rect.left() + metrics.bp_px, + bubble_rect.top() + metrics.bp_px, + ) + text_rect = QRectF( + 0.0, + 0.0, + doc_size.width(), + doc_size.height(), + ) + + bubble_brush = QBrush(index.data(Qt.ItemDataRole.BackgroundRole)) + bubble_pen = QPen(index.data(Qt.ItemDataRole.ForegroundRole)) + bubble_pen.setWidth(metrics.border_px) + + style.drawControl(QStyle.CE_ItemViewItem, option, painter, option.widget) + + painter.save() + painter.setPen(bubble_pen) + painter.setBrush(bubble_brush) + painter.drawRoundedRect(bubble_rect, metrics.radius_px, metrics.radius_px) + painter.translate(text_origin) + doc.drawContents(painter, text_rect) + painter.restore() + + def sizeHint(self, option: QStyleOptionViewItem, index: QModelIndex) -> QSize: # noqa: N802 + doc = self._create_text_document(option, index) + metrics = BubbleMetrics.from_document(doc) + hint = option.rect.size() + hint.setHeight(int(doc.size().height()) + 2 * metrics.mbp_px) + return hint diff --git a/src/ptychodus/controller/agent/list_model.py b/src/ptychodus/controller/agent/list_model.py new file mode 100644 index 00000000..ebd41136 --- /dev/null +++ b/src/ptychodus/controller/agent/list_model.py @@ -0,0 +1,34 @@ +from typing import Any, Final + +from PyQt5.QtCore import QAbstractListModel, QModelIndex, QObject, Qt +from PyQt5.QtGui import QColor + +from ...model.agent import ChatHistory, ChatRole + + +class AgentMessageListModel(QAbstractListModel): + DARK_BLUE: Final[QColor] = QColor('#243689') + LIGHT_BLUE: Final[QColor] = QColor('#0492d2') + DARK_GREEN: Final[QColor] = QColor('#00894d') + LIGHT_GREEN: Final[QColor] = QColor('#78ca2a') + + def __init__(self, history: ChatHistory, parent: QObject | None = None) -> None: + super().__init__(parent) + self._history = history + + def data(self, index: QModelIndex, role: int = Qt.ItemDataRole.DisplayRole) -> Any: + if index.isValid(): + message = self._history[index.row()] + + match role: + case Qt.ItemDataRole.DisplayRole: + return message.content + case Qt.ItemDataRole.TextAlignmentRole: + return Qt.AlignRight if message.role == ChatRole.USER else Qt.AlignLeft + case Qt.ItemDataRole.BackgroundRole: + return self.LIGHT_BLUE if message.role == ChatRole.USER else self.LIGHT_GREEN + case Qt.ItemDataRole.ForegroundRole: + return self.DARK_BLUE if message.role == ChatRole.USER else self.DARK_GREEN + + def rowCount(self, parent: QModelIndex = QModelIndex()) -> int: # noqa: N802 + return len(self._history) diff --git a/src/ptychodus/controller/automation.py b/src/ptychodus/controller/automation.py index 20f439e6..9435345f 100644 --- a/src/ptychodus/controller/automation.py +++ b/src/ptychodus/controller/automation.py @@ -26,71 +26,59 @@ def __init__( self, presenter: AutomationPresenter, view: AutomationProcessingView, - fileDialogFactory: FileDialogFactory, + file_dialog_factory: FileDialogFactory, ) -> None: super().__init__() self._presenter = presenter self._view = view - self._fileDialogFactory = fileDialogFactory + self._file_dialog_factory = file_dialog_factory - @classmethod - def createInstance( - cls, - presenter: AutomationPresenter, - view: AutomationProcessingView, - fileDialogFactory: FileDialogFactory, - ) -> AutomationProcessingController: - controller = cls(presenter, view, fileDialogFactory) - presenter.addObserver(controller) - - for strategy in presenter.getStrategyList(): - view.strategyComboBox.addItem(strategy) + presenter.add_observer(self) - view.strategyComboBox.textActivated.connect(presenter.setStrategy) - view.directoryLineEdit.editingFinished.connect(controller._syncDirectoryToModel) - view.directoryBrowseButton.clicked.connect(controller._browseDirectory) - view.intervalSpinBox.valueChanged.connect(presenter.setProcessingIntervalInSeconds) + for strategy in presenter.get_strategies(): + view.strategy_combo_box.addItem(strategy) - controller._syncModelToView() + view.strategy_combo_box.textActivated.connect(presenter.set_strategy) + view.directory_line_edit.editingFinished.connect(self._sync_directory_to_model) + view.directory_browse_button.clicked.connect(self._browse_directory) + view.interval_spin_box.valueChanged.connect(presenter.set_processing_interval_s) - return controller + self._sync_model_to_view() - def _syncDirectoryToModel(self) -> None: - dataDirectory = Path(self._view.directoryLineEdit.text()) - self._presenter.setDataDirectory(dataDirectory) + def _sync_directory_to_model(self) -> None: + data_dir = Path(self._view.directory_line_edit.text()) + self._presenter.set_data_directory(data_dir) - def _browseDirectory(self) -> None: - dirPath = self._fileDialogFactory.getExistingDirectoryPath( + def _browse_directory(self) -> None: + dir_path = self._file_dialog_factory.get_existing_directory_path( self._view, 'Choose Data Directory' ) - if dirPath: - self._presenter.setDataDirectory(dirPath) + if dir_path: + self._presenter.set_data_directory(dir_path) - def _syncModelToView(self) -> None: - self._view.strategyComboBox.blockSignals(True) - self._view.strategyComboBox.setCurrentText(self._presenter.getStrategy()) - self._view.strategyComboBox.blockSignals(False) + def _sync_model_to_view(self) -> None: + self._view.strategy_combo_box.blockSignals(True) + self._view.strategy_combo_box.setCurrentText(self._presenter.get_strategy()) + self._view.strategy_combo_box.blockSignals(False) - dataDirectory = self._presenter.getDataDirectory() + data_dir = self._presenter.get_data_directory() - if dataDirectory: - self._view.directoryLineEdit.setText(str(dataDirectory)) + if data_dir: + self._view.directory_line_edit.setText(str(data_dir)) else: - self._view.directoryLineEdit.clear() + self._view.directory_line_edit.clear() - intervalLimitsInSeconds = self._presenter.getProcessingIntervalLimitsInSeconds() + interval_limits_s = self._presenter.get_processing_interval_limits_s() - self._view.intervalSpinBox.blockSignals(True) - self._view.intervalSpinBox.setRange( - intervalLimitsInSeconds.lower, intervalLimitsInSeconds.upper - ) - self._view.intervalSpinBox.setValue(self._presenter.getProcessingIntervalInSeconds()) - self._view.intervalSpinBox.blockSignals(False) + self._view.interval_spin_box.blockSignals(True) + self._view.interval_spin_box.setRange(interval_limits_s.lower, interval_limits_s.upper) + self._view.interval_spin_box.setValue(self._presenter.get_processing_interval_s()) + self._view.interval_spin_box.blockSignals(False) - def update(self, observable: Observable) -> None: + def _update(self, observable: Observable) -> None: if observable is self._presenter: - self._syncModelToView() + self._sync_model_to_view() class AutomationWatchdogController(Observer): @@ -99,35 +87,30 @@ def __init__(self, presenter: AutomationPresenter, view: AutomationWatchdogView) self._presenter = presenter self._view = view - @classmethod - def createInstance( - cls, presenter: AutomationPresenter, view: AutomationWatchdogView - ) -> AutomationWatchdogController: - controller = cls(presenter, view) - presenter.addObserver(controller) - - view.delaySpinBox.valueChanged.connect(presenter.setWatchdogDelayInSeconds) - view.usePollingObserverCheckBox.toggled.connect(presenter.setWatchdogPollingObserverEnabled) + presenter.add_observer(self) - controller._syncModelToView() + view.delay_spin_box.valueChanged.connect(presenter.set_watchdog_delay_s) + view.use_polling_observer_check_box.toggled.connect( + presenter.set_watchdog_polling_observer_enabled + ) - return controller + self._sync_model_to_view() - def _syncModelToView(self) -> None: - delayLimitsInSeconds = self._presenter.getWatchdogDelayLimitsInSeconds() + def _sync_model_to_view(self) -> None: + delay_limits_s = self._presenter.get_watchdog_delay_limits_s() - self._view.delaySpinBox.blockSignals(True) - self._view.delaySpinBox.setRange(delayLimitsInSeconds.lower, delayLimitsInSeconds.upper) - self._view.delaySpinBox.setValue(self._presenter.getWatchdogDelayInSeconds()) - self._view.delaySpinBox.blockSignals(False) + self._view.delay_spin_box.blockSignals(True) + self._view.delay_spin_box.setRange(delay_limits_s.lower, delay_limits_s.upper) + self._view.delay_spin_box.setValue(self._presenter.get_watchdog_delay_s()) + self._view.delay_spin_box.blockSignals(False) - self._view.usePollingObserverCheckBox.setChecked( - self._presenter.isWatchdogPollingObserverEnabled() + self._view.use_polling_observer_check_box.setChecked( + self._presenter.is_watchdog_polling_observer_enabled() ) - def update(self, observable: Observable) -> None: + def _update(self, observable: Observable) -> None: if observable is self._presenter: - self._syncModelToView() + self._sync_model_to_view() class AutomationProcessingListModel(QAbstractListModel): @@ -140,10 +123,10 @@ def __init__( def data(self, index: QModelIndex, role: int = Qt.ItemDataRole.DisplayRole) -> Any: if index.isValid(): if role == Qt.ItemDataRole.DisplayRole: - return self._presenter.getDatasetLabel(index.row()) + return self._presenter.get_dataset_label(index.row()) elif role == Qt.ItemDataRole.FontRole: font = QFont() - state = self._presenter.getDatasetState(index.row()) + state = self._presenter.get_dataset_state(index.row()) if state == AutomationDatasetState.WAITING: font.setItalic(True) @@ -152,8 +135,8 @@ def data(self, index: QModelIndex, role: int = Qt.ItemDataRole.DisplayRole) -> A return font - def rowCount(self, parent: QModelIndex = QModelIndex()) -> int: - return self._presenter.getNumberOfDatasets() + def rowCount(self, parent: QModelIndex = QModelIndex()) -> int: # noqa: N802 + return self._presenter.get_num_datasets() class AutomationController(Observer): @@ -161,64 +144,62 @@ def __init__( self, core: AutomationCore, presenter: AutomationPresenter, - processingPresenter: AutomationProcessingPresenter, + processing_presenter: AutomationProcessingPresenter, view: AutomationView, - fileDialogFactory: FileDialogFactory, + file_dialog_factory: FileDialogFactory, ) -> None: super().__init__() self._core = core self._presenter = presenter - self._processingController = AutomationProcessingController.createInstance( - presenter, view.processingView, fileDialogFactory - ) - self._watchdogController = AutomationWatchdogController.createInstance( - presenter, view.watchdogView + self._processing_controller = AutomationProcessingController( + presenter, view.processing_view, file_dialog_factory ) - self._processingPresenter = processingPresenter - self._listModel = AutomationProcessingListModel(processingPresenter) + self._watchdog_controller = AutomationWatchdogController(presenter, view.watchdog_view) + self._processing_presenter = processing_presenter + self._list_model = AutomationProcessingListModel(processing_presenter) self._view = view - self._executeWaitingTasksTimer = QTimer() - self._automationTimer = QTimer() + self._execute_waiting_tasks_timer = QTimer() + self._automation_timer = QTimer() @classmethod - def createInstance( + def create_instance( cls, core: AutomationCore, presenter: AutomationPresenter, - processingPresenter: AutomationProcessingPresenter, + processing_presenter: AutomationProcessingPresenter, view: AutomationView, - fileDialogFactory: FileDialogFactory, + file_dialog_factory: FileDialogFactory, ) -> AutomationController: - controller = cls(core, presenter, processingPresenter, view, fileDialogFactory) - processingPresenter.addObserver(controller) + controller = cls(core, presenter, processing_presenter, view, file_dialog_factory) + processing_presenter.add_observer(controller) - view.processingListView.setModel(controller._listModel) + view.processing_list_view.setModel(controller._list_model) - view.loadButton.clicked.connect(presenter.loadExistingDatasetsToRepository) - view.watchButton.setCheckable(True) - view.watchButton.toggled.connect(presenter.setWatchdogEnabled) - view.processButton.setCheckable(True) - view.processButton.toggled.connect(processingPresenter.setProcessingEnabled) - view.clearButton.clicked.connect(presenter.clearDatasetRepository) + view.load_button.clicked.connect(presenter.load_existing_datasets_to_repository) + view.watch_button.setCheckable(True) + view.watch_button.toggled.connect(presenter.set_watchdog_enabled) + view.process_button.setCheckable(True) + view.process_button.toggled.connect(processing_presenter.set_processing_enabled) + view.clear_button.clicked.connect(presenter.clear_dataset_repository) - controller._syncModelToView() + controller._sync_model_to_view() - controller._executeWaitingTasksTimer.timeout.connect(core.executeWaitingTasks) - controller._executeWaitingTasksTimer.start(60 * 1000) # TODO customize (in milliseconds) + controller._execute_waiting_tasks_timer.timeout.connect(core.execute_waiting_tasks) + controller._execute_waiting_tasks_timer.start(60 * 1000) # TODO customize (in milliseconds) - controller._automationTimer.timeout.connect(core.refreshDatasetRepository) - controller._automationTimer.start(10 * 1000) # TODO customize (in milliseconds) + controller._automation_timer.timeout.connect(core.refresh_dataset_repository) + controller._automation_timer.start(10 * 1000) # TODO customize (in milliseconds) return controller - def _syncModelToView(self) -> None: - self._view.processButton.setChecked(self._processingPresenter.isProcessingEnabled()) - self._listModel.beginResetModel() - self._listModel.endResetModel() + def _sync_model_to_view(self) -> None: + self._view.process_button.setChecked(self._processing_presenter.is_processing_enabled()) + self._list_model.beginResetModel() + self._list_model.endResetModel() - self._view.watchButton.setChecked(self._presenter.isWatchdogEnabled()) - self._view.processButton.setChecked(self._processingPresenter.isProcessingEnabled()) + self._view.watch_button.setChecked(self._presenter.is_watchdog_enabled()) + self._view.process_button.setChecked(self._processing_presenter.is_processing_enabled()) - def update(self, observable: Observable) -> None: - if observable is self._processingPresenter: - self._syncModelToView() + def _update(self, observable: Observable) -> None: + if observable is self._processing_presenter: + self._sync_model_to_view() diff --git a/src/ptychodus/controller/core.py b/src/ptychodus/controller/core.py index e709cbd8..76b85bea 100644 --- a/src/ptychodus/controller/core.py +++ b/src/ptychodus/controller/core.py @@ -5,6 +5,7 @@ from ..model import ModelCore from ..view import ViewCore +from .agent import AgentChatController, AgentController from .automation import AutomationController from .data import FileDialogFactory from .image import ImageController @@ -13,7 +14,9 @@ from .patterns import PatternsController from .probe import ProbeController from .product import ProductController +from .ptychi import PtyChiViewControllerFactory from .ptychonn import PtychoNNViewControllerFactory +from .ptychopinn import PtychoPINNViewControllerFactory from .reconstructor import ReconstructorController from .scan import ScanController from .settings import SettingsController @@ -22,131 +25,159 @@ class ControllerCore: - def __init__(self, model: ModelCore, view: ViewCore) -> None: + def __init__( + self, model: ModelCore, view: ViewCore, *, is_developer_mode_enabled: bool = False + ) -> None: self.view = view - self._memoryController = MemoryController(model.memoryPresenter, view.memoryProgressBar) - self._fileDialogFactory = FileDialogFactory() - self._ptychonnViewControllerFactory = PtychoNNViewControllerFactory( - model.ptychonnReconstructorLibrary, self._fileDialogFactory - ) - self._tikeViewControllerFactory = TikeViewControllerFactory(model.tikeReconstructorLibrary) - self._settingsController = SettingsController( - model.settingsRegistry, - view.settingsView, - view.settingsTableView, - self._fileDialogFactory, - ) - self._patternsImageController = ImageController.createInstance( - model.patternVisualizationEngine, - view.patternsImageView, + self._memory_controller = MemoryController(model.memory_presenter, view.memory_widget) + self._file_dialog_factory = FileDialogFactory() + self._ptychi_view_controller_factory = PtyChiViewControllerFactory( + model.ptychi_reconstructor_library + ) + self._ptychonn_view_controller_factory = PtychoNNViewControllerFactory( + model.ptychonn_reconstructor_library + ) + self._ptychopinn_view_controller_factory = PtychoPINNViewControllerFactory( + model.ptychopinn_reconstructor_library, self._file_dialog_factory + ) + self._tike_view_controller_factory = TikeViewControllerFactory( + model.tike_reconstructor_library + ) + self._settings_controller = SettingsController( + model.settings_registry, + view.settings_view, + view.settings_table_view, + self._file_dialog_factory, + ) + self._patterns_image_controller = ImageController( + model.pattern_visualization_engine, + view.patterns_image_view, view.statusBar(), - self._fileDialogFactory, - ) - self._patternsController = PatternsController.createInstance( - model.detector, - model.diffractionDatasetInputOutputPresenter, - model.diffractionMetadataPresenter, - model.diffractionDatasetPresenter, - model.patternPresenter, - self._patternsImageController, - view.patternsView, - self._fileDialogFactory, - ) - self._productController = ProductController.createInstance( - model.productRepository, - model.productAPI, - view.productView, - self._fileDialogFactory, - ) - self._scanController = ScanController.createInstance( - model.scanRepository, - model.scanAPI, - view.scanView, - view.scanPlotView, - self._fileDialogFactory, - ) - self._probeImageController = ImageController.createInstance( - model.probeVisualizationEngine, - view.probeImageView, + self._file_dialog_factory, + ) + self._patterns_controller = PatternsController( + model.patterns.detector_settings, + model.patterns.pattern_settings, + model.patterns.pattern_sizer, + model.patterns.patterns_api, + model.patterns.dataset, + model.metadata_presenter, + view.patterns_view, + self._patterns_image_controller, + self._file_dialog_factory, + ) + self._product_controller = ProductController.create_instance( + model.patterns.dataset, + model.product.product_repository, + model.product.product_api, + view.product_view, + self._file_dialog_factory, + ) + self._scan_controller = ScanController( + model.product.scan_repository, + model.product.scan_api, + view.scan_view, + view.scan_plot_view, + self._file_dialog_factory, + is_developer_mode_enabled=is_developer_mode_enabled, + ) + self._probe_image_controller = ImageController( + model.probe_visualization_engine, + view.probe_image_view, view.statusBar(), - self._fileDialogFactory, - ) - self._probeController = ProbeController.createInstance( - model.probeRepository, - model.probeAPI, - self._probeImageController, - model.probePropagator, - model.probePropagatorVisualizationEngine, - model.stxmSimulator, - model.stxmVisualizationEngine, - model.exposureAnalyzer, - model.exposureVisualizationEngine, - model.fluorescenceEnhancer, - model.fluorescenceVisualizationEngine, - view.probeView, - self._fileDialogFactory, - ) - self._objectImageController = ImageController.createInstance( - model.objectVisualizationEngine, - view.objectImageView, + self._file_dialog_factory, + ) + self._probe_controller = ProbeController( + model.product.probe_repository, + model.product.probe_api, + self._probe_image_controller, + model.analysis.probe_propagator, + model.analysis.probe_propagator_visualization_engine, + model.analysis.stxm_simulator, + model.analysis.stxm_visualization_engine, + model.analysis.exposure_analyzer, + model.analysis.exposure_visualization_engine, + model.fluorescence_core.enhancer, + model.fluorescence_core.visualization_engine, + view.probe_view, + self._file_dialog_factory, + is_developer_mode_enabled=is_developer_mode_enabled, + ) + self._object_image_controller = ImageController( + model.object_visualization_engine, + view.object_image_view, view.statusBar(), - self._fileDialogFactory, - ) - self._objectController = ObjectController.createInstance( - model.objectRepository, - model.objectAPI, - self._objectImageController, - model.fourierRingCorrelator, - model.xmcdAnalyzer, - model.xmcdVisualizationEngine, - view.objectView, - self._fileDialogFactory, - ) - self._reconstructorParametersController = ReconstructorController.createInstance( - model.reconstructorPresenter, - model.productRepository, - view.reconstructorParametersView, - view.reconstructorPlotView, - self._fileDialogFactory, - self._productController.tableModel, + self._file_dialog_factory, + ) + self._object_controller = ObjectController( + model.product.object_repository, + model.product.object_api, + self._object_image_controller, + model.analysis.fourier_ring_correlator, + model.analysis.xmcd_analyzer, + model.analysis.xmcd_visualization_engine, + view.object_view, + self._file_dialog_factory, + is_developer_mode_enabled=is_developer_mode_enabled, + ) + self._reconstructor_controller = ReconstructorController( + model.reconstructor.presenter, + model.product.product_repository, + view.reconstructor_view, + view.reconstructor_plot_view, + self._product_controller.table_model, + self._file_dialog_factory, [ - self._ptychonnViewControllerFactory, - self._tikeViewControllerFactory, + self._ptychi_view_controller_factory, + self._ptychopinn_view_controller_factory, + self._ptychonn_view_controller_factory, + self._tike_view_controller_factory, ], ) - self._workflowController = WorkflowController.createInstance( - model.workflowParametersPresenter, - model.workflowAuthorizationPresenter, - model.workflowStatusPresenter, - model.workflowExecutionPresenter, - view.workflowParametersView, - view.workflowTableView, - self._productController.tableModel, - ) - self._automationController = AutomationController.createInstance( - model._automationCore, - model.automationPresenter, - model.automationProcessingPresenter, - view.automationView, - self._fileDialogFactory, + self._workflow_controller = WorkflowController( + model.workflow.parameters_presenter, + model.workflow.authorization_presenter, + model.workflow.status_presenter, + model.workflow.execution_presenter, + view.workflow_parameters_view, + view.workflow_table_view, + self._product_controller.table_model, + ) + self._automation_controller = AutomationController.create_instance( + model.automation, + model.automation.presenter, + model.automation.processing_presenter, + view.automation_view, + self._file_dialog_factory, + ) + self._agent_controller = AgentController( + model.agent.settings, model.agent.presenter, view.agent_view + ) + self._agent_chat_controller = AgentChatController( + model.agent.chat_history, model.agent.presenter, view.agent_chat_view ) - self._refreshDataTimer = QTimer() - self._refreshDataTimer.timeout.connect(model.refreshActiveDataset) - self._refreshDataTimer.start(1000) # TODO make configurable + self._refresh_data_timer = QTimer() + self._refresh_data_timer.timeout.connect(model.refresh_active_dataset) + self._refresh_data_timer.start(1000) # TODO make configurable - view.workflowAction.setVisible(model.areWorkflowsSupported) + view.workflow_action.setVisible(model.workflow.is_supported) + + self._swap_central_widgets(view.patterns_action) + view.patterns_action.setChecked(True) + view.navigation_action_group.triggered.connect( + lambda action: self._swap_central_widgets(action) + ) - self.swapCentralWidgets(view.patternsAction) - view.patternsAction.setChecked(True) - view.navigationActionGroup.triggered.connect(lambda action: self.swapCentralWidgets(action)) + view.agent_action.setVisible(is_developer_mode_enabled) + view.scan_view.button_box.analyze_button.setEnabled(is_developer_mode_enabled) - def showMainWindow(self, windowTitle: str) -> None: - self.view.setWindowTitle(windowTitle) + def show_main_window(self, window_title: str) -> None: + self.view.setWindowTitle(window_title) self.view.show() - def swapCentralWidgets(self, action: QAction) -> None: + def _swap_central_widgets(self, action: QAction) -> None: index = action.data() - self.view.parametersWidget.setCurrentIndex(index) - self.view.contentsWidget.setCurrentIndex(index) + self.view.left_panel.setCurrentIndex(index) + self.view.right_panel.setCurrentIndex(index) diff --git a/src/ptychodus/controller/data.py b/src/ptychodus/controller/data.py index 241d90b0..92539bc6 100644 --- a/src/ptychodus/controller/data.py +++ b/src/ptychodus/controller/data.py @@ -3,96 +3,108 @@ from PyQt5.QtWidgets import QDialog, QFileDialog, QWidget +from ptychodus.api.observer import Observable -class FileDialogFactory: + +class FileDialogFactory(Observable): def __init__(self) -> None: - self._openWorkingDirectory = Path.cwd() - self._saveWorkingDirectory = Path.cwd() + super().__init__() + self._open_working_directory = Path.cwd() + self._save_working_directory = Path.cwd() + + def get_open_working_directory(self) -> Path: + return self._open_working_directory + + def set_open_working_directory(self, directory: Path) -> None: + if not directory.is_dir(): + directory = directory.parent - def getOpenWorkingDirectory(self) -> Path: - return self._openWorkingDirectory + directory = directory.resolve() - def setOpenWorkingDirectory(self, directory: Path) -> None: - self._openWorkingDirectory = directory if directory.is_dir() else directory.parent + if self._open_working_directory != directory: + self._open_working_directory = directory + self.notify_observers() - def getOpenFilePath( + def get_open_file_path( self, parent: QWidget, caption: str, - nameFilters: Sequence[str] | None = None, - mimeTypeFilters: Sequence[str] | None = None, - selectedNameFilter: str | None = None, + name_filters: Sequence[str] | None = None, + mime_type_filters: Sequence[str] | None = None, + selected_name_filter: str | None = None, ) -> tuple[Path | None, str]: - filePath = None + file_path = None - dialog = QFileDialog(parent, caption, str(self.getOpenWorkingDirectory())) + dialog = QFileDialog(parent, caption, str(self.get_open_working_directory())) dialog.setAcceptMode(QFileDialog.AcceptMode.AcceptOpen) dialog.setFileMode(QFileDialog.FileMode.ExistingFile) - if nameFilters is not None: - dialog.setNameFilters(nameFilters) + if name_filters is not None: + dialog.setNameFilters(name_filters) - if mimeTypeFilters is not None: - dialog.setMimeTypeFilters(mimeTypeFilters) + if mime_type_filters is not None: + dialog.setMimeTypeFilters(mime_type_filters) - if selectedNameFilter is not None: - dialog.selectNameFilter(selectedNameFilter) + if selected_name_filter is not None: + dialog.selectNameFilter(selected_name_filter) if dialog.exec() == QDialog.DialogCode.Accepted: # TODO exec -> open - fileNameList = dialog.selectedFiles() - fileName = fileNameList[0] + file_name_list = dialog.selectedFiles() + file_name = file_name_list[0] - if fileName: - filePath = Path(fileName) - self.setOpenWorkingDirectory(filePath.parent) + if file_name: + file_path = Path(file_name) + self.set_open_working_directory(file_path.parent) - return filePath, dialog.selectedNameFilter() + return file_path, dialog.selectedNameFilter() - def getSaveFilePath( + def get_save_file_path( self, parent: QWidget, caption: str, - nameFilters: Sequence[str] | None = None, - mimeTypeFilters: Sequence[str] | None = None, - selectedNameFilter: str | None = None, + name_filters: Sequence[str] | None = None, + mime_type_filters: Sequence[str] | None = None, + selected_name_filter: str | None = None, ) -> tuple[Path | None, str]: - filePath = None + file_path = None - dialog = QFileDialog(parent, caption, str(self._saveWorkingDirectory)) + dialog = QFileDialog(parent, caption, str(self._save_working_directory)) dialog.setAcceptMode(QFileDialog.AcceptMode.AcceptSave) dialog.setFileMode(QFileDialog.FileMode.AnyFile) - if nameFilters is not None: - dialog.setNameFilters(nameFilters) + if name_filters is not None: + dialog.setNameFilters(name_filters) - if mimeTypeFilters is not None: - dialog.setMimeTypeFilters(mimeTypeFilters) + if mime_type_filters is not None: + dialog.setMimeTypeFilters(mime_type_filters) - if selectedNameFilter is not None: - dialog.selectNameFilter(selectedNameFilter) + if selected_name_filter is not None: + dialog.selectNameFilter(selected_name_filter) if dialog.exec() == QDialog.DialogCode.Accepted: # TODO exec -> open - fileNameList = dialog.selectedFiles() - fileName = fileNameList[0] + file_name_list = dialog.selectedFiles() + file_name = file_name_list[0] - if fileName: - filePath = Path(fileName) - self._saveWorkingDirectory = filePath.parent + if file_name: + file_path = Path(file_name) + self._save_working_directory = file_path.parent - return filePath, dialog.selectedNameFilter() + return file_path, dialog.selectedNameFilter() - def getExistingDirectoryPath(self, parent: QWidget, caption: str) -> Path | None: - dirPath = None + def get_existing_directory_path( + self, parent: QWidget, caption: str, initial_directory: Path | None = None + ) -> Path | None: + dir_path = None - dirName = QFileDialog.getExistingDirectory( + dir_name = QFileDialog.getExistingDirectory( parent, caption, - str(self.getOpenWorkingDirectory()), + str(initial_directory or self.get_open_working_directory()), QFileDialog.Option.ShowDirsOnly | QFileDialog.Option.DontResolveSymlinks, ) - if dirName: - dirPath = Path(dirName) - self.setOpenWorkingDirectory(dirPath) + if dir_name: + dir_path = Path(dir_name) + self.set_open_working_directory(dir_path) - return dirPath + return dir_path diff --git a/src/ptychodus/controller/image.py b/src/ptychodus/controller/image.py index 15375e0b..551fbaf4 100644 --- a/src/ptychodus/controller/image.py +++ b/src/ptychodus/controller/image.py @@ -9,7 +9,7 @@ from ptychodus.api.geometry import Interval, PixelGeometry from ptychodus.api.observer import Observable, Observer -from ptychodus.api.visualization import NumberArrayType +from ptychodus.api.typing import NumberArrayType from ..model.visualization import VisualizationEngine from ..view.image import ( @@ -33,39 +33,41 @@ class ImageToolsController: def __init__( self, view: ImageToolsGroupBox, - visualizationController: VisualizationController, - mouseToolButtonGroup: QButtonGroup, + visualization_controller: VisualizationController, + mouse_tool_button_group: QButtonGroup, ) -> None: self._view = view - self._visualizationController = visualizationController - self._mouseToolButtonGroup = mouseToolButtonGroup + self._visualization_controller = visualization_controller + self._mouse_tool_button_group = mouse_tool_button_group @classmethod - def createInstance( - cls, view: ImageToolsGroupBox, visualizationController: VisualizationController + def create_instance( + cls, view: ImageToolsGroupBox, visualization_controller: VisualizationController ) -> ImageToolsController: - view.moveButton.setCheckable(True) - view.moveButton.setChecked(True) - view.rulerButton.setCheckable(True) - view.rectangleButton.setCheckable(True) - view.lineCutButton.setCheckable(True) - - mouseToolButtonGroup = QButtonGroup() - mouseToolButtonGroup.addButton(view.moveButton, ImageMouseTool.MOVE_TOOL.value) - mouseToolButtonGroup.addButton(view.rulerButton, ImageMouseTool.RULER_TOOL.value) - mouseToolButtonGroup.addButton(view.rectangleButton, ImageMouseTool.RECTANGLE_TOOL.value) - mouseToolButtonGroup.addButton(view.lineCutButton, ImageMouseTool.LINE_CUT_TOOL.value) - - controller = cls(view, visualizationController, mouseToolButtonGroup) - view.homeButton.clicked.connect(visualizationController.zoomToFit) - view.saveButton.clicked.connect(visualizationController.saveImage) - mouseToolButtonGroup.idToggled.connect(controller._setMouseTool) + view.move_button.setCheckable(True) + view.move_button.setChecked(True) + view.ruler_button.setCheckable(True) + view.rectangle_button.setCheckable(True) + view.line_cut_button.setCheckable(True) + + mouse_tool_button_group = QButtonGroup() + mouse_tool_button_group.addButton(view.move_button, ImageMouseTool.MOVE_TOOL.value) + mouse_tool_button_group.addButton(view.ruler_button, ImageMouseTool.RULER_TOOL.value) + mouse_tool_button_group.addButton( + view.rectangle_button, ImageMouseTool.RECTANGLE_TOOL.value + ) + mouse_tool_button_group.addButton(view.line_cut_button, ImageMouseTool.LINE_CUT_TOOL.value) + + controller = cls(view, visualization_controller, mouse_tool_button_group) + view.home_button.clicked.connect(visualization_controller.zoom_to_fit) + view.save_button.clicked.connect(visualization_controller.save_image) + mouse_tool_button_group.idToggled.connect(controller._set_mouse_tool) return controller - def _setMouseTool(self, toolId: int, checked: bool) -> None: + def _set_mouse_tool(self, tool_id: int, checked: bool) -> None: if checked: - mouseTool = ImageMouseTool(toolId) - self._visualizationController.setMouseTool(mouseTool) + mouse_tool = ImageMouseTool(tool_id) + self._visualization_controller.set_mouse_tool(mouse_tool) class ImageRendererController(Observer): @@ -73,48 +75,48 @@ def __init__(self, engine: VisualizationEngine, view: ImageRendererGroupBox) -> super().__init__() self._engine = engine self._view = view - self._rendererModel = QStringListModel() - self._transformationModel = QStringListModel() - self._variantModel = QStringListModel() + self._renderer_model = QStringListModel() + self._transformation_model = QStringListModel() + self._variant_model = QStringListModel() @classmethod - def createInstance( + def create_instance( cls, engine: VisualizationEngine, view: ImageRendererGroupBox ) -> ImageRendererController: controller = cls(engine, view) - view.rendererComboBox.setModel(controller._rendererModel) - view.transformationComboBox.setModel(controller._transformationModel) - view.variantComboBox.setModel(controller._variantModel) + view.renderer_combo_box.setModel(controller._renderer_model) + view.transformation_combo_box.setModel(controller._transformation_model) + view.variant_combo_box.setModel(controller._variant_model) - controller._syncModelToView() - engine.addObserver(controller) + controller._sync_model_to_view() + engine.add_observer(controller) - view.rendererComboBox.textActivated.connect(engine.setRenderer) - view.transformationComboBox.textActivated.connect(engine.setTransformation) - view.variantComboBox.textActivated.connect(engine.setVariant) + view.renderer_combo_box.textActivated.connect(engine.set_renderer) + view.transformation_combo_box.textActivated.connect(engine.set_transformation) + view.variant_combo_box.textActivated.connect(engine.set_variant) return controller - def _syncModelToView(self) -> None: - self._view.rendererComboBox.blockSignals(True) - self._rendererModel.setStringList([name for name in self._engine.renderers()]) - self._view.rendererComboBox.setCurrentText(self._engine.getRenderer()) - self._view.rendererComboBox.blockSignals(False) + def _sync_model_to_view(self) -> None: + self._view.renderer_combo_box.blockSignals(True) + self._renderer_model.setStringList([name for name in self._engine.renderers()]) + self._view.renderer_combo_box.setCurrentText(self._engine.get_renderer()) + self._view.renderer_combo_box.blockSignals(False) - self._view.transformationComboBox.blockSignals(True) - self._transformationModel.setStringList([name for name in self._engine.transformations()]) - self._view.transformationComboBox.setCurrentText(self._engine.getTransformation()) - self._view.transformationComboBox.blockSignals(False) + self._view.transformation_combo_box.blockSignals(True) + self._transformation_model.setStringList([name for name in self._engine.transformations()]) + self._view.transformation_combo_box.setCurrentText(self._engine.get_transformation()) + self._view.transformation_combo_box.blockSignals(False) - self._view.variantComboBox.blockSignals(True) - self._variantModel.setStringList([name for name in self._engine.variants()]) - self._view.variantComboBox.setCurrentText(self._engine.getVariant()) - self._view.variantComboBox.blockSignals(False) + self._view.variant_combo_box.blockSignals(True) + self._variant_model.setStringList([name for name in self._engine.variants()]) + self._view.variant_combo_box.setCurrentText(self._engine.get_variant()) + self._view.variant_combo_box.blockSignals(False) - def update(self, observable: Observable) -> None: + def _update(self, observable: Observable) -> None: if observable is self._engine: - self._syncModelToView() + self._sync_model_to_view() class ImageDataRangeController(Observer): @@ -122,90 +124,90 @@ def __init__( self, engine: VisualizationEngine, view: ImageDataRangeGroupBox, - imageWidget: ImageWidget, - displayRangeDialog: ImageDisplayRangeDialog, - visualizationController: VisualizationController, + image_widget: ImageWidget, + display_range_dialog: ImageDisplayRangeDialog, + visualization_controller: VisualizationController, ) -> None: self._engine = engine self._view = view - self._imageWidget = imageWidget - self._displayRangeDialog = displayRangeDialog - self._visualizationController = visualizationController - self._displayRangeIsLocked = True + self._image_widget = image_widget + self._display_range_dialog = display_range_dialog + self._visualization_controller = visualization_controller + self._display_range_is_locked = True @classmethod - def createInstance( + def create_instance( cls, engine: VisualizationEngine, view: ImageDataRangeGroupBox, - imageWidget: ImageWidget, - visualizationController: VisualizationController, + image_widget: ImageWidget, + visualization_controller: VisualizationController, ) -> ImageDataRangeController: - displayRangeDialog = ImageDisplayRangeDialog.createInstance(view) - controller = cls(engine, view, imageWidget, displayRangeDialog, visualizationController) - controller._syncModelToView() - engine.addObserver(controller) + display_range_dialog = ImageDisplayRangeDialog.create_instance(view) + controller = cls(engine, view, image_widget, display_range_dialog, visualization_controller) + controller._sync_model_to_view() + engine.add_observer(controller) - view.minDisplayValueSlider.valueChanged.connect( - lambda value: engine.setMinDisplayValue(float(value)) + view.min_display_value_slider.value_changed.connect( + lambda value: engine.set_min_display_value(float(value)) ) - view.maxDisplayValueSlider.valueChanged.connect( - lambda value: engine.setMaxDisplayValue(float(value)) + view.max_display_value_slider.value_changed.connect( + lambda value: engine.set_max_display_value(float(value)) ) - view.autoButton.clicked.connect(controller._autoDisplayRange) - view.editButton.clicked.connect(displayRangeDialog.open) - displayRangeDialog.finished.connect(controller._finishEditingDisplayRange) + view.auto_button.clicked.connect(controller._auto_display_range) + view.edit_button.clicked.connect(display_range_dialog.open) + display_range_dialog.finished.connect(controller._finish_editing_display_range) - view.colorLegendButton.setCheckable(True) - imageWidget.setColorLegendVisible(view.colorLegendButton.isChecked()) - view.colorLegendButton.toggled.connect(imageWidget.setColorLegendVisible) + view.color_legend_button.setCheckable(True) + image_widget.set_color_legend_visible(view.color_legend_button.isChecked()) + view.color_legend_button.toggled.connect(image_widget.set_color_legend_visible) return controller - def _autoDisplayRange(self) -> None: - self._displayRangeIsLocked = False - self._visualizationController.rerenderImage(autoscaleColorAxis=True) - self._displayRangeIsLocked = True + def _auto_display_range(self) -> None: + self._display_range_is_locked = False + self._visualization_controller.rerender_image(autoscale_color_axis=True) + self._display_range_is_locked = True - def _finishEditingDisplayRange(self, result: int) -> None: + def _finish_editing_display_range(self, result: int) -> None: if result == QDialog.DialogCode.Accepted: - lower = float(self._displayRangeDialog.minValueLineEdit.getValue()) - upper = float(self._displayRangeDialog.maxValueLineEdit.getValue()) + lower = float(self._display_range_dialog.min_value_line_edit.get_value()) + upper = float(self._display_range_dialog.max_value_line_edit.get_value()) - self._displayRangeIsLocked = False - self._engine.setDisplayValueRange(lower, upper) - self._displayRangeIsLocked = True + self._display_range_is_locked = False + self._engine.set_display_value_range(lower, upper) + self._display_range_is_locked = True - def _syncColorLegendToView(self) -> None: + def _sync_color_legend_to_view(self) -> None: values = numpy.linspace( - self._engine.getMinDisplayValue(), self._engine.getMaxDisplayValue(), 1000 + self._engine.get_min_display_value(), self._engine.get_max_display_value(), 1000 ) - self._imageWidget.setColorLegendColors( + self._image_widget.set_color_legend_colors( values, self._engine.colorize(values), - self._engine.isRendererCyclic(), + self._engine.is_renderer_cyclic(), ) - def _syncModelToView(self) -> None: - minValue = Decimal(repr(self._engine.getMinDisplayValue())) - maxValue = Decimal(repr(self._engine.getMaxDisplayValue())) + def _sync_model_to_view(self) -> None: + min_value = Decimal(repr(self._engine.get_min_display_value())) + max_value = Decimal(repr(self._engine.get_max_display_value())) - self._displayRangeDialog.minValueLineEdit.setValue(minValue) - self._displayRangeDialog.maxValueLineEdit.setValue(maxValue) + self._display_range_dialog.min_value_line_edit.set_value(min_value) + self._display_range_dialog.max_value_line_edit.set_value(max_value) - if self._displayRangeIsLocked: - self._view.minDisplayValueSlider.setValue(minValue) - self._view.maxDisplayValueSlider.setValue(maxValue) + if self._display_range_is_locked: + self._view.min_display_value_slider.set_value(min_value) + self._view.max_display_value_slider.set_value(max_value) else: - displayRangeLimits = Interval[Decimal](minValue, maxValue) - self._view.minDisplayValueSlider.setValueAndRange(minValue, displayRangeLimits) - self._view.maxDisplayValueSlider.setValueAndRange(maxValue, displayRangeLimits) + display_range_limits = Interval[Decimal](min_value, max_value) + self._view.min_display_value_slider.set_value_and_range(min_value, display_range_limits) + self._view.max_display_value_slider.set_value_and_range(max_value, display_range_limits) - self._syncColorLegendToView() + self._sync_color_legend_to_view() - def update(self, observable: Observable) -> None: + def _update(self, observable: Observable) -> None: if observable is self._engine: - self._syncModelToView() + self._sync_model_to_view() class ImageController: @@ -213,37 +215,27 @@ def __init__( self, engine: VisualizationEngine, view: ImageView, - visualizationController: VisualizationController, + status_bar: QStatusBar, + file_dialog_factory: FileDialogFactory, ) -> None: - self._visualizationController = visualizationController - self._toolsController = ImageToolsController.createInstance( - view.imageRibbon.imageToolsGroupBox, visualizationController + self._visualization_controller = VisualizationController.create_instance( + engine, view.image_widget, status_bar, file_dialog_factory ) - self._rendererController = ImageRendererController.createInstance( - engine, view.imageRibbon.colormapGroupBox + self._tools_controller = ImageToolsController.create_instance( + view.image_ribbon.image_tools_group_box, self._visualization_controller ) - self._dataRangeController = ImageDataRangeController.createInstance( - engine, - view.imageRibbon.dataRangeGroupBox, - view.imageWidget, - visualizationController, + self._renderer_controller = ImageRendererController.create_instance( + engine, view.image_ribbon.colormap_group_box ) - - @classmethod - def createInstance( - cls, - engine: VisualizationEngine, - view: ImageView, - statusBar: QStatusBar, - fileDialogFactory: FileDialogFactory, - ) -> ImageController: - visualizationController = VisualizationController.createInstance( - engine, view.imageWidget, statusBar, fileDialogFactory + self._data_range_controller = ImageDataRangeController.create_instance( + engine, + view.image_ribbon.data_range_group_box, + view.image_widget, + self._visualization_controller, ) - return cls(engine, view, visualizationController) - def setArray(self, array: NumberArrayType, pixelGeometry: PixelGeometry) -> None: - self._visualizationController.setArray(array, pixelGeometry) + def set_array(self, array: NumberArrayType, pixel_geometry: PixelGeometry) -> None: + self._visualization_controller.set_array(array, pixel_geometry) - def clearArray(self) -> None: - self._visualizationController.clearArray() + def clear_array(self) -> None: + self._visualization_controller.clear_array() diff --git a/src/ptychodus/controller/memory.py b/src/ptychodus/controller/memory.py index 20dce3f7..60e59895 100644 --- a/src/ptychodus/controller/memory.py +++ b/src/ptychodus/controller/memory.py @@ -1,29 +1,30 @@ -from __future__ import annotations - from PyQt5.QtCore import QTimer -from PyQt5.QtWidgets import QProgressBar +from PyQt5.QtWidgets import QFrame, QLCDNumber, QSizePolicy from ..model.memory import MemoryPresenter class MemoryController: - def __init__(self, presenter: MemoryPresenter, progressBar: QProgressBar) -> None: + def __init__(self, presenter: MemoryPresenter, widget: QLCDNumber) -> None: self._presenter = presenter - self._progressBar = progressBar + self._widget = widget + self._widget.setSegmentStyle(QLCDNumber.SegmentStyle.Flat) + self._widget.setFrameStyle(QFrame.Panel | QFrame.Sunken) + self._widget.setDigitCount(6) + self._widget.setSizePolicy(QSizePolicy.Policy.Minimum, QSizePolicy.Policy.Preferred) self._timer = QTimer() - self._timer.timeout.connect(self._updateProgressBar) + self._timer.timeout.connect(self._update_widget) - self._updateProgressBar() + self._update_widget() self._timer.start(10 * 1000) # TODO customize (in milliseconds) - def _updateProgressBar(self) -> None: - stats = self._presenter.getStatistics() - totalMemMB = int(stats.totalMemoryInBytes / 1e6) - totalMem = f'Total Memory: {totalMemMB} MB' + def _update_widget(self) -> None: + stats = self._presenter.get_statistics() + total_MB = int(stats.total_physical_memory_bytes / 1e6) # noqa: N806 + total_str = f'Total Memory: {total_MB} MB' - availMemMB = int(stats.availableMemoryInBytes / 1e6) - availMem = f'Available Memory: {availMemMB} MB' + avail_MB = int(stats.available_memory_bytes / 1e6) # noqa: N806 + avail_str = f'Available Memory: {avail_MB} MB' - self._progressBar.setRange(0, 100) - self._progressBar.setValue(int(stats.memoryUsagePercent)) - self._progressBar.setToolTip('\n'.join((totalMem, availMem))) + self._widget.display(avail_MB) + self._widget.setToolTip('\n'.join((total_str, avail_str))) diff --git a/src/ptychodus/controller/object/core.py b/src/ptychodus/controller/object/core.py index 268a3da0..63775959 100644 --- a/src/ptychodus/controller/object/core.py +++ b/src/ptychodus/controller/object/core.py @@ -14,9 +14,9 @@ from ...view.widgets import ComboBoxItemDelegate, ExceptionDialog from ..data import FileDialogFactory from ..image import ImageController -from .editorFactory import ObjectEditorViewControllerFactory +from .editor_factory import ObjectEditorViewControllerFactory from .frc import FourierRingCorrelationViewController -from .treeModel import ObjectTreeModel +from .tree_model import ObjectTreeModel from .xmcd import XMCDViewController logger = logging.getLogger(__name__) @@ -27,230 +27,207 @@ def __init__( self, repository: ObjectRepository, api: ObjectAPI, - imageController: ImageController, + image_controller: ImageController, correlator: FourierRingCorrelator, - xmcdAnalyzer: XMCDAnalyzer, - xmcdVisualizationEngine: VisualizationEngine, + xmcd_analyzer: XMCDAnalyzer, + xmcd_visualization_engine: VisualizationEngine, view: RepositoryTreeView, - fileDialogFactory: FileDialogFactory, - treeModel: ObjectTreeModel, + file_dialog_factory: FileDialogFactory, + *, + is_developer_mode_enabled: bool, ) -> None: super().__init__() self._repository = repository self._api = api - self._imageController = imageController + self._image_controller = image_controller self._view = view - self._fileDialogFactory = fileDialogFactory - self._treeModel = treeModel - self._editorFactory = ObjectEditorViewControllerFactory() + self._file_dialog_factory = file_dialog_factory + self._tree_model = ObjectTreeModel(repository, api) + self._editor_factory = ObjectEditorViewControllerFactory() - self._frcViewController = FourierRingCorrelationViewController(correlator, treeModel) - self._xmcdViewController = XMCDViewController( - xmcdAnalyzer, xmcdVisualizationEngine, fileDialogFactory, treeModel + self._frc_view_controller = FourierRingCorrelationViewController( + correlator, self._tree_model ) - - @classmethod - def createInstance( - cls, - repository: ObjectRepository, - api: ObjectAPI, - imageController: ImageController, - correlator: FourierRingCorrelator, - xmcdAnalyzer: XMCDAnalyzer, - xmcdVisualizationEngine: VisualizationEngine, - view: RepositoryTreeView, - fileDialogFactory: FileDialogFactory, - ) -> ObjectController: - # TODO figure out good fix when saving NPY file without suffix (numpy adds suffix) - treeModel = ObjectTreeModel(repository, api) - controller = cls( - repository, - api, - imageController, - correlator, - xmcdAnalyzer, - xmcdVisualizationEngine, - view, - fileDialogFactory, - treeModel, + self._xmcd_view_controller = XMCDViewController( + xmcd_analyzer, xmcd_visualization_engine, file_dialog_factory, self._tree_model ) - repository.addObserver(controller) - builderListModel = QStringListModel() - builderListModel.setStringList([name for name in api.builderNames()]) - builderItemDelegate = ComboBoxItemDelegate(builderListModel, view.treeView) + # TODO figure out good fix when saving NPY file without suffix (numpy adds suffix) + repository.add_observer(self) - view.treeView.setModel(treeModel) - view.treeView.setSelectionBehavior(QAbstractItemView.SelectionBehavior.SelectRows) - view.treeView.setItemDelegateForColumn(2, builderItemDelegate) - view.treeView.selectionModel().currentChanged.connect(controller._updateView) - controller._updateView(QModelIndex(), QModelIndex()) + builder_list_model = QStringListModel() + builder_list_model.setStringList([name for name in api.builder_names()]) + builder_item_delegate = ComboBoxItemDelegate(builder_list_model, view.tree_view) - loadFromFileAction = view.buttonBox.loadMenu.addAction('Open File...') - loadFromFileAction.triggered.connect(controller._loadCurrentObjectFromFile) + view.tree_view.setModel(self._tree_model) + view.tree_view.setSelectionBehavior(QAbstractItemView.SelectionBehavior.SelectRows) + view.tree_view.setItemDelegateForColumn(2, builder_item_delegate) + view.tree_view.selectionModel().currentChanged.connect(self._update_view) + self._update_view(QModelIndex(), QModelIndex()) - copyAction = view.buttonBox.loadMenu.addAction('Copy...') - copyAction.triggered.connect(controller._copyToCurrentObject) + load_from_file_action = view.button_box.load_menu.addAction('Open File...') + load_from_file_action.triggered.connect(self._load_current_object_from_file) - saveToFileAction = view.buttonBox.saveMenu.addAction('Save File...') - saveToFileAction.triggered.connect(controller._saveCurrentObjectToFile) + copy_action = view.button_box.load_menu.addAction('Copy...') + copy_action.triggered.connect(self._copy_to_current_object) - syncToSettingsAction = view.buttonBox.saveMenu.addAction('Sync To Settings') - syncToSettingsAction.triggered.connect(controller._syncCurrentObjectToSettings) + save_to_file_action = view.button_box.save_menu.addAction('Save File...') + save_to_file_action.triggered.connect(self._save_current_object_to_file) - view.copierDialog.setWindowTitle('Copy Object') - view.copierDialog.sourceComboBox.setModel(treeModel) - view.copierDialog.destinationComboBox.setModel(treeModel) - view.copierDialog.finished.connect(controller._finishCopyingObject) + sync_to_settings_action = view.button_box.save_menu.addAction('Sync To Settings') + sync_to_settings_action.triggered.connect(self._sync_current_object_to_settings) - view.buttonBox.editButton.clicked.connect(controller._editCurrentObject) + view.copier_dialog.setWindowTitle('Copy Object') + view.copier_dialog.source_combo_box.setModel(self._tree_model) + view.copier_dialog.destination_combo_box.setModel(self._tree_model) + view.copier_dialog.finished.connect(self._finish_copying_object) - frcAction = view.buttonBox.analyzeMenu.addAction('Fourier Ring Correlation...') - frcAction.triggered.connect(controller._analyzeFRC) + view.button_box.edit_button.clicked.connect(self._edit_current_object) - xmcdAction = view.buttonBox.analyzeMenu.addAction('XMCD...') - xmcdAction.triggered.connect(controller._analyzeXMCD) + frc_action = view.button_box.analyze_menu.addAction('Fourier Ring Correlation...') + frc_action.triggered.connect(self._analyze_frc) - return controller + xmcd_action = view.button_box.analyze_menu.addAction('XMCD...') + xmcd_action.triggered.connect(self._analyze_xmcd) - def _getCurrentItemIndex(self) -> int: - modelIndex = self._view.treeView.currentIndex() + def _get_current_item_index(self) -> int: + model_index = self._view.tree_view.currentIndex() - if modelIndex.isValid(): - parent = modelIndex.parent() + if model_index.isValid(): + parent = model_index.parent() while parent.isValid(): - modelIndex = parent - parent = modelIndex.parent() + model_index = parent + parent = model_index.parent() - return modelIndex.row() + return model_index.row() logger.warning('No current index!') return -1 - def _loadCurrentObjectFromFile(self) -> None: - itemIndex = self._getCurrentItemIndex() + def _load_current_object_from_file(self) -> None: + item_index = self._get_current_item_index() - if itemIndex < 0: + if item_index < 0: return - filePath, nameFilter = self._fileDialogFactory.getOpenFilePath( + file_path, name_filter = self._file_dialog_factory.get_open_file_path( self._view, 'Open Object', - nameFilters=self._api.getOpenFileFilterList(), - selectedNameFilter=self._api.getOpenFileFilter(), + name_filters=[nf for nf in self._api.get_open_file_filters()], + selected_name_filter=self._api.get_open_file_filter(), ) - if filePath: + if file_path: try: - self._api.openObject(itemIndex, filePath, fileType=nameFilter) + self._api.open_object(item_index, file_path, file_type=name_filter) except Exception as err: logger.exception(err) - ExceptionDialog.showException('File Reader', err) + ExceptionDialog.show_exception('File Reader', err) - def _copyToCurrentObject(self) -> None: - itemIndex = self._getCurrentItemIndex() + def _copy_to_current_object(self) -> None: + item_index = self._get_current_item_index() - if itemIndex >= 0: - self._view.copierDialog.destinationComboBox.setCurrentIndex(itemIndex) - self._view.copierDialog.open() + if item_index >= 0: + self._view.copier_dialog.destination_combo_box.setCurrentIndex(item_index) + self._view.copier_dialog.open() - def _finishCopyingObject(self, result: int) -> None: + def _finish_copying_object(self, result: int) -> None: if result == QDialog.DialogCode.Accepted: - sourceIndex = self._view.copierDialog.sourceComboBox.currentIndex() - destinationIndex = self._view.copierDialog.destinationComboBox.currentIndex() - self._api.copyObject(sourceIndex, destinationIndex) + source_index = self._view.copier_dialog.source_combo_box.currentIndex() + destination_index = self._view.copier_dialog.destination_combo_box.currentIndex() + self._api.copy_object(source_index, destination_index) - def _editCurrentObject(self) -> None: - itemIndex = self._getCurrentItemIndex() + def _edit_current_object(self) -> None: + item_index = self._get_current_item_index() - if itemIndex < 0: + if item_index < 0: return - itemName = self._repository.getName(itemIndex) - item = self._repository[itemIndex] - dialog = self._editorFactory.createEditorDialog(itemName, item, self._view) + item_name = self._repository.get_name(item_index) + item = self._repository[item_index] + dialog = self._editor_factory.create_editor_dialog(item_name, item, self._view) dialog.open() - def _saveCurrentObjectToFile(self) -> None: - itemIndex = self._getCurrentItemIndex() + def _save_current_object_to_file(self) -> None: + item_index = self._get_current_item_index() - if itemIndex < 0: + if item_index < 0: return - filePath, nameFilter = self._fileDialogFactory.getSaveFilePath( + file_path, name_filter = self._file_dialog_factory.get_save_file_path( self._view, 'Save Object', - nameFilters=self._api.getSaveFileFilterList(), - selectedNameFilter=self._api.getSaveFileFilter(), + name_filters=[nameFilter for nameFilter in self._api.get_save_file_filters()], + selected_name_filter=self._api.get_save_file_filter(), ) - if filePath: + if file_path: try: - self._api.saveObject(itemIndex, filePath, nameFilter) + self._api.save_object(item_index, file_path, name_filter) except Exception as err: logger.exception(err) - ExceptionDialog.showException('File Writer', err) + ExceptionDialog.show_exception('File Writer', err) - def _syncCurrentObjectToSettings(self) -> None: - itemIndex = self._getCurrentItemIndex() + def _sync_current_object_to_settings(self) -> None: + item_index = self._get_current_item_index() - if itemIndex < 0: + if item_index < 0: logger.warning('No current item!') else: - item = self._repository[itemIndex] - item.syncToSettings() + item = self._repository[item_index] + item.sync_to_settings() - def _analyzeFRC(self) -> None: - itemIndex = self._getCurrentItemIndex() + def _analyze_frc(self) -> None: + item_index = self._get_current_item_index() - if itemIndex < 0: + if item_index < 0: logger.warning('No current item!') else: - self._frcViewController.analyze(itemIndex, itemIndex) + self._frc_view_controller.analyze(item_index, item_index) - def _analyzeXMCD(self) -> None: - itemIndex = self._getCurrentItemIndex() + def _analyze_xmcd(self) -> None: + item_index = self._get_current_item_index() - if itemIndex < 0: + if item_index < 0: logger.warning('No current item!') else: - self._xmcdViewController.analyze(itemIndex, itemIndex) + self._xmcd_view_controller.analyze(item_index, item_index) - def _updateView(self, current: QModelIndex, previous: QModelIndex) -> None: + def _update_view(self, current: QModelIndex, previous: QModelIndex) -> None: enabled = current.isValid() - self._view.buttonBox.loadButton.setEnabled(enabled) - self._view.buttonBox.saveButton.setEnabled(enabled) - self._view.buttonBox.editButton.setEnabled(enabled) - self._view.buttonBox.analyzeButton.setEnabled(enabled) + self._view.button_box.load_button.setEnabled(enabled) + self._view.button_box.save_button.setEnabled(enabled) + self._view.button_box.edit_button.setEnabled(enabled) + self._view.button_box.analyze_button.setEnabled(enabled) - itemIndex = self._getCurrentItemIndex() + item_index = self._get_current_item_index() - if itemIndex < 0: - self._imageController.clearArray() + if item_index < 0: + self._image_controller.clear_array() else: try: - item = self._repository[itemIndex] + item = self._repository[item_index] except IndexError: logger.warning('Unable to access item for visualization!') else: - object_ = item.getObject() + object_ = item.get_object() array = ( - object_.getLayer(current.row()) + object_.get_layer(current.row()) if current.parent().isValid() - else object_.getLayersFlattened() + else object_.get_layers_flattened() ) - self._imageController.setArray(array, object_.getPixelGeometry()) + self._image_controller.set_array(array, object_.get_pixel_geometry()) - def handleItemInserted(self, index: int, item: ObjectRepositoryItem) -> None: - self._treeModel.insertItem(index, item) + def handle_item_inserted(self, index: int, item: ObjectRepositoryItem) -> None: + self._tree_model.insert_item(index, item) - def handleItemChanged(self, index: int, item: ObjectRepositoryItem) -> None: - self._treeModel.updateItem(index, item) + def handle_item_changed(self, index: int, item: ObjectRepositoryItem) -> None: + self._tree_model.update_item(index, item) - if index == self._getCurrentItemIndex(): - currentIndex = self._view.treeView.currentIndex() - self._updateView(currentIndex, currentIndex) + if index == self._get_current_item_index(): + current_index = self._view.tree_view.currentIndex() + self._update_view(current_index, current_index) - def handleItemRemoved(self, index: int, item: ObjectRepositoryItem) -> None: - self._treeModel.removeItem(index, item) + def handle_item_removed(self, index: int, item: ObjectRepositoryItem) -> None: + self._tree_model.remove_item(index, item) diff --git a/src/ptychodus/controller/object/editorFactory.py b/src/ptychodus/controller/object/editorFactory.py deleted file mode 100644 index b4a3a169..00000000 --- a/src/ptychodus/controller/object/editorFactory.py +++ /dev/null @@ -1,79 +0,0 @@ -from PyQt5.QtWidgets import QDialog, QMessageBox, QSpinBox, QWidget - -from ptychodus.api.observer import Observable, Observer - -from ...model.product.object import ObjectRepositoryItem, RandomObjectBuilder -from ..parametric import ParameterViewBuilder, ParameterViewController - - -class MultisliceViewController(ParameterViewController, Observer): - def __init__(self, item: ObjectRepositoryItem) -> None: - super().__init__() - self._item = item - self._parameter = item.layerDistanceInMeters - self._widget = QSpinBox() - - self._syncModelToView() - self._widget.valueChanged.connect(self._syncViewToModel) - self._parameter.addObserver(self) - - def getWidget(self) -> QWidget: - return self._widget - - def _syncViewToModel(self, numberOfLayers: int) -> None: - self._item.setNumberOfLayers(numberOfLayers) - - def _syncModelToView(self) -> None: - self._widget.blockSignals(True) - self._widget.setRange(1, 99) - self._widget.setValue(self._item.getNumberOfLayers()) - self._widget.blockSignals(False) - - def update(self, observable: Observable) -> None: - if observable is self._parameter: - self._syncModelToView() - - -class ObjectEditorViewControllerFactory: - def createEditorDialog( - self, itemName: str, item: ObjectRepositoryItem, parent: QWidget - ) -> QDialog: - objectBuilder = item.getBuilder() - builderName = objectBuilder.getName() - firstLayerGroup = 'First Layer' - additionalLayersGroup = 'Additional Layers' - title = f'{itemName} [{builderName}]' - - if isinstance(objectBuilder, RandomObjectBuilder): - dialogBuilder = ParameterViewBuilder() - dialogBuilder.addSpinBox( - objectBuilder.extraPaddingX, 'Extra Padding X:', group=firstLayerGroup - ) - dialogBuilder.addSpinBox( - objectBuilder.extraPaddingY, 'Extra Padding Y:', group=firstLayerGroup - ) - dialogBuilder.addDecimalSlider( - objectBuilder.amplitudeMean, 'Amplitude Mean:', group=firstLayerGroup - ) - dialogBuilder.addDecimalSlider( - objectBuilder.amplitudeDeviation, - 'Amplitude Deviation:', - group=firstLayerGroup, - ) - dialogBuilder.addDecimalSlider( - objectBuilder.phaseDeviation, 'Phase Deviation:', group=firstLayerGroup - ) - dialogBuilder.addViewController( - MultisliceViewController(item), - 'Number of Layers:', - group=additionalLayersGroup, - ) - return dialogBuilder.buildDialog(title, parent) - - return QMessageBox( - QMessageBox.Icon.Information, - title, - f'"{builderName}" has no editable parameters!', - QMessageBox.Ok, - parent, - ) diff --git a/src/ptychodus/controller/object/editor_factory.py b/src/ptychodus/controller/object/editor_factory.py new file mode 100644 index 00000000..c0b72e28 --- /dev/null +++ b/src/ptychodus/controller/object/editor_factory.py @@ -0,0 +1,79 @@ +from PyQt5.QtWidgets import QDialog, QMessageBox, QSpinBox, QWidget + +from ptychodus.api.observer import Observable, Observer + +from ...model.product.object import ObjectRepositoryItem, RandomObjectBuilder +from ..parametric import ParameterViewBuilder, ParameterViewController + + +class MultisliceViewController(ParameterViewController, Observer): + def __init__(self, item: ObjectRepositoryItem) -> None: + super().__init__() + self._item = item + self._parameter = item.layer_spacing_m + self._widget = QSpinBox() + + self._sync_model_to_view() + self._widget.valueChanged.connect(self._sync_view_to_model) + self._parameter.add_observer(self) + + def get_widget(self) -> QWidget: + return self._widget + + def _sync_view_to_model(self, num_layers: int) -> None: + self._item.set_num_layers(num_layers) + + def _sync_model_to_view(self) -> None: + self._widget.blockSignals(True) + self._widget.setRange(1, 99) + self._widget.setValue(self._item.get_num_layers()) + self._widget.blockSignals(False) + + def _update(self, observable: Observable) -> None: + if observable is self._parameter: + self._sync_model_to_view() + + +class ObjectEditorViewControllerFactory: + def create_editor_dialog( + self, item_name: str, item: ObjectRepositoryItem, parent: QWidget + ) -> QDialog: + object_builder = item.get_builder() + builder_name = object_builder.get_name() + first_layer_group = 'First Layer' + additional_layers_group = 'Additional Layers' + title = f'{item_name} [{builder_name}]' + + if isinstance(object_builder, RandomObjectBuilder): + dialog_builder = ParameterViewBuilder() + dialog_builder.add_spin_box( + object_builder.extra_padding_x, 'Extra Padding X:', group=first_layer_group + ) + dialog_builder.add_spin_box( + object_builder.extra_padding_y, 'Extra Padding Y:', group=first_layer_group + ) + dialog_builder.add_decimal_slider( + object_builder.amplitude_mean, 'Amplitude Mean:', group=first_layer_group + ) + dialog_builder.add_decimal_slider( + object_builder.amplitude_deviation, + 'Amplitude Deviation:', + group=first_layer_group, + ) + dialog_builder.add_decimal_slider( + object_builder.phase_deviation, 'Phase Deviation:', group=first_layer_group + ) + dialog_builder.add_view_controller( + MultisliceViewController(item), + 'Number of Layers:', + group=additional_layers_group, + ) + return dialog_builder.build_dialog(title, parent) + + return QMessageBox( + QMessageBox.Icon.Information, + title, + f'"{builder_name}" has no editable parameters!', + QMessageBox.Ok, + parent, + ) diff --git a/src/ptychodus/controller/object/frc.py b/src/ptychodus/controller/object/frc.py index 7b0e5c41..4df283b4 100644 --- a/src/ptychodus/controller/object/frc.py +++ b/src/ptychodus/controller/object/frc.py @@ -3,56 +3,56 @@ from ...model.analysis import FourierRingCorrelator from ...view.object import FourierRingCorrelationDialog -from .treeModel import ObjectTreeModel +from .tree_model import ObjectTreeModel logger = logging.getLogger(__name__) class FourierRingCorrelationViewController: - def __init__(self, correlator: FourierRingCorrelator, treeModel: ObjectTreeModel) -> None: + def __init__(self, correlator: FourierRingCorrelator, tree_model: ObjectTreeModel) -> None: super().__init__() self._correlator = correlator self._dialog = FourierRingCorrelationDialog() self._dialog.setWindowTitle('Fourier Ring Correlation') - self._dialog.product1ComboBox.setModel(treeModel) - self._dialog.product1ComboBox.textActivated.connect(self._redrawPlot) - self._dialog.product2ComboBox.setModel(treeModel) - self._dialog.product2ComboBox.textActivated.connect(self._redrawPlot) - - def analyze(self, itemIndex1: int, itemIndex2: int) -> None: - self._dialog.product1ComboBox.setCurrentIndex(itemIndex1) - self._dialog.product2ComboBox.setCurrentIndex(itemIndex2) - self._redrawPlot() + self._dialog.product1_combo_box.setModel(tree_model) + self._dialog.product1_combo_box.textActivated.connect(self._redraw_plot) + self._dialog.product2_combo_box.setModel(tree_model) + self._dialog.product2_combo_box.textActivated.connect(self._redraw_plot) + + def analyze(self, item_index1: int, item_index2: int) -> None: + self._dialog.product1_combo_box.setCurrentIndex(item_index1) + self._dialog.product2_combo_box.setCurrentIndex(item_index2) + self._redraw_plot() self._dialog.open() - def _redrawPlot(self) -> None: - currentIndex1 = self._dialog.product1ComboBox.currentIndex() - currentIndex2 = self._dialog.product2ComboBox.currentIndex() + def _redraw_plot(self) -> None: + current_index1 = self._dialog.product1_combo_box.currentIndex() + current_index2 = self._dialog.product2_combo_box.currentIndex() - if currentIndex1 < 0 or currentIndex2 < 0: + if current_index1 < 0 or current_index2 < 0: logger.warning('Invalid item index for FRC!') return - frc = self._correlator.correlate(currentIndex1, currentIndex2) - plot2D = frc.getPlot() - axisX = plot2D.axisX - axisY = plot2D.axisY + frc = self._correlator.correlate(current_index1, current_index2) + plot2d = frc.get_plot() + axis_x = plot2d.axis_x + axis_y = plot2d.axis_y ax = self._dialog.axes ax.clear() - ax.set_xlabel(axisX.label) - ax.set_ylabel(axisY.label) + ax.set_xlabel(axis_x.label) + ax.set_ylabel(axis_y.label) ax.grid(True) - if len(axisX.series) == 1: - sx = axisX.series[0] + if len(axis_x.series) == 1: + sx = axis_x.series[0] - for sy in axisY.series: + for sy in axis_y.series: ax.plot(sx.values, sy.values, '.-', label=sy.label, linewidth=1.5) else: logger.warning('Failed to broadcast plot series!') - if len(axisX.series) > 1: + if len(axis_x.series) > 1: ax.legend(loc='upper right') - self._dialog.figureCanvas.draw() + self._dialog.figure_canvas.draw() diff --git a/src/ptychodus/controller/object/treeModel.py b/src/ptychodus/controller/object/tree_model.py similarity index 58% rename from src/ptychodus/controller/object/treeModel.py rename to src/ptychodus/controller/object/tree_model.py index 24cc29c7..3db04109 100644 --- a/src/ptychodus/controller/object/treeModel.py +++ b/src/ptychodus/controller/object/tree_model.py @@ -3,6 +3,8 @@ from PyQt5.QtCore import Qt, QAbstractItemModel, QModelIndex, QObject +from ptychodus.api.units import BYTES_PER_MEGABYTE + from ...model.product import ObjectAPI, ObjectRepository from ...model.product.object import ObjectRepositoryItem @@ -12,12 +14,12 @@ def __init__(self, parent: ObjectTreeNode | None = None) -> None: self.parent = parent self.children: list[ObjectTreeNode] = list() - def insertNode(self, index: int = -1) -> ObjectTreeNode: + def insert_node(self, index: int = -1) -> ObjectTreeNode: node = ObjectTreeNode(self) self.children.insert(index, node) return node - def removeNode(self, index: int = -1) -> ObjectTreeNode: + def remove_node(self, index: int = -1) -> ObjectTreeNode: return self.children.pop(index) def row(self) -> int: @@ -34,7 +36,7 @@ def __init__( super().__init__(parent) self._repository = repository self._api = api - self._treeRoot = ObjectTreeNode() + self._tree_root = ObjectTreeNode() self._header = [ 'Name', 'Distance [m]', @@ -46,54 +48,54 @@ def __init__( ] for index, item in enumerate(repository): - self.insertItem(index, item) + self.insert_item(index, item) @staticmethod - def _appendLayers(node: ObjectTreeNode, item: ObjectRepositoryItem) -> None: - object_ = item.getObject() + def _append_layers(node: ObjectTreeNode, item: ObjectRepositoryItem) -> None: + object_ = item.get_object() - for layer in range(object_.numberOfLayers): - node.insertNode() + for layer in range(object_.num_layers): + node.insert_node() - def insertItem(self, index: int, item: ObjectRepositoryItem) -> None: + def insert_item(self, index: int, item: ObjectRepositoryItem) -> None: self.beginInsertRows(QModelIndex(), index, index) - ObjectTreeModel._appendLayers(self._treeRoot.insertNode(index), item) + ObjectTreeModel._append_layers(self._tree_root.insert_node(index), item) self.endInsertRows() - def updateItem(self, index: int, item: ObjectRepositoryItem) -> None: - topLeft = self.index(index, 0) - bottomRight = self.index(index, len(self._header)) - self.dataChanged.emit(topLeft, bottomRight) + def update_item(self, index: int, item: ObjectRepositoryItem) -> None: + top_left = self.index(index, 0) + bottom_right = self.index(index, len(self._header)) + self.dataChanged.emit(top_left, bottom_right) - node = self._treeRoot.children[index] - numLayersOld = len(node.children) - numLayersNew = item.getObject().numberOfLayers + node = self._tree_root.children[index] + num_layers_old = len(node.children) + num_layers_new = item.get_object().num_layers - if numLayersOld < numLayersNew: - self.beginInsertRows(topLeft, numLayersOld, numLayersNew) + if num_layers_old < num_layers_new: + self.beginInsertRows(top_left, num_layers_old, num_layers_new) - while len(node.children) < numLayersNew: - node.insertNode() + while len(node.children) < num_layers_new: + node.insert_node() self.endInsertRows() - elif numLayersOld > numLayersNew: - self.beginRemoveRows(topLeft, numLayersNew, numLayersOld) + elif num_layers_old > num_layers_new: + self.beginRemoveRows(top_left, num_layers_new, num_layers_old) - while len(node.children) > numLayersNew: - node.removeNode() + while len(node.children) > num_layers_new: + node.remove_node() self.endRemoveRows() - childTopLeft = self.index(0, 0, topLeft) - childBottomRight = self.index(numLayersNew, len(self._header), topLeft) - self.dataChanged.emit(childTopLeft, childBottomRight) + child_top_left = self.index(0, 0, top_left) + child_bottom_right = self.index(num_layers_new, len(self._header), top_left) + self.dataChanged.emit(child_top_left, child_bottom_right) - def removeItem(self, index: int, item: ObjectRepositoryItem) -> None: + def remove_item(self, index: int, item: ObjectRepositoryItem) -> None: self.beginRemoveRows(QModelIndex(), index, index) - self._treeRoot.removeNode(index) + self._tree_root.remove_node(index) self.endRemoveRows() - def headerData( + def headerData( # noqa: N802 self, section: int, orientation: Qt.Orientation, @@ -103,21 +105,21 @@ def headerData( return self._header[section] @overload - def parent(self, index: QModelIndex) -> QModelIndex: ... + def parent(self, child: QModelIndex) -> QModelIndex: ... @overload def parent(self) -> QObject: ... - def parent(self, index: QModelIndex | None = None) -> QModelIndex | QObject: - if index is None: + def parent(self, child: QModelIndex | None = None) -> QModelIndex | QObject: + if child is None: return super().parent() - elif index.isValid(): - node = index.internalPointer() - parentNode = node.parent + elif child.isValid(): + node = child.internalPointer() + parent_node = node.parent return ( QModelIndex() - if parentNode is self._treeRoot - else self.createIndex(parentNode.row(), 0, parentNode) + if parent_node is self._tree_root + else self.createIndex(parent_node.row(), 0, parent_node) ) return QModelIndex() @@ -125,10 +127,10 @@ def parent(self, index: QModelIndex | None = None) -> QModelIndex | QObject: def index(self, row: int, column: int, parent: QModelIndex = QModelIndex()) -> QModelIndex: if self.hasIndex(row, column, parent): if parent.isValid(): - parentNode = parent.internalPointer() - node = parentNode.children[row] + parent_node = parent.internalPointer() + node = parent_node.children[row] else: - node = self._treeRoot.children[row] + node = self._tree_root.children[row] return self.createIndex(row, column, node) @@ -148,28 +150,28 @@ def data(self, index: QModelIndex, role: int = Qt.ItemDataRole.DisplayRole) -> A return f'Layer {index.row() + 1}' elif index.column() == 1: try: - return item.layerDistanceInMeters[index.row()] + return item.layer_spacing_m[index.row()] except IndexError: - return float('NaN') + return float('inf') else: item = self._repository[index.row()] - object_ = item.getObject() + object_ = item.get_object() if role == Qt.ItemDataRole.DisplayRole or role == Qt.ItemDataRole.EditRole: if index.column() == 0: - return self._repository.getName(index.row()) + return self._repository.get_name(index.row()) elif index.column() == 1: - return object_.getTotalLayerDistanceInMeters() + return object_.get_total_thickness_m() elif index.column() == 2: - return item.getBuilder().getName() + return item.get_builder().get_name() elif index.column() == 3: - return str(object_.dataType) + return str(object_.dtype) elif index.column() == 4: - return object_.widthInPixels + return object_.width_px elif index.column() == 5: - return object_.heightInPixels + return object_.height_px elif index.column() == 6: - return f'{object_.sizeInBytes / (1024 * 1024):.2f}' + return f'{object_.nbytes / BYTES_PER_MEGABYTE:.2f}' def flags(self, index: QModelIndex) -> Qt.ItemFlags: value = super().flags(index) @@ -181,7 +183,7 @@ def flags(self, index: QModelIndex) -> Qt.ItemFlags: if index.column() == 1: item = self._repository[parent.row()] - if index.row() + 1 < item.getNumberOfLayers(): + if index.row() + 1 < item.get_num_layers(): value |= Qt.ItemFlag.ItemIsEditable else: if index.column() in (0, 2): @@ -189,7 +191,7 @@ def flags(self, index: QModelIndex) -> Qt.ItemFlags: return value - def setData(self, index: QModelIndex, value: Any, role: int = Qt.ItemDataRole.EditRole) -> bool: + def setData(self, index: QModelIndex, value: Any, role: int = Qt.ItemDataRole.EditRole) -> bool: # noqa: N802 if index.isValid() and role == Qt.ItemDataRole.EditRole: parent = index.parent() @@ -198,28 +200,28 @@ def setData(self, index: QModelIndex, value: Any, role: int = Qt.ItemDataRole.Ed if index.column() == 1: try: - distanceInM = float(value) + distance_m = float(value) except ValueError: return False - item.layerDistanceInMeters[index.row()] = distanceInM + item.layer_spacing_m[index.row()] = distance_m return False else: if index.column() == 0: - self._repository.setName(index.row(), str(value)) + self._repository.set_name(index.row(), str(value)) return True elif index.column() == 2: - self._api.buildObject(index.row(), str(value)) + self._api.build_object(index.row(), str(value)) return True return False - def rowCount(self, parent: QModelIndex = QModelIndex()) -> int: + def rowCount(self, parent: QModelIndex = QModelIndex()) -> int: # noqa: N802 if parent.column() > 0: return 0 - node = parent.internalPointer() if parent.isValid() else self._treeRoot + node = parent.internalPointer() if parent.isValid() else self._tree_root return len(node.children) - def columnCount(self, parent: QModelIndex = QModelIndex()) -> int: + def columnCount(self, parent: QModelIndex = QModelIndex()) -> int: # noqa: N802 return len(self._header) diff --git a/src/ptychodus/controller/object/xmcd.py b/src/ptychodus/controller/object/xmcd.py index d8317972..b8f9df0b 100644 --- a/src/ptychodus/controller/object/xmcd.py +++ b/src/ptychodus/controller/object/xmcd.py @@ -1,6 +1,8 @@ import logging -from ...model.analysis import XMCDAnalyzer, XMCDResult +from ptychodus.api.observer import Observable, Observer + +from ...model.analysis import XMCDAnalyzer, XMCDData from ...model.visualization import VisualizationEngine from ...view.object import XMCDDialog from ...view.widgets import ExceptionDialog @@ -9,96 +11,102 @@ VisualizationParametersController, VisualizationWidgetController, ) -from .treeModel import ObjectTreeModel +from .tree_model import ObjectTreeModel logger = logging.getLogger(__name__) -class XMCDViewController: +class XMCDViewController(Observer): def __init__( self, analyzer: XMCDAnalyzer, engine: VisualizationEngine, - fileDialogFactory: FileDialogFactory, - treeModel: ObjectTreeModel, + file_dialog_factory: FileDialogFactory, + tree_model: ObjectTreeModel, ) -> None: super().__init__() self._analyzer = analyzer self._engine = engine - self._fileDialogFactory = fileDialogFactory + self._file_dialog_factory = file_dialog_factory self._dialog = XMCDDialog() - self._dialog.setWindowTitle('XMCD Analysis') - self._dialog.parametersView.lcircComboBox.setModel(treeModel) - self._dialog.parametersView.lcircComboBox.currentIndexChanged.connect(self._analyze) - self._dialog.parametersView.rcircComboBox.setModel(treeModel) - self._dialog.parametersView.rcircComboBox.currentIndexChanged.connect(self._analyze) - self._dialog.parametersView.saveButton.clicked.connect(self._saveResult) + self._dialog.setWindowTitle('X-ray Magnetic Circular Dichroism (XMCD)') + self._dialog.parameters_view.lcirc_combo_box.setModel(tree_model) + self._dialog.parameters_view.lcirc_combo_box.currentIndexChanged.connect( + analyzer.set_lcirc_product + ) + self._dialog.parameters_view.rcirc_combo_box.setModel(tree_model) + self._dialog.parameters_view.rcirc_combo_box.currentIndexChanged.connect( + analyzer.set_rcirc_product + ) + self._dialog.parameters_view.save_button.clicked.connect(self._save_data) - self._differenceVisualizationWidgetController = VisualizationWidgetController( + self._difference_visualization_widget_controller = VisualizationWidgetController( engine, - self._dialog.differenceWidget, - self._dialog.statusBar, - fileDialogFactory, + self._dialog.difference_widget, + self._dialog.status_bar, + file_dialog_factory, ) - self._sumVisualizationWidgetController = VisualizationWidgetController( - engine, self._dialog.sumWidget, self._dialog.statusBar, fileDialogFactory + self._sum_visualization_widget_controller = VisualizationWidgetController( + engine, self._dialog.sum_widget, self._dialog.status_bar, file_dialog_factory ) - self._ratioVisualizationWidgetController = VisualizationWidgetController( - engine, self._dialog.ratioWidget, self._dialog.statusBar, fileDialogFactory + self._ratio_visualization_widget_controller = VisualizationWidgetController( + engine, self._dialog.ratio_widget, self._dialog.status_bar, file_dialog_factory ) - self._visualizationParametersController = VisualizationParametersController.createInstance( - engine, self._dialog.parametersView.visualizationParametersView + self._visualization_parameters_controller = ( + VisualizationParametersController.create_instance( + engine, self._dialog.parameters_view.visualization_parameters_view + ) ) - self._result: XMCDResult | None = None - - def _analyze(self) -> None: - lcircItemIndex = self._dialog.parametersView.lcircComboBox.currentIndex() - rcircItemIndex = self._dialog.parametersView.rcircComboBox.currentIndex() - - if lcircItemIndex < 0 or rcircItemIndex < 0: - return - try: - result = self._analyzer.analyze(lcircItemIndex, rcircItemIndex) - except Exception as err: - logger.exception(err) - ExceptionDialog.showException('XMCD Analysis', err) - return - - self._result = result - self._differenceVisualizationWidgetController.setArray( - result.polar_difference[0, :, :], result.pixel_geometry - ) - self._sumVisualizationWidgetController.setArray( - result.polar_sum[0, :, :], result.pixel_geometry - ) - # TODO support multi-layer objects - self._ratioVisualizationWidgetController.setArray( - result.polar_ratio[0, :, :], result.pixel_geometry - ) + analyzer.add_observer(self) - def analyze(self, lcircItemIndex: int, rcircItemIndex: int) -> None: - self._dialog.parametersView.lcircComboBox.setCurrentIndex(lcircItemIndex) - self._dialog.parametersView.rcircComboBox.setCurrentIndex(rcircItemIndex) - self._analyze() + def analyze(self, lcirc_product_index: int, rcirc_product_index: int) -> None: + self._analyzer.set_lcirc_product(lcirc_product_index) + self._analyzer.set_rcirc_product(rcirc_product_index) + self._analyzer.analyze() self._dialog.open() - def _saveResult(self) -> None: - if self._result is None: - logger.debug('No result to save!') - return - - title = 'Save Result' - filePath, nameFilter = self._fileDialogFactory.getSaveFilePath( + def _save_data(self) -> None: + title = 'Save XMCD Data' + file_path, _ = self._file_dialog_factory.get_save_file_path( self._dialog, title, - nameFilters=self._analyzer.getSaveFileFilterList(), - selectedNameFilter=self._analyzer.getSaveFileFilter(), + name_filters=self._analyzer.get_save_file_filters(), + selected_name_filter=self._analyzer.get_save_file_filter(), ) - if filePath: + if file_path: try: - self._analyzer.saveResult(filePath, self._result) + self._analyzer.save_data(file_path) except Exception as err: logger.exception(err) - ExceptionDialog.showException(title, err) + ExceptionDialog.show_exception(title, err) + + def _sync_model_to_view(self) -> None: + lcirc_product_index = self._analyzer.get_lcirc_product() + self._dialog.parameters_view.lcirc_combo_box.setCurrentIndex(lcirc_product_index) + + rcirc_product_index = self._analyzer.get_rcirc_product() + self._dialog.parameters_view.rcirc_combo_box.setCurrentIndex(rcirc_product_index) + + try: + data = self._analyzer.get_data() + except ValueError: + self._difference_visualization_widget_controller.clear_array() + self._sum_visualization_widget_controller.clear_array() + self._ratio_visualization_widget_controller.clear_array() + except Exception as err: + logger.exception(err) + ExceptionDialog.show_exception('Update Views', err) + else: + self._difference_visualization_widget_controller.set_array( + data.polar_difference, data.pixel_geometry + ) + self._sum_visualization_widget_controller.set_array(data.polar_sum, data.pixel_geometry) + self._ratio_visualization_widget_controller.set_array( + data.polar_ratio, data.pixel_geometry + ) + + def _update(self, observable: Observable) -> None: + if observable is self._analyzer: + self._sync_model_to_view() diff --git a/src/ptychodus/controller/parametric.py b/src/ptychodus/controller/parametric.py index 9cdd9686..e7dba4f5 100644 --- a/src/ptychodus/controller/parametric.py +++ b/src/ptychodus/controller/parametric.py @@ -1,6 +1,8 @@ +from __future__ import annotations from abc import ABC, abstractmethod -from collections.abc import Sequence +from collections.abc import Iterable, Sequence from decimal import Decimal +from pathlib import Path from typing import Final import logging @@ -14,7 +16,9 @@ QDialogButtonBox, QFormLayout, QGroupBox, + QHBoxLayout, QLineEdit, + QPushButton, QSpinBox, QVBoxLayout, QWidget, @@ -25,25 +29,48 @@ from ptychodus.api.parametric import ( BooleanParameter, IntegerParameter, + PathParameter, RealParameter, StringParameter, ) from ..view.widgets import AngleWidget, DecimalLineEdit, DecimalSlider, LengthWidget +from .data import FileDialogFactory logger = logging.getLogger(__name__) -__all__ = [ - 'ParameterViewBuilder', -] - class ParameterViewController(ABC): @abstractmethod - def getWidget(self) -> QWidget: + def get_widget(self) -> QWidget: pass +class CheckableGroupBoxParameterViewController(ParameterViewController, Observer): + def __init__(self, parameter: BooleanParameter, title: str, *, tool_tip: str = '') -> None: + super().__init__() + self._parameter = parameter + self._widget = QGroupBox(title) + self._widget.setCheckable(True) + + if tool_tip: + self._widget.setToolTip(tool_tip) + + self.__sync_model_to_view() + self._widget.toggled.connect(parameter.set_value) + self._parameter.add_observer(self) + + def get_widget(self) -> QWidget: + return self._widget + + def __sync_model_to_view(self) -> None: + self._widget.setChecked(self._parameter.get_value()) + + def _update(self, observable: Observable) -> None: + if observable is self._parameter: + self.__sync_model_to_view() + + class CheckBoxParameterViewController(ParameterViewController, Observer): def __init__(self, parameter: BooleanParameter, text: str, *, tool_tip: str = '') -> None: super().__init__() @@ -53,19 +80,19 @@ def __init__(self, parameter: BooleanParameter, text: str, *, tool_tip: str = '' if tool_tip: self._widget.setToolTip(tool_tip) - self._syncModelToView() - self._widget.toggled.connect(parameter.setValue) - parameter.addObserver(self) + self.__sync_model_to_view() + self._widget.toggled.connect(parameter.set_value) + parameter.add_observer(self) - def getWidget(self) -> QWidget: + def get_widget(self) -> QWidget: return self._widget - def _syncModelToView(self) -> None: - self._widget.setChecked(self._parameter.getValue()) + def __sync_model_to_view(self) -> None: + self._widget.setChecked(self._parameter.get_value()) - def update(self, observable: Observable) -> None: + def _update(self, observable: Observable) -> None: if observable is self._parameter: - self._syncModelToView() + self.__sync_model_to_view() class SpinBoxParameterViewController(ParameterViewController, Observer): @@ -79,19 +106,19 @@ def __init__(self, parameter: IntegerParameter, *, tool_tip: str = '') -> None: if tool_tip: self._widget.setToolTip(tool_tip) - self._syncModelToView() - self._widget.valueChanged.connect(parameter.setValue) - parameter.addObserver(self) + self.__sync_model_to_view() + self._widget.valueChanged.connect(parameter.set_value) + parameter.add_observer(self) - def getWidget(self) -> QWidget: + def get_widget(self) -> QWidget: return self._widget - def _syncModelToView(self) -> None: - minimum = self._parameter.getMinimum() - maximum = self._parameter.getMaximum() + def __sync_model_to_view(self) -> None: + minimum = self._parameter.get_minimum() + maximum = self._parameter.get_maximum() if minimum is None: - logger.error('Minimum not provided!') + raise ValueError('Minimum not provided!') else: self._widget.blockSignals(True) @@ -100,17 +127,17 @@ def _syncModelToView(self) -> None: else: self._widget.setRange(minimum, maximum) - self._widget.setValue(self._parameter.getValue()) + self._widget.setValue(self._parameter.get_value()) self._widget.blockSignals(False) - def update(self, observable: Observable) -> None: + def _update(self, observable: Observable) -> None: if observable is self._parameter: - self._syncModelToView() + self.__sync_model_to_view() class ComboBoxParameterViewController(ParameterViewController, Observer): def __init__( - self, parameter: StringParameter, items: Sequence[str], *, tool_tip: str = '' + self, parameter: StringParameter, items: Iterable[str], *, tool_tip: str = '' ) -> None: super().__init__() self._parameter = parameter @@ -122,19 +149,19 @@ def __init__( if tool_tip: self._widget.setToolTip(tool_tip) - self._syncModelToView() - self._widget.textActivated.connect(parameter.setValue) - parameter.addObserver(self) + self.__sync_model_to_view() + self._widget.textActivated.connect(parameter.set_value) + parameter.add_observer(self) - def getWidget(self) -> QWidget: + def get_widget(self) -> QWidget: return self._widget - def _syncModelToView(self) -> None: - self._widget.setCurrentText(self._parameter.getValue()) + def __sync_model_to_view(self) -> None: + self._widget.setCurrentText(self._parameter.get_value()) - def update(self, observable: Observable) -> None: + def _update(self, observable: Observable) -> None: if observable is self._parameter: - self._syncModelToView() + self.__sync_model_to_view() class LineEditParameterViewController(ParameterViewController, Observer): @@ -151,22 +178,173 @@ def __init__( if tool_tip: self._widget.setToolTip(tool_tip) - self._syncModelToView() - self._widget.editingFinished.connect(self._syncViewToModel) - parameter.addObserver(self) + self.__sync_model_to_view() + self._widget.editingFinished.connect(self.__sync_view_to_model) + parameter.add_observer(self) - def getWidget(self) -> QWidget: + def get_widget(self) -> QWidget: return self._widget - def _syncViewToModel(self) -> None: - self._parameter.setValue(self._widget.text()) + def __sync_view_to_model(self) -> None: + self._parameter.set_value(self._widget.text()) - def _syncModelToView(self) -> None: - self._widget.setText(self._parameter.getValue()) + def __sync_model_to_view(self) -> None: + self._widget.setText(self._parameter.get_value()) - def update(self, observable: Observable) -> None: + def _update(self, observable: Observable) -> None: if observable is self._parameter: - self._syncModelToView() + self.__sync_model_to_view() + + +class PathParameterViewController(ParameterViewController, Observer): + def __init__( + self, + parameter: PathParameter, + file_dialog_factory: FileDialogFactory, + *, + caption: str, + name_filters: Sequence[str] | None, + mime_type_filters: Sequence[str] | None, + selected_name_filter: str | None, + tool_tip: str, + ) -> None: + super().__init__() + self._parameter = parameter + self._file_dialog_factory = file_dialog_factory + self._caption = caption + self._name_filters = name_filters + self._mime_type_filters = mime_type_filters + self._selected_name_filter = selected_name_filter + self._line_edit = QLineEdit() + self._browse_button = QPushButton('Browse') + self._widget = QWidget() + + if tool_tip: + self._line_edit.setToolTip(tool_tip) + + layout = QHBoxLayout() + layout.setContentsMargins(0, 0, 0, 0) + layout.addWidget(self._line_edit) + layout.addWidget(self._browse_button) + self._widget.setLayout(layout) + + self.__sync_model_to_view() + parameter.add_observer(self) + self._line_edit.editingFinished.connect(self.__sync_path_to_model) + + @classmethod + def create_file_opener( + cls, + parameter: PathParameter, + file_dialog_factory: FileDialogFactory, + *, + caption: str = 'Open File', + name_filters: Sequence[str] | None = None, + mime_type_filters: Sequence[str] | None = None, + selected_name_filter: str | None = None, + tool_tip: str = '', + ) -> PathParameterViewController: + view_controller = cls( + parameter, + file_dialog_factory, + caption=caption, + name_filters=name_filters, + mime_type_filters=mime_type_filters, + selected_name_filter=selected_name_filter, + tool_tip=tool_tip, + ) + view_controller._browse_button.clicked.connect(view_controller._choose_file_to_open) + return view_controller + + @classmethod + def create_file_saver( + cls, + parameter: PathParameter, + file_dialog_factory: FileDialogFactory, + *, + caption: str = 'Save File', + name_filters: Sequence[str] | None = None, + mime_type_filters: Sequence[str] | None = None, + selected_name_filter: str | None = None, + tool_tip: str = '', + ) -> PathParameterViewController: + view_controller = cls( + parameter, + file_dialog_factory, + caption=caption, + name_filters=name_filters, + mime_type_filters=mime_type_filters, + selected_name_filter=selected_name_filter, + tool_tip=tool_tip, + ) + view_controller._browse_button.clicked.connect(view_controller._choose_file_to_save) + return view_controller + + @classmethod + def create_directory_chooser( + cls, + parameter: PathParameter, + file_dialog_factory: FileDialogFactory, + *, + caption: str = 'Choose Directory', + tool_tip: str = '', + ) -> PathParameterViewController: + view_controller = cls( + parameter, + file_dialog_factory, + caption=caption, + name_filters=None, + mime_type_filters=None, + selected_name_filter=None, + tool_tip=tool_tip, + ) + view_controller._browse_button.clicked.connect(view_controller._choose_directory) + return view_controller + + def get_widget(self) -> QWidget: + return self._widget + + def __sync_path_to_model(self) -> None: + path = Path(self._line_edit.text()) + self._parameter.set_value(path) + + def _choose_file_to_open(self) -> None: + path, _ = self._file_dialog_factory.get_open_file_path( + self._widget, + self._caption, + self._name_filters, + self._mime_type_filters, + self._selected_name_filter, + ) + + if path: + self._parameter.set_value(path) + + def _choose_file_to_save(self) -> None: + path, _ = self._file_dialog_factory.get_save_file_path( + self._widget, + self._caption, + self._name_filters, + self._mime_type_filters, + self._selected_name_filter, + ) + + if path: + self._parameter.set_value(path) + + def _choose_directory(self) -> None: + path = self._file_dialog_factory.get_existing_directory_path(self._widget, self._caption) + + if path: + self._parameter.set_value(path) + + def __sync_model_to_view(self) -> None: + path = self._parameter.get_value() + self._line_edit.setText(str(path)) + + def _update(self, observable: Observable) -> None: + if observable is self._parameter: + self.__sync_model_to_view() class IntegerLineEditParameterViewController(ParameterViewController, Observer): @@ -179,8 +357,8 @@ def __init__(self, parameter: IntegerParameter, *, tool_tip: str = '') -> None: self._widget.setToolTip(tool_tip) validator = QIntValidator() - bottom = parameter.getMinimum() - top = parameter.getMaximum() + bottom = parameter.get_minimum() + top = parameter.get_maximum() if bottom is not None: validator.setBottom(bottom) @@ -190,14 +368,14 @@ def __init__(self, parameter: IntegerParameter, *, tool_tip: str = '') -> None: self._widget.setValidator(validator) - self._syncModelToView() - self._widget.editingFinished.connect(self._syncViewToModel) - parameter.addObserver(self) + self.__sync_model_to_view() + self._widget.editingFinished.connect(self.__sync_view_to_model) + parameter.add_observer(self) - def getWidget(self) -> QWidget: + def get_widget(self) -> QWidget: return self._widget - def _syncViewToModel(self) -> None: + def __sync_view_to_model(self) -> None: text = self._widget.text() try: @@ -205,14 +383,14 @@ def _syncViewToModel(self) -> None: except ValueError: logger.warning(f'Failed to convert "{text}" to int!') else: - self._parameter.setValue(value) + self._parameter.set_value(value) - def _syncModelToView(self) -> None: - self._widget.setText(str(self._parameter.getValue())) + def __sync_model_to_view(self) -> None: + self._widget.setText(str(self._parameter.get_value())) - def update(self, observable: Observable) -> None: + def _update(self, observable: Observable) -> None: if observable is self._parameter: - self._syncModelToView() + self.__sync_model_to_view() class DecimalLineEditParameterViewController(ParameterViewController, Observer): @@ -221,150 +399,159 @@ def __init__( ) -> None: super().__init__() self._parameter = parameter - self._widget = DecimalLineEdit.createInstance(isSigned=is_signed) + self._widget = DecimalLineEdit.create_instance(is_signed=is_signed) if tool_tip: self._widget.setToolTip(tool_tip) - self._syncModelToView() - self._widget.valueChanged.connect(self._syncViewToModel) - parameter.addObserver(self) + self.__sync_model_to_view() + self._widget.value_changed.connect(self.__sync_view_to_model) + parameter.add_observer(self) - def getWidget(self) -> QWidget: + def get_widget(self) -> QWidget: return self._widget - def _syncViewToModel(self, value: Decimal) -> None: - self._parameter.setValue(float(value)) + def __sync_view_to_model(self, value: Decimal) -> None: + self._parameter.set_value(float(value)) - def _syncModelToView(self) -> None: - self._widget.setValue(Decimal(repr(self._parameter.getValue()))) + def __sync_model_to_view(self) -> None: + self._widget.set_value(Decimal(str(self._parameter.get_value()))) - def update(self, observable: Observable) -> None: + def _update(self, observable: Observable) -> None: if observable is self._parameter: - self._syncModelToView() + self.__sync_model_to_view() class DecimalSliderParameterViewController(ParameterViewController, Observer): def __init__(self, parameter: RealParameter, *, tool_tip: str = '') -> None: super().__init__() self._parameter = parameter - self._widget = DecimalSlider.createInstance(Qt.Orientation.Horizontal) + self._widget = DecimalSlider.create_instance(Qt.Orientation.Horizontal) if tool_tip: self._widget.setToolTip(tool_tip) - self._syncModelToView() - self._widget.valueChanged.connect(self._syncViewToModel) - parameter.addObserver(self) + self.__sync_model_to_view() + self._widget.value_changed.connect(self.__sync_view_to_model) + parameter.add_observer(self) - def getWidget(self) -> QWidget: + def get_widget(self) -> QWidget: return self._widget - def _syncViewToModel(self, value: Decimal) -> None: - self._parameter.setValue(float(value)) + def __sync_view_to_model(self, value: Decimal) -> None: + self._parameter.set_value(float(value)) - def _syncModelToView(self) -> None: - minimum = self._parameter.getMinimum() - maximum = self._parameter.getMaximum() + def __sync_model_to_view(self) -> None: + minimum = self._parameter.get_minimum() + maximum = self._parameter.get_maximum() if minimum is None or maximum is None: - logger.error('Range not provided!') + raise ValueError('Range not provided!') else: - value = Decimal(repr(self._parameter.getValue())) - range_ = Interval[Decimal](Decimal(repr(minimum)), Decimal(repr(maximum))) - self._widget.setValueAndRange(value, range_) + value = Decimal(str(self._parameter.get_value())) + range_ = Interval[Decimal](Decimal(str(minimum)), Decimal(str(maximum))) + self._widget.set_value_and_range(value, range_) - def update(self, observable: Observable) -> None: + def _update(self, observable: Observable) -> None: if observable is self._parameter: - self._syncModelToView() + self.__sync_model_to_view() class LengthWidgetParameterViewController(ParameterViewController, Observer): - def __init__(self, parameter: RealParameter, *, is_signed: bool = False) -> None: + def __init__( + self, parameter: RealParameter, *, is_signed: bool = False, tool_tip: str = '' + ) -> None: super().__init__() self._parameter = parameter - self._widget = LengthWidget.createInstance(isSigned=is_signed) + self._widget = LengthWidget.create_instance(is_signed=is_signed) + + if tool_tip: + self._widget.setToolTip(tool_tip) - self._syncModelToView() - self._widget.lengthChanged.connect(self._syncViewToModel) - parameter.addObserver(self) + self.__sync_model_to_view() + self._widget.length_changed.connect(self.__sync_view_to_model) + parameter.add_observer(self) - def getWidget(self) -> QWidget: + def get_widget(self) -> QWidget: return self._widget - def _syncViewToModel(self, value: Decimal) -> None: - self._parameter.setValue(float(value)) + def __sync_view_to_model(self, value: Decimal) -> None: + self._parameter.set_value(float(value)) - def _syncModelToView(self) -> None: - self._widget.setLengthInMeters(Decimal(repr(self._parameter.getValue()))) + def __sync_model_to_view(self) -> None: + self._widget.set_length_m(Decimal(str(self._parameter.get_value()))) - def update(self, observable: Observable) -> None: + def _update(self, observable: Observable) -> None: if observable is self._parameter: - self._syncModelToView() + self.__sync_model_to_view() class AngleWidgetParameterViewController(ParameterViewController, Observer): - def __init__(self, parameter: RealParameter) -> None: + def __init__(self, parameter: RealParameter, tool_tip: str = '') -> None: super().__init__() self._parameter = parameter - self._widget = AngleWidget.createInstance() + self._widget = AngleWidget.create_instance() + + if tool_tip: + self._widget.setToolTip(tool_tip) - self._syncModelToView() - self._widget.angleChanged.connect(self._syncViewToModel) - parameter.addObserver(self) + self.__sync_model_to_view() + self._widget.angle_changed.connect(self.__sync_view_to_model) + parameter.add_observer(self) - def getWidget(self) -> QWidget: + def get_widget(self) -> QWidget: return self._widget - def _syncViewToModel(self, value: Decimal) -> None: - self._parameter.setValue(float(value)) + def __sync_view_to_model(self, value: Decimal) -> None: + self._parameter.set_value(float(value)) - def _syncModelToView(self) -> None: - self._widget.setAngleInTurns(Decimal(repr(self._parameter.getValue()))) + def __sync_model_to_view(self) -> None: + self._widget.set_angle_in_turns(Decimal(str(self._parameter.get_value()))) - def update(self, observable: Observable) -> None: + def _update(self, observable: Observable) -> None: if observable is self._parameter: - self._syncModelToView() + self.__sync_model_to_view() class ParameterWidget(QWidget): def __init__( - self, viewControllers: Sequence[ParameterViewController], parent: QWidget | None = None + self, view_controllers: Sequence[ParameterViewController], parent: QWidget | None = None ) -> None: super().__init__(parent) - self._viewControllers = viewControllers + self._view_controllers = view_controllers class ParameterDialog(QDialog): def __init__( self, - viewControllers: Sequence[ParameterViewController], - buttonBox: QDialogButtonBox, + view_controllers: Sequence[ParameterViewController], + button_box: QDialogButtonBox, parent: QWidget | None, ) -> None: super().__init__(parent) - self._viewControllers = viewControllers - self._buttonBox = buttonBox + self._view_controllers = view_controllers + self._button_box = button_box - buttonBox.addButton(QDialogButtonBox.StandardButton.Ok) - buttonBox.clicked.connect(self._handleButtonBoxClicked) + button_box.addButton(QDialogButtonBox.StandardButton.Ok) + button_box.clicked.connect(self._handle_button_box_clicked) - def _handleButtonBoxClicked(self, button: QAbstractButton) -> None: + def _handle_button_box_clicked(self, button: QAbstractButton) -> None: # TODO remove observers from viewControllers - if self._buttonBox.buttonRole(button) == QDialogButtonBox.ButtonRole.AcceptRole: + if self._button_box.buttonRole(button) == QDialogButtonBox.ButtonRole.AcceptRole: self.accept() else: self.reject() class ParameterViewBuilder: - def __init__(self) -> None: - self._viewControllersTop: list[ParameterViewController] = list() - self._viewControllers: dict[tuple[str, str], ParameterViewController] = dict() - self._viewControllersBottom: list[ParameterViewController] = list() + def __init__(self, file_dialog_factory: FileDialogFactory | None = None) -> None: + self._file_dialog_factory = file_dialog_factory + self._view_controllers_top: list[ParameterViewController] = list() + self._view_controllers: dict[tuple[str, str], ParameterViewController] = dict() + self._view_controllers_bottom: list[ParameterViewController] = list() - def addCheckBox( + def add_check_box( self, parameter: BooleanParameter, label: str, @@ -372,10 +559,87 @@ def addCheckBox( tool_tip: str = '', group: str = '', ) -> None: - viewController = CheckBoxParameterViewController(parameter, '') - self.addViewController(viewController, label, tool_tip=tool_tip, group=group) + view_controller = CheckBoxParameterViewController(parameter, '', tool_tip=tool_tip) + self.add_view_controller(view_controller, label, group=group) - def addSpinBox( + def add_combo_box( + self, + parameter: StringParameter, + items: Iterable[str], + label: str, + *, + tool_tip: str = '', + group: str = '', + ) -> None: + view_controller = ComboBoxParameterViewController(parameter, items, tool_tip=tool_tip) + self.add_view_controller(view_controller, label, group=group) + + def add_file_opener( + self, + parameter: PathParameter, + label: str, + *, + caption: str = 'Open File', + name_filters: Sequence[str] | None = None, + mime_type_filters: Sequence[str] | None = None, + selected_name_filter: str | None = None, + tool_tip: str = '', + group: str = '', + ) -> None: + if self._file_dialog_factory is None: + raise ValueError('Cannot add file chooser without FileDialogFactory!') + else: + view_controller = PathParameterViewController.create_file_opener( + parameter, + self._file_dialog_factory, + caption=caption, + name_filters=name_filters, + mime_type_filters=mime_type_filters, + selected_name_filter=selected_name_filter, + tool_tip=tool_tip, + ) + self.add_view_controller(view_controller, label, group=group) + + def add_file_saver( + self, + parameter: PathParameter, + label: str, + *, + caption: str = 'Save File', + name_filters: Sequence[str] | None = None, + mime_type_filters: Sequence[str] | None = None, + selected_name_filter: str | None = None, + tool_tip: str = '', + group: str = '', + ) -> None: + if self._file_dialog_factory is None: + raise ValueError('Cannot add file chooser without FileDialogFactory!') + else: + view_controller = PathParameterViewController.create_file_saver( + parameter, + self._file_dialog_factory, + caption=caption, + name_filters=name_filters, + mime_type_filters=mime_type_filters, + selected_name_filter=selected_name_filter, + tool_tip=tool_tip, + ) + self.add_view_controller(view_controller, label, group=group) + + def add_directory_chooser( + self, parameter: PathParameter, label: str, *, tool_tip: str = '', group: str = '' + ) -> None: + if self._file_dialog_factory is None: + raise ValueError('Cannot add directory chooser without FileDialogFactory!') + else: + view_controller = PathParameterViewController.create_directory_chooser( + parameter, + self._file_dialog_factory, + tool_tip=tool_tip, + ) + self.add_view_controller(view_controller, label, group=group) + + def add_spin_box( self, parameter: IntegerParameter, label: str, @@ -383,10 +647,21 @@ def addSpinBox( tool_tip: str = '', group: str = '', ) -> None: - viewController = SpinBoxParameterViewController(parameter) - self.addViewController(viewController, label, tool_tip=tool_tip, group=group) + view_controller = SpinBoxParameterViewController(parameter, tool_tip=tool_tip) + self.add_view_controller(view_controller, label, group=group) - def addDecimalLineEdit( + def add_integer_line_edit( + self, + parameter: IntegerParameter, + label: str, + *, + tool_tip: str = '', + group: str = '', + ) -> None: + view_controller = IntegerLineEditParameterViewController(parameter, tool_tip=tool_tip) + self.add_view_controller(view_controller, label, group=group) + + def add_decimal_line_edit( self, parameter: RealParameter, label: str, @@ -394,10 +669,10 @@ def addDecimalLineEdit( tool_tip: str = '', group: str = '', ) -> None: - viewController = DecimalLineEditParameterViewController(parameter) - self.addViewController(viewController, label, tool_tip=tool_tip, group=group) + view_controller = DecimalLineEditParameterViewController(parameter, tool_tip=tool_tip) + self.add_view_controller(view_controller, label, group=group) - def addDecimalSlider( + def add_decimal_slider( self, parameter: RealParameter, label: str, @@ -405,10 +680,10 @@ def addDecimalSlider( tool_tip: str = '', group: str = '', ) -> None: - viewController = DecimalSliderParameterViewController(parameter) - self.addViewController(viewController, label, tool_tip=tool_tip, group=group) + view_controller = DecimalSliderParameterViewController(parameter, tool_tip=tool_tip) + self.add_view_controller(view_controller, label, group=group) - def addLengthWidget( + def add_length_widget( self, parameter: RealParameter, label: str, @@ -416,10 +691,10 @@ def addLengthWidget( tool_tip: str = '', group: str = '', ) -> None: - viewController = LengthWidgetParameterViewController(parameter) - self.addViewController(viewController, label, tool_tip=tool_tip, group=group) + view_controller = LengthWidgetParameterViewController(parameter, tool_tip=tool_tip) + self.add_view_controller(view_controller, label, group=group) - def addAngleWidget( + def add_angle_widget( self, parameter: RealParameter, label: str, @@ -427,85 +702,84 @@ def addAngleWidget( tool_tip: str = '', group: str = '', ) -> None: - viewController = AngleWidgetParameterViewController(parameter) - self.addViewController(viewController, label, tool_tip=tool_tip, group=group) + view_controller = AngleWidgetParameterViewController(parameter, tool_tip=tool_tip) + self.add_view_controller(view_controller, label, group=group) - def addViewControllerToTop(self, viewController: ParameterViewController) -> None: - self._viewControllersTop.append(viewController) + def add_view_controller_to_top(self, view_controller: ParameterViewController) -> None: + self._view_controllers_top.append(view_controller) - def addViewController( + def add_view_controller( self, - viewController: ParameterViewController, + view_controller: ParameterViewController, label: str, *, - tool_tip: str = '', group: str = '', ) -> None: - self._viewControllers[group, label] = viewController + self._view_controllers[group, label] = view_controller - def addViewControllerToBottom(self, viewController: ParameterViewController) -> None: - self._viewControllersBottom.append(viewController) + def add_view_controller_to_bottom(self, view_controller: ParameterViewController) -> None: + self._view_controllers_bottom.append(view_controller) - def _buildLayout(self, *, add_stretch: bool) -> QVBoxLayout: - groupDict: dict[str, QFormLayout] = dict() + def _build_layout(self, *, add_stretch: bool) -> QVBoxLayout: + group_dict: dict[str, QFormLayout] = dict() - for (groupName, widgetLabel), vc in self._viewControllers.items(): + for (group_name, widget_label), vc in self._view_controllers.items(): try: - formLayout = groupDict[groupName] + form_layout = group_dict[group_name] except KeyError: - formLayout = QFormLayout() - groupDict[groupName] = formLayout + form_layout = QFormLayout() + group_dict[group_name] = form_layout - formLayout.addRow(widgetLabel, vc.getWidget()) + form_layout.addRow(widget_label, vc.get_widget()) layout = QVBoxLayout() - for viewController in self._viewControllersTop: - layout.addWidget(viewController.getWidget()) + for view_controller in self._view_controllers_top: + layout.addWidget(view_controller.get_widget()) - for groupName, groupLayout in groupDict.items(): - if groupName: - groupBox = QGroupBox(groupName) - groupBox.setLayout(groupLayout) - layout.addWidget(groupBox) - elif groupLayout.count() > 0: - layout.addLayout(groupLayout) + for group_name, group_layout in group_dict.items(): + if group_name: + group_box = QGroupBox(group_name) + group_box.setLayout(group_layout) + layout.addWidget(group_box) + elif group_layout.count() > 0: + layout.addLayout(group_layout) - for viewController in self._viewControllersBottom: - layout.addWidget(viewController.getWidget()) + for view_controller in self._view_controllers_bottom: + layout.addWidget(view_controller.get_widget()) if add_stretch: layout.addStretch() return layout - def _flushViewControllers(self) -> Sequence[ParameterViewController]: - viewControllers: list[ParameterViewController] = list() - viewControllers.extend(self._viewControllersTop) - viewControllers.extend(self._viewControllers.values()) - viewControllers.extend(self._viewControllersBottom) + def _flush_view_controllers(self) -> Sequence[ParameterViewController]: + view_controllers: list[ParameterViewController] = list() + view_controllers.extend(self._view_controllers_top) + view_controllers.extend(self._view_controllers.values()) + view_controllers.extend(self._view_controllers_bottom) - self._viewControllersTop.clear() - self._viewControllers.clear() - self._viewControllersBottom.clear() + self._view_controllers_top.clear() + self._view_controllers.clear() + self._view_controllers_bottom.clear() - return viewControllers + return view_controllers - def buildWidget(self) -> QWidget: - layout = self._buildLayout(add_stretch=True) + def build_widget(self) -> QWidget: + layout = self._build_layout(add_stretch=True) - widget = ParameterWidget(self._flushViewControllers()) + widget = ParameterWidget(self._flush_view_controllers()) widget.setLayout(layout) return widget - def buildDialog(self, windowTitle: str, parent: QWidget | None) -> QDialog: - buttonBox = QDialogButtonBox() - layout = self._buildLayout(add_stretch=False) - layout.addWidget(buttonBox) + def build_dialog(self, window_title: str, parent: QWidget | None) -> QDialog: + button_box = QDialogButtonBox() + layout = self._build_layout(add_stretch=True) + layout.addWidget(button_box) - dialog = ParameterDialog(self._flushViewControllers(), buttonBox, parent) + dialog = ParameterDialog(self._flush_view_controllers(), button_box, parent) dialog.setLayout(layout) - dialog.setWindowTitle(windowTitle) + dialog.setWindowTitle(window_title) return dialog diff --git a/src/ptychodus/controller/patterns/core.py b/src/ptychodus/controller/patterns/core.py index e110112d..5c9a3ad0 100644 --- a/src/ptychodus/controller/patterns/core.py +++ b/src/ptychodus/controller/patterns/core.py @@ -1,148 +1,148 @@ -from __future__ import annotations import logging -from PyQt5.QtCore import QModelIndex -from PyQt5.QtWidgets import QAbstractItemView, QMessageBox -from ptychodus.api.observer import Observable, Observer +from PyQt5.QtCore import QModelIndex +from PyQt5.QtWidgets import QAbstractItemView, QFormLayout, QMessageBox +from ...model.metadata import MetadataPresenter from ...model.patterns import ( - Detector, - DiffractionDatasetInputOutputPresenter, - DiffractionDatasetPresenter, - DiffractionMetadataPresenter, - DiffractionPatternPresenter, + AssembledDiffractionDataset, + DetectorSettings, + DiffractionDatasetObserver, + PatternSettings, + PatternSizer, + PatternsAPI, ) -from ...view.patterns import PatternsView -from ...view.widgets import ExceptionDialog +from ...view.patterns import DetectorView, PatternsView +from ...view.widgets import ExceptionDialog, ProgressBarItemDelegate from ..data import FileDialogFactory from ..image import ImageController -from .detector import DetectorController +from ..parametric import LengthWidgetParameterViewController, SpinBoxParameterViewController +from .dataset import DatasetTreeModel from .info import PatternsInfoViewController -from .treeModel import DatasetTreeModel, DatasetTreeNode from .wizard import OpenDatasetWizardController logger = logging.getLogger(__name__) -class PatternsController(Observer): +class DetectorController: + def __init__(self, settings: DetectorSettings, view: DetectorView) -> None: + self._width_px_view_controller = SpinBoxParameterViewController(settings.width_px) + self._height_px_view_controller = SpinBoxParameterViewController(settings.height_px) + self._pixel_width_view_controller = LengthWidgetParameterViewController( + settings.pixel_width_m + ) + self._pixel_height_view_controller = LengthWidgetParameterViewController( + settings.pixel_height_m + ) + self._bit_depth_view_controller = SpinBoxParameterViewController(settings.bit_depth) + + layout = QFormLayout() + layout.addRow('Detector Width [px]:', self._width_px_view_controller.get_widget()) + layout.addRow('Detector Height [px]:', self._height_px_view_controller.get_widget()) + layout.addRow('Pixel Width:', self._pixel_width_view_controller.get_widget()) + layout.addRow('Pixel Height:', self._pixel_height_view_controller.get_widget()) + layout.addRow('Bit Depth:', self._bit_depth_view_controller.get_widget()) + view.setLayout(layout) + + +class PatternsController(DiffractionDatasetObserver): def __init__( self, - detector: Detector, - ioPresenter: DiffractionDatasetInputOutputPresenter, - metadataPresenter: DiffractionMetadataPresenter, - datasetPresenter: DiffractionDatasetPresenter, - patternPresenter: DiffractionPatternPresenter, - imageController: ImageController, + detector_settings: DetectorSettings, + pattern_settings: PatternSettings, + pattern_sizer: PatternSizer, + patterns_api: PatternsAPI, + dataset: AssembledDiffractionDataset, + metadata_presenter: MetadataPresenter, view: PatternsView, - fileDialogFactory: FileDialogFactory, + image_controller: ImageController, + file_dialog_factory: FileDialogFactory, ) -> None: super().__init__() - self._detector = detector - self._datasetPresenter = datasetPresenter - self._ioPresenter = ioPresenter - self._imageController = imageController + self._pattern_sizer = pattern_sizer + self._patterns_api = patterns_api + self._dataset = dataset self._view = view - self._fileDialogFactory = fileDialogFactory - self._detectorController = DetectorController(detector, view.detectorView) - self._wizardController = OpenDatasetWizardController.createInstance( - ioPresenter, - metadataPresenter, - datasetPresenter, - patternPresenter, - view.openDatasetWizard, - fileDialogFactory, + self._image_controller = image_controller + self._file_dialog_factory = file_dialog_factory + self._detector_controller = DetectorController(detector_settings, view.detector_view) + self._wizard_controller = OpenDatasetWizardController( + pattern_settings, + pattern_sizer, + patterns_api, + metadata_presenter, + file_dialog_factory, ) - self._treeModel = DatasetTreeModel() - - @classmethod - def createInstance( - cls, - detector: Detector, - ioPresenter: DiffractionDatasetInputOutputPresenter, - metadataPresenter: DiffractionMetadataPresenter, - datasetPresenter: DiffractionDatasetPresenter, - patternPresenter: DiffractionPatternPresenter, - imageController: ImageController, - view: PatternsView, - fileDialogFactory: FileDialogFactory, - ) -> PatternsController: - controller = cls( - detector, - ioPresenter, - metadataPresenter, - datasetPresenter, - patternPresenter, - imageController, - view, - fileDialogFactory, - ) - - view.treeView.setModel(controller._treeModel) - view.treeView.setSelectionBehavior(QAbstractItemView.SelectionBehavior.SelectRows) - view.treeView.selectionModel().currentChanged.connect(controller._updateView) - controller._updateView(QModelIndex(), QModelIndex()) + self._tree_model = DatasetTreeModel() - view.buttonBox.openButton.clicked.connect(controller._wizardController.openDataset) - view.buttonBox.saveButton.clicked.connect(controller._saveDataset) - view.buttonBox.infoButton.clicked.connect(controller._openPatternsInfo) - view.buttonBox.closeButton.clicked.connect(controller._closeDataset) - view.buttonBox.closeButton.setEnabled(False) # TODO - datasetPresenter.addObserver(controller) + view.tree_view.setModel(self._tree_model) + view.tree_view.setSelectionBehavior(QAbstractItemView.SelectionBehavior.SelectRows) + counts_item_delegate = ProgressBarItemDelegate(view.tree_view) + view.tree_view.setItemDelegateForColumn(1, counts_item_delegate) + view.tree_view.selectionModel().currentChanged.connect(self._update_view) + self._update_view(QModelIndex(), QModelIndex()) - controller._syncModelToView() + view.button_box.open_button.clicked.connect(self._wizard_controller.open_dataset) + view.button_box.save_button.clicked.connect(self._save_dataset) + view.button_box.info_button.clicked.connect(self._open_patterns_info) + view.button_box.close_button.clicked.connect(self._close_dataset) + dataset.add_observer(self) - return controller + self._sync_model_to_view() - def _updateView(self, current: QModelIndex, previous: QModelIndex) -> None: + def _update_view(self, current: QModelIndex, previous: QModelIndex) -> None: if current.isValid(): node = current.internalPointer() - pixelGeometry = self._detector.getPixelGeometry() - self._imageController.setArray(node.data, pixelGeometry) + data = node.get_data() + pixel_geometry = self._pattern_sizer.get_processed_pixel_geometry() + self._image_controller.set_array(data, pixel_geometry) else: - self._imageController.clearArray() + self._image_controller.clear_array() - def _saveDataset(self) -> None: - filePath, nameFilter = self._fileDialogFactory.getSaveFilePath( + def _save_dataset(self) -> None: + file_writer_chooser = self._patterns_api.get_file_writer_chooser() + file_path, name_filter = self._file_dialog_factory.get_save_file_path( self._view, 'Save Diffraction File', - nameFilters=self._ioPresenter.getSaveFileFilterList(), - selectedNameFilter=self._ioPresenter.getSaveFileFilter(), + name_filters=[plugin.display_name for plugin in file_writer_chooser], + selected_name_filter=file_writer_chooser.get_current_plugin().display_name, ) - if filePath: + if file_path: try: - self._ioPresenter.saveDiffractionFile(filePath, nameFilter) + self._patterns_api.save_patterns(file_path, name_filter) except Exception as err: logger.exception(err) - ExceptionDialog.showException('File Writer', err) + ExceptionDialog.show_exception('File Writer', err) - def _openPatternsInfo(self) -> None: - PatternsInfoViewController.showInfo(self._datasetPresenter, self._view) + def _open_patterns_info(self) -> None: + PatternsInfoViewController.show_info(self._dataset, self._view) - def _closeDataset(self) -> None: + def _close_dataset(self) -> None: button = QMessageBox.question( self._view, 'Confirm Close', 'This will free the diffraction data from memory. Do you want to continue?', ) - if button != QMessageBox.StandardButton.Yes: - return + if button == QMessageBox.StandardButton.Yes: + self._patterns_api.close_patterns() - logger.error('Close not implemented!') # TODO + def _sync_model_to_view(self) -> None: + self._tree_model.clear() - def _syncModelToView(self) -> None: - rootNode = DatasetTreeNode.createRoot() + for index, array in enumerate(self._dataset): + self._tree_model.insert_array(index, array) # type: ignore - for arrayPresenter in self._datasetPresenter: - rootNode.createChild(arrayPresenter) + info_text = self._dataset.get_info_text() + self._view.info_label.setText(info_text) - self._treeModel.setRootNode(rootNode) + def handle_array_inserted(self, index: int) -> None: + self._tree_model.insert_array(index, self._dataset[index]) - infoText = self._datasetPresenter.getInfoText() - self._view.infoLabel.setText(infoText) + def handle_array_changed(self, index: int) -> None: + self._tree_model.refresh_array(index) - def update(self, observable: Observable) -> None: - if observable is self._datasetPresenter: - self._syncModelToView() + def handle_dataset_reloaded(self) -> None: + self._sync_model_to_view() diff --git a/src/ptychodus/controller/patterns/dataset.py b/src/ptychodus/controller/patterns/dataset.py new file mode 100644 index 00000000..31db1e63 --- /dev/null +++ b/src/ptychodus/controller/patterns/dataset.py @@ -0,0 +1,184 @@ +from __future__ import annotations +from typing import Any, overload + +from PyQt5.QtCore import Qt, QAbstractItemModel, QModelIndex, QObject + +from ptychodus.api.patterns import PatternDataType +from ptychodus.api.units import BYTES_PER_MEGABYTE + +from ptychodus.model.patterns import AssembledDiffractionPatternArray + +__all__ = ['DatasetTreeModel'] + + +class DatasetTreeNode: + def __init__( + self, + parent_node: DatasetTreeNode | None, + array: AssembledDiffractionPatternArray, + frame_index: int, + ) -> None: + self.parent_node = parent_node + self._array = array + self._frame_index = frame_index + self.child_nodes: list[DatasetTreeNode] = list() + + @classmethod + def create_root(cls) -> DatasetTreeNode: + return cls(None, AssembledDiffractionPatternArray.create_null(), -1) + + def insert_child(self, pos: int, array: AssembledDiffractionPatternArray) -> DatasetTreeNode: + child = DatasetTreeNode(self, array, -1) + + for frame_index in range(array.get_num_patterns()): + grandchild = DatasetTreeNode(child, array, frame_index) + child.child_nodes.append(grandchild) + + self.child_nodes.insert(pos, child) + return child + + def get_label(self) -> str: + return self._array.get_label() if self._frame_index < 0 else f'Frame {self._frame_index}' + + def get_data(self) -> PatternDataType: + return ( + self._array.get_average_pattern() + if self._frame_index < 0 + else self._array.get_pattern(self._frame_index) + ) + + def get_counts(self) -> int: + return ( + int(self._array.get_mean_pattern_counts()) + if self._frame_index < 0 + else int(self._array.get_pattern_counts(self._frame_index)) + ) + + def get_nframes(self) -> int: + return len(self.child_nodes) if self._frame_index < 0 else 1 + + def get_nbytes(self) -> int: + return ( + self._array.get_data().nbytes + if self._frame_index < 0 + else self._array.get_pattern(self._frame_index).nbytes + ) + + def get_row(self) -> int: + return 0 if self.parent_node is None else self.parent_node.child_nodes.index(self) + + +class DatasetTreeModel(QAbstractItemModel): + def __init__(self, parent: QObject | None = None) -> None: + super().__init__(parent) + self._nodes = DatasetTreeNode.create_root() + self._max_counts = 1 + self._header = ['Label', 'Counts', 'Frames', 'Size [MB]'] + + def clear(self) -> None: + self.beginResetModel() + self._nodes = DatasetTreeNode.create_root() + self._max_counts = 1 + self.endResetModel() + + def insert_array(self, row: int, array: AssembledDiffractionPatternArray) -> None: + max_counts = array.get_max_pattern_counts() + + if self._max_counts < max_counts: + self._max_counts = max_counts + num_rows = self.rowCount() + + top_left = self.index(0, 1) + bottom_right = self.index(num_rows - 1, 1) + self.dataChanged.emit(top_left, bottom_right) + + for row2 in range(num_rows): + parent_index = self.index(row2, 0) + num_rows2 = self.rowCount(parent_index) + + child_top_left = self.index(0, 1, parent_index) + child_bottom_right = self.index(num_rows2 - 1, 1, parent_index) + self.dataChanged.emit(child_top_left, child_bottom_right) + + self.beginInsertRows(QModelIndex(), row, row) + child_node = self._nodes.insert_child(row, array) + self.endInsertRows() + + index = self.index(row, 0) + self.beginInsertRows(index, 0, len(child_node.child_nodes)) + self.endInsertRows() + + def refresh_array(self, row: int) -> None: + top_left = self.index(row, 0) + bottom_right = self.index(row, self.columnCount() - 1) + self.dataChanged.emit(top_left, bottom_right) + + num_rows = self.rowCount(top_left) + num_cols = self.columnCount(top_left) + + child_top_left = self.index(0, 0, top_left) + child_bottom_right = self.index(num_rows - 1, num_cols - 1, top_left) + self.dataChanged.emit(child_top_left, child_bottom_right) + + @overload + def parent(self, child: QModelIndex) -> QModelIndex: ... + + @overload + def parent(self) -> QObject: ... + + def parent(self, child: QModelIndex | None = None) -> QModelIndex | QObject: + if child is None: + return super().parent() + + if child.isValid(): + child_node = child.internalPointer() + parent_node = child_node.parent_node + + if parent_node is not self._nodes: + return self.createIndex(parent_node.get_row(), 0, parent_node) + + return QModelIndex() + + def headerData( # noqa: N802 + self, + section: int, + orientation: Qt.Orientation, + role: int = Qt.ItemDataRole.DisplayRole, + ) -> Any: + if orientation == Qt.Orientation.Horizontal and role == Qt.ItemDataRole.DisplayRole: + return self._header[section] + + def data(self, index: QModelIndex, role: int = Qt.ItemDataRole.DisplayRole) -> Any: + if index.isValid(): + node = index.internalPointer() + + if role == Qt.ItemDataRole.DisplayRole: + match index.column(): + case 0: + return node.get_label() + case 1: + return str(node.get_counts()) + case 2: + return node.get_nframes() + case 3: + return f'{node.get_nbytes() / BYTES_PER_MEGABYTE:.2f}' + elif role == Qt.ItemDataRole.UserRole: + if index.column() == 1: + return (100 * node.get_counts()) // self._max_counts + + def index(self, row: int, column: int, parent: QModelIndex = QModelIndex()) -> QModelIndex: + if self.hasIndex(row, column, parent): + parent_node = parent.internalPointer() if parent.isValid() else self._nodes + child_node = parent_node.child_nodes[row] + + if child_node: + return self.createIndex(row, column, child_node) + + return QModelIndex() + + def rowCount(self, parent: QModelIndex = QModelIndex()) -> int: # noqa: N802 + node = parent.internalPointer() if parent.isValid() else self._nodes + return len(node.child_nodes) + + def columnCount(self, parent: QModelIndex = QModelIndex()) -> int: # noqa: N802 + return len(self._header) diff --git a/src/ptychodus/controller/patterns/detector.py b/src/ptychodus/controller/patterns/detector.py deleted file mode 100644 index d63b1971..00000000 --- a/src/ptychodus/controller/patterns/detector.py +++ /dev/null @@ -1,29 +0,0 @@ -from __future__ import annotations - -from PyQt5.QtWidgets import QFormLayout - - -from ...model.patterns import Detector -from ...view.patterns import DetectorView -from ..parametric import LengthWidgetParameterViewController, SpinBoxParameterViewController - - -class DetectorController: - def __init__(self, detector: Detector, view: DetectorView) -> None: - self._widthInPixelsViewController = SpinBoxParameterViewController(detector.widthInPixels) - self._heightInPixelsViewController = SpinBoxParameterViewController(detector.heightInPixels) - self._pixelWidthViewController = LengthWidgetParameterViewController( - detector.pixelWidthInMeters - ) - self._pixelHeightViewController = LengthWidgetParameterViewController( - detector.pixelHeightInMeters - ) - self._bitDepthViewController = SpinBoxParameterViewController(detector.bitDepth) - - layout = QFormLayout() - layout.addRow('Detector Width [px]:', self._widthInPixelsViewController.getWidget()) - layout.addRow('Detector Height [px]:', self._heightInPixelsViewController.getWidget()) - layout.addRow('Pixel Width:', self._pixelWidthViewController.getWidget()) - layout.addRow('Pixel Height:', self._pixelHeightViewController.getWidget()) - layout.addRow('Bit Depth:', self._bitDepthViewController.getWidget()) - view.setLayout(layout) diff --git a/src/ptychodus/controller/patterns/info.py b/src/ptychodus/controller/patterns/info.py index 77aa5157..d022e9a5 100644 --- a/src/ptychodus/controller/patterns/info.py +++ b/src/ptychodus/controller/patterns/info.py @@ -1,38 +1,123 @@ +from typing import Any, overload + from PyQt5.QtWidgets import QWidget +from PyQt5.QtCore import Qt, QAbstractItemModel, QModelIndex, QObject -from ptychodus.api.observer import Observable, Observer +from ptychodus.api.tree import SimpleTreeNode -from ...model.patterns import DiffractionDatasetPresenter +from ...model.patterns import AssembledDiffractionDataset, DiffractionDatasetObserver from ...view.patterns import PatternsInfoDialog -from .tree import SimpleTreeModel -class PatternsInfoViewController(Observer): - def __init__(self, presenter: DiffractionDatasetPresenter, treeModel: SimpleTreeModel) -> None: +class SimpleTreeModel(QAbstractItemModel): + def __init__(self, root_node: SimpleTreeNode, parent: QObject | None = None) -> None: + super().__init__(parent) + self._root_node = root_node + + def set_root_node(self, root_node: SimpleTreeNode) -> None: + self.beginResetModel() + self._root_node = root_node + self.endResetModel() + + @overload + def parent(self, child: QModelIndex) -> QModelIndex: ... + + @overload + def parent(self) -> QObject: ... + + def parent(self, child: QModelIndex | None = None) -> QModelIndex | QObject: + if child is None: + return super().parent() + else: + value = QModelIndex() + + if child.isValid(): + child_item = child.internalPointer() + parent_item = child_item.parent_item + + if parent_item is self._root_node: + value = QModelIndex() + else: + value = self.createIndex(parent_item.row(), 0, parent_item) + + return value + + def headerData( # noqa: N802 + self, + section: int, + orientation: Qt.Orientation, + role: int = Qt.ItemDataRole.DisplayRole, + ) -> Any: + if orientation == Qt.Orientation.Horizontal and role == Qt.ItemDataRole.DisplayRole: + return self._root_node.data(section) + + def flags(self, index: QModelIndex) -> Qt.ItemFlags: + return super().flags(index) + + def index(self, row: int, column: int, parent: QModelIndex = QModelIndex()) -> QModelIndex: + value = QModelIndex() + + if self.hasIndex(row, column, parent): + parent_item = parent.internalPointer() if parent.isValid() else self._root_node + child_item = parent_item.child_items[row] + + if child_item: + value = self.createIndex(row, column, child_item) + + return value + + def data(self, index: QModelIndex, role: int = Qt.ItemDataRole.DisplayRole) -> Any: + if index.isValid() and role == Qt.ItemDataRole.DisplayRole: + node = index.internalPointer() + return node.data(index.column()) + + def rowCount(self, parent: QModelIndex = QModelIndex()) -> int: # noqa: N802 + if parent.column() > 0: + return 0 + + node = self._root_node + + if parent.isValid(): + node = parent.internalPointer() + + return len(node.child_items) + + def columnCount(self, parent: QModelIndex = QModelIndex()) -> int: # noqa: N802 + node = self._root_node + + if parent.isValid(): + node = parent.internalPointer() + + return len(node.item_data) + + +class PatternsInfoViewController(DiffractionDatasetObserver): + def __init__(self, dataset: AssembledDiffractionDataset, tree_model: SimpleTreeModel) -> None: super().__init__() - self._presenter = presenter - self._treeModel = treeModel + self._dataset = dataset + self._tree_model = tree_model @classmethod - def showInfo(cls, presenter: DiffractionDatasetPresenter, parent: QWidget) -> None: - treeModel = SimpleTreeModel(presenter.getContentsTree()) - controller = cls(presenter, treeModel) - presenter.addObserver(controller) + def show_info(cls, dataset: AssembledDiffractionDataset, parent: QWidget) -> None: + tree_model = SimpleTreeModel(dataset.get_contents_tree()) + controller = cls(dataset, tree_model) + dataset.add_observer(controller) - dialog = PatternsInfoDialog.createInstance(parent) + dialog = PatternsInfoDialog(parent) dialog.setWindowTitle('Patterns Info') - dialog.treeView.setModel(treeModel) - dialog.finished.connect(controller._finish) + dialog.tree_view.setModel(tree_model) - controller._syncModelToView() + controller._sync_model_to_view() dialog.open() - def _finish(self, result: int) -> None: - self._presenter.removeObserver(self) + def _sync_model_to_view(self) -> None: + self._tree_model.set_root_node(self._dataset.get_contents_tree()) + + def handle_array_inserted(self, index: int) -> None: + pass - def _syncModelToView(self) -> None: - self._treeModel.setRootNode(self._presenter.getContentsTree()) + def handle_array_changed(self, index: int) -> None: + pass - def update(self, observable: Observable) -> None: - if observable is self._presenter: - self._syncModelToView() + def handle_dataset_reloaded(self) -> None: + self._sync_model_to_view() diff --git a/src/ptychodus/controller/patterns/tree.py b/src/ptychodus/controller/patterns/tree.py deleted file mode 100644 index b5e70838..00000000 --- a/src/ptychodus/controller/patterns/tree.py +++ /dev/null @@ -1,87 +0,0 @@ -from typing import Any, overload - -from PyQt5.QtCore import Qt, QAbstractItemModel, QModelIndex, QObject - -from ptychodus.api.tree import SimpleTreeNode - - -class SimpleTreeModel(QAbstractItemModel): - def __init__(self, rootNode: SimpleTreeNode, parent: QObject | None = None) -> None: - super().__init__(parent) - self._rootNode = rootNode - - def setRootNode(self, rootNode: SimpleTreeNode) -> None: - self.beginResetModel() - self._rootNode = rootNode - self.endResetModel() - - @overload - def parent(self, child: QModelIndex) -> QModelIndex: ... - - @overload - def parent(self) -> QObject: ... - - def parent(self, child: QModelIndex | None = None) -> QModelIndex | QObject: - if child is None: - return super().parent() - else: - value = QModelIndex() - - if child.isValid(): - childItem = child.internalPointer() - parentItem = childItem.parentItem - - if parentItem is self._rootNode: - value = QModelIndex() - else: - value = self.createIndex(parentItem.row(), 0, parentItem) - - return value - - def headerData( - self, - section: int, - orientation: Qt.Orientation, - role: int = Qt.ItemDataRole.DisplayRole, - ) -> Any: - if orientation == Qt.Orientation.Horizontal and role == Qt.ItemDataRole.DisplayRole: - return self._rootNode.data(section) - - def flags(self, index: QModelIndex) -> Qt.ItemFlags: - return super().flags(index) - - def index(self, row: int, column: int, parent: QModelIndex = QModelIndex()) -> QModelIndex: - value = QModelIndex() - - if self.hasIndex(row, column, parent): - parentItem = parent.internalPointer() if parent.isValid() else self._rootNode - childItem = parentItem.childItems[row] - - if childItem: - value = self.createIndex(row, column, childItem) - - return value - - def data(self, index: QModelIndex, role: int = Qt.ItemDataRole.DisplayRole) -> Any: - if index.isValid() and role == Qt.ItemDataRole.DisplayRole: - node = index.internalPointer() - return node.data(index.column()) - - def rowCount(self, parent: QModelIndex = QModelIndex()) -> int: - if parent.column() > 0: - return 0 - - node = self._rootNode - - if parent.isValid(): - node = parent.internalPointer() - - return len(node.childItems) - - def columnCount(self, parent: QModelIndex = QModelIndex()) -> int: - node = self._rootNode - - if parent.isValid(): - node = parent.internalPointer() - - return len(node.itemData) diff --git a/src/ptychodus/controller/patterns/treeModel.py b/src/ptychodus/controller/patterns/treeModel.py deleted file mode 100644 index f4fe525a..00000000 --- a/src/ptychodus/controller/patterns/treeModel.py +++ /dev/null @@ -1,177 +0,0 @@ -from __future__ import annotations -from typing import Any, overload - -from PyQt5.QtCore import Qt, QAbstractItemModel, QModelIndex, QObject -from PyQt5.QtGui import QFont - -from ptychodus.api.patterns import DiffractionPatternArrayType, DiffractionPatternState - -from ...model.patterns import DiffractionPatternArrayPresenter - - -class DatasetTreeNode: - def __init__( - self, - parentItem: DatasetTreeNode | None, - presenter: DiffractionPatternArrayPresenter, - frameIndex: int, - ) -> None: - self.parentItem = parentItem - self._presenter = presenter - self._frameIndex = frameIndex - self.childItems: list[DatasetTreeNode] = list() - - @classmethod - def createRoot(cls) -> DatasetTreeNode: - return cls(None, DiffractionPatternArrayPresenter.createNull(), -1) - - def createChild(self, presenter: DiffractionPatternArrayPresenter) -> DatasetTreeNode: - childItem = DatasetTreeNode(self, presenter, -1) - - if presenter.data is not None: - for frameIndex in range(presenter.data.shape[0]): - grandChildItem = DatasetTreeNode(childItem, presenter, frameIndex) - childItem.childItems.append(grandChildItem) - - self.childItems.append(childItem) - return childItem - - @property - def label(self) -> str: - if self._frameIndex < 0: - return self._presenter.label - - return f'Frame {self._frameIndex}' - - @property - def state(self) -> DiffractionPatternState: - return self._presenter.state - - @property - def data(self) -> DiffractionPatternArrayType | None: - if self._presenter.data is None: - return None - elif self._frameIndex < 0: - return self._presenter.data.mean(axis=0) - - return self._presenter.data[self._frameIndex] - - @property - def numberOfFrames(self) -> int: - if self._frameIndex < 0: - return len(self.childItems) - - return 1 - - @property - def sizeInBytes(self) -> int: - if self._presenter.data is None: - return 0 - elif self._frameIndex < 0: - return self._presenter.data.nbytes - - return self._presenter.data[self._frameIndex].nbytes - - def row(self) -> int: - if self.parentItem: - return self.parentItem.childItems.index(self) - - return 0 - - -class DatasetTreeModel(QAbstractItemModel): - def __init__(self, parent: QObject | None = None) -> None: - super().__init__(parent) - self._rootNode = DatasetTreeNode.createRoot() - self._header = ['Label', 'Frames', 'Size [MB]'] - - def setRootNode(self, rootNode: DatasetTreeNode) -> None: - self.beginResetModel() - self._rootNode = rootNode - self.endResetModel() - - @overload - def parent(self, child: QModelIndex) -> QModelIndex: ... - - @overload - def parent(self) -> QObject: ... - - def parent(self, child: QModelIndex | None = None) -> QModelIndex | QObject: - if child is None: - return super().parent() - else: - value = QModelIndex() - - if child.isValid(): - childItem = child.internalPointer() - parentItem = childItem.parentItem - - if parentItem is self._rootNode: - value = QModelIndex() - else: - value = self.createIndex(parentItem.row(), 0, parentItem) - - return value - - def headerData( - self, - section: int, - orientation: Qt.Orientation, - role: int = Qt.ItemDataRole.DisplayRole, - ) -> Any: - if orientation == Qt.Orientation.Horizontal and role == Qt.ItemDataRole.DisplayRole: - return self._header[section] - - def flags(self, index: QModelIndex) -> Qt.ItemFlags: - value = super().flags(index) - - if index.isValid(): - node = index.internalPointer() - - if node.state != DiffractionPatternState.LOADED: - value &= ~Qt.ItemFlag.ItemIsSelectable - value &= ~Qt.ItemFlag.ItemIsEnabled - - return value - - def data(self, index: QModelIndex, role: int = Qt.ItemDataRole.DisplayRole) -> Any: - if index.isValid(): - node = index.internalPointer() - - if role == Qt.ItemDataRole.DisplayRole: - if index.column() == 0: - return node.label - elif index.column() == 1: - return node.numberOfFrames - elif index.column() == 2: - return f'{node.sizeInBytes / (1024 * 1024):.2f}' - elif role == Qt.ItemDataRole.FontRole: - font = QFont() - font.setItalic(node.state == DiffractionPatternState.FOUND) - return font - - def index(self, row: int, column: int, parent: QModelIndex = QModelIndex()) -> QModelIndex: - value = QModelIndex() - - if self.hasIndex(row, column, parent): - parentItem = parent.internalPointer() if parent.isValid() else self._rootNode - childItem = parentItem.childItems[row] - - if childItem: - value = self.createIndex(row, column, childItem) - - return value - - def rowCount(self, parent: QModelIndex = QModelIndex()) -> int: - if parent.column() > 0: - return 0 - - node = self._rootNode - - if parent.isValid(): - node = parent.internalPointer() - - return len(node.childItems) - - def columnCount(self, parent: QModelIndex = QModelIndex()) -> int: - return len(self._header) diff --git a/src/ptychodus/controller/patterns/wizard.py b/src/ptychodus/controller/patterns/wizard.py deleted file mode 100644 index 14635cfa..00000000 --- a/src/ptychodus/controller/patterns/wizard.py +++ /dev/null @@ -1,547 +0,0 @@ -from __future__ import annotations -from pathlib import Path -import logging -import re - -from PyQt5.QtCore import Qt, QDir, QFileInfo, QModelIndex, QSortFilterProxyModel -from PyQt5.QtWidgets import QAbstractItemView, QFileSystemModel, QWizard - -from ptychodus.api.observer import Observable, Observer - -from ...model.patterns import ( - DiffractionDatasetInputOutputPresenter, - DiffractionDatasetPresenter, - DiffractionMetadataPresenter, - DiffractionPatternPresenter, -) -from ...view.patterns import ( - OpenDatasetWizard, - OpenDatasetWizardFilesPage, - OpenDatasetWizardMetadataPage, - OpenDatasetWizardPatternsPage, - OpenDatasetWizardPatternCropView, - OpenDatasetWizardPatternLoadView, - OpenDatasetWizardPatternMemoryMapView, - OpenDatasetWizardPatternTransformView, -) -from ..data import FileDialogFactory - -logger = logging.getLogger(__name__) - -__all__ = [ - 'OpenDatasetWizardController', -] - - -class OpenDatasetWizardFilesController(Observer): - def __init__( - self, - presenter: DiffractionDatasetInputOutputPresenter, - page: OpenDatasetWizardFilesPage, - fileDialogFactory: FileDialogFactory, - fileSystemModel: QFileSystemModel, - fileSystemProxyModel: QSortFilterProxyModel, - ) -> None: - super().__init__() - self._presenter = presenter - self._page = page - self._fileDialogFactory = fileDialogFactory - self._fileSystemModel = fileSystemModel - self._fileSystemProxyModel = fileSystemProxyModel - - @classmethod - def createInstance( - cls, - presenter: DiffractionDatasetInputOutputPresenter, - page: OpenDatasetWizardFilesPage, - fileDialogFactory: FileDialogFactory, - ) -> OpenDatasetWizardFilesController: - fileSystemModel = QFileSystemModel() - fileSystemProxyModel = QSortFilterProxyModel() - fileSystemModel.setFilter(QDir.Filter.AllEntries | QDir.Filter.AllDirs) - fileSystemModel.setNameFilterDisables(False) - fileSystemProxyModel.setSourceModel(fileSystemModel) - - controller = cls(presenter, page, fileDialogFactory, fileSystemModel, fileSystemProxyModel) - presenter.addObserver(controller) - - page.directoryComboBox.addItem(str(fileDialogFactory.getOpenWorkingDirectory())) - page.directoryComboBox.addItem(str(Path.home())) - page.directoryComboBox.setEditable(True) - page.directoryComboBox.textActivated.connect(controller._handleDirectoryComboBoxActivated) - - page.fileSystemTableView.setModel(controller._fileSystemProxyModel) - page.fileSystemTableView.setSortingEnabled(True) - page.fileSystemTableView.sortByColumn(0, Qt.SortOrder.AscendingOrder) - page.fileSystemTableView.verticalHeader().hide() - page.fileSystemTableView.setSelectionBehavior( - QAbstractItemView.SelectionBehavior.SelectRows - ) - page.fileSystemTableView.doubleClicked.connect( - controller._handleFileSystemTableDoubleClicked - ) - page.fileSystemTableView.selectionModel().currentChanged.connect( - controller._checkIfComplete - ) - - for fileFilter in presenter.getOpenFileFilterList(): - page.fileTypeComboBox.addItem(fileFilter) - - page.fileTypeComboBox.textActivated.connect(controller._setNameFiltersInFileSystemModel) - - controller._setRootPath(fileDialogFactory.getOpenWorkingDirectory()) - controller._syncModelToView() - - return controller - - def _setRootPath(self, rootPath: Path) -> None: - index = self._fileSystemModel.setRootPath(str(rootPath)) - proxyIndex = self._fileSystemProxyModel.mapFromSource(index) - self._page.fileSystemTableView.setRootIndex(proxyIndex) - self._page.directoryComboBox.setCurrentText(str(rootPath)) - self._fileDialogFactory.setOpenWorkingDirectory(rootPath) - - def _handleDirectoryComboBoxActivated(self, text: str) -> None: - fileInfo = QFileInfo(text) - - if fileInfo.isDir(): - self._setRootPath(Path(fileInfo.canonicalFilePath())) - - def _handleFileSystemTableDoubleClicked(self, proxyIndex: QModelIndex) -> None: - index = self._fileSystemProxyModel.mapToSource(proxyIndex) - fileInfo = self._fileSystemModel.fileInfo(index) - - if fileInfo.isDir(): - self._setRootPath(Path(fileInfo.canonicalFilePath())) - - def openDataset(self) -> None: - proxyIndex = self._page.fileSystemTableView.currentIndex() - index = self._fileSystemProxyModel.mapToSource(proxyIndex) - filePath = Path(self._fileSystemModel.filePath(index)) - self._fileDialogFactory.setOpenWorkingDirectory(filePath.parent) - - fileFilter = self._page.fileTypeComboBox.currentText() - self._presenter.openDiffractionFile(filePath, fileFilter) - - def _checkIfComplete(self, current: QModelIndex, previous: QModelIndex) -> None: - index = self._fileSystemProxyModel.mapToSource(current) - fileInfo = self._fileSystemModel.fileInfo(index) - self._page._setComplete(fileInfo.isFile()) - - def _setNameFiltersInFileSystemModel(self, currentText: str) -> None: - z = re.search(r'\((.+)\)', currentText) - - if z: - nameFilters = z.group(1).split() - logger.debug(f'Dataset File Name Filters: {nameFilters}') - self._fileSystemModel.setNameFilters(nameFilters) - - def _syncModelToView(self) -> None: - self._page.fileTypeComboBox.setCurrentText(self._presenter.getOpenFileFilter()) - - def update(self, observable: Observable) -> None: - if observable is self._presenter: - self._syncModelToView() - - -class OpenDatasetWizardMetadataController(Observer): - def __init__( - self, - presenter: DiffractionMetadataPresenter, - page: OpenDatasetWizardMetadataPage, - ) -> None: - super().__init__() - self._presenter = presenter - self._page = page - - @classmethod - def createInstance( - cls, - presenter: DiffractionMetadataPresenter, - page: OpenDatasetWizardMetadataPage, - ) -> OpenDatasetWizardMetadataController: - controller = cls(presenter, page) - presenter.addObserver(controller) - controller._syncModelToView() - page._setComplete(True) - return controller - - def importMetadata(self) -> None: - if self._page.detectorPixelCountCheckBox.isChecked(): - self._presenter.syncDetectorPixelCount() - - if self._page.detectorPixelSizeCheckBox.isChecked(): - self._presenter.syncDetectorPixelSize() - - if self._page.detectorBitDepthCheckBox.isChecked(): - self._presenter.syncDetectorBitDepth() - - if self._page.detectorDistanceCheckBox.isChecked(): - self._presenter.syncDetectorDistance() - - self._presenter.syncPatternCrop( - syncCenter=self._page.patternCropCenterCheckBox.isChecked(), - syncExtent=self._page.patternCropExtentCheckBox.isChecked(), - ) - - if self._page.probeEnergyCheckBox.isChecked(): - self._presenter.syncProbeEnergy() - - def _syncModelToView(self) -> None: - canSyncDetectorPixelCount = self._presenter.canSyncDetectorPixelCount() - self._page.detectorPixelCountCheckBox.setVisible(canSyncDetectorPixelCount) - self._page.detectorPixelCountCheckBox.setChecked(canSyncDetectorPixelCount) - - canSyncDetectorPixelSize = self._presenter.canSyncDetectorPixelSize() - self._page.detectorPixelSizeCheckBox.setVisible(canSyncDetectorPixelSize) - self._page.detectorPixelSizeCheckBox.setChecked(canSyncDetectorPixelSize) - - canSyncDetectorBitDepth = self._presenter.canSyncDetectorBitDepth() - self._page.detectorBitDepthCheckBox.setVisible(canSyncDetectorBitDepth) - self._page.detectorBitDepthCheckBox.setChecked(canSyncDetectorBitDepth) - - canSyncDetectorDistance = self._presenter.canSyncDetectorDistance() - self._page.detectorDistanceCheckBox.setVisible(canSyncDetectorDistance) - self._page.detectorDistanceCheckBox.setChecked(canSyncDetectorDistance) - - canSyncPatternCropCenter = self._presenter.canSyncPatternCropCenter() - self._page.patternCropCenterCheckBox.setVisible(canSyncPatternCropCenter) - self._page.patternCropCenterCheckBox.setChecked(canSyncPatternCropCenter) - - canSyncPatternCropExtent = self._presenter.canSyncPatternCropExtent() - self._page.patternCropExtentCheckBox.setVisible(canSyncPatternCropExtent) - self._page.patternCropExtentCheckBox.setChecked(canSyncPatternCropExtent) - - canSyncProbeEnergy = self._presenter.canSyncProbeEnergy() - self._page.probeEnergyCheckBox.setVisible(canSyncProbeEnergy) - self._page.probeEnergyCheckBox.setChecked(canSyncProbeEnergy) - - def update(self, observable: Observable) -> None: - if observable is self._presenter: - self._syncModelToView() - - -class PatternLoadController(Observer): - def __init__( - self, - presenter: DiffractionDatasetPresenter, - view: OpenDatasetWizardPatternLoadView, - ) -> None: - super().__init__() - self._presenter = presenter - self._view = view - - @classmethod - def createInstance( - cls, - presenter: DiffractionDatasetPresenter, - view: OpenDatasetWizardPatternLoadView, - ) -> PatternLoadController: - controller = cls(presenter, view) - presenter.addObserver(controller) - view.numberOfThreadsSpinBox.valueChanged.connect(presenter.setNumberOfDataThreads) - controller._syncModelToView() - return controller - - def _syncModelToView(self) -> None: - self._view.numberOfThreadsSpinBox.blockSignals(True) - self._view.numberOfThreadsSpinBox.setRange( - self._presenter.getNumberOfDataThreadsLimits().lower, - self._presenter.getNumberOfDataThreadsLimits().upper, - ) - self._view.numberOfThreadsSpinBox.setValue(self._presenter.getNumberOfDataThreads()) - self._view.numberOfThreadsSpinBox.blockSignals(False) - - def update(self, observable: Observable) -> None: - if observable is self._presenter: - self._syncModelToView() - - -class PatternMemoryMapController(Observer): - def __init__( - self, - presenter: DiffractionDatasetPresenter, - view: OpenDatasetWizardPatternMemoryMapView, - fileDialogFactory: FileDialogFactory, - ) -> None: - super().__init__() - self._presenter = presenter - self._view = view - self._fileDialogFactory = fileDialogFactory - - @classmethod - def createInstance( - cls, - presenter: DiffractionDatasetPresenter, - view: OpenDatasetWizardPatternMemoryMapView, - fileDialogFactory: FileDialogFactory, - ) -> PatternMemoryMapController: - controller = cls(presenter, view, fileDialogFactory) - presenter.addObserver(controller) - - view.setCheckable(True) - controller._syncModelToView() - view.toggled.connect(presenter.setMemmapEnabled) - view.scratchDirectoryLineEdit.editingFinished.connect( - controller._syncScratchDirectoryToModel - ) - view.scratchDirectoryBrowseButton.clicked.connect(controller._browseScratchDirectory) - - return controller - - def _syncScratchDirectoryToModel(self) -> None: - scratchDirectory = Path(self._view.scratchDirectoryLineEdit.text()) - self._presenter.setScratchDirectory(scratchDirectory) - - def _browseScratchDirectory(self) -> None: - dirPath = self._fileDialogFactory.getExistingDirectoryPath( - self._view, 'Choose Scratch ScratchDirectory' - ) - - if dirPath: - self._presenter.setScratchDirectory(dirPath) - - def _syncModelToView(self) -> None: - self._view.setChecked(self._presenter.isMemmapEnabled()) - scratchDirectory = self._presenter.getScratchDirectory() - - if scratchDirectory: - self._view.scratchDirectoryLineEdit.setText(str(scratchDirectory)) - else: - self._view.scratchDirectoryLineEdit.clear() - - def update(self, observable: Observable) -> None: - if observable is self._presenter: - self._syncModelToView() - - -class PatternCropController(Observer): - def __init__( - self, - presenter: DiffractionPatternPresenter, - view: OpenDatasetWizardPatternCropView, - ) -> None: - super().__init__() - self._presenter = presenter - self._view = view - - @classmethod - def createInstance( - cls, - presenter: DiffractionPatternPresenter, - view: OpenDatasetWizardPatternCropView, - ) -> PatternCropController: - controller = cls(presenter, view) - presenter.addObserver(controller) - - view.setCheckable(True) - view.toggled.connect(presenter.setCropEnabled) - - view.centerXSpinBox.valueChanged.connect(presenter.setCropCenterXInPixels) - view.centerYSpinBox.valueChanged.connect(presenter.setCropCenterYInPixels) - view.extentXSpinBox.valueChanged.connect(presenter.setCropWidthInPixels) - view.extentYSpinBox.valueChanged.connect(presenter.setCropHeightInPixels) - - controller._syncModelToView() - - return controller - - def _syncModelToView(self) -> None: - self._view.setChecked(self._presenter.isCropEnabled()) - - self._view.centerXSpinBox.blockSignals(True) - self._view.centerXSpinBox.setRange( - self._presenter.getCropCenterXLimitsInPixels().lower, - self._presenter.getCropCenterXLimitsInPixels().upper, - ) - self._view.centerXSpinBox.setValue(self._presenter.getCropCenterXInPixels()) - self._view.centerXSpinBox.blockSignals(False) - - self._view.centerYSpinBox.blockSignals(True) - self._view.centerYSpinBox.setRange( - self._presenter.getCropCenterYLimitsInPixels().lower, - self._presenter.getCropCenterYLimitsInPixels().upper, - ) - self._view.centerYSpinBox.setValue(self._presenter.getCropCenterYInPixels()) - self._view.centerYSpinBox.blockSignals(False) - - self._view.extentXSpinBox.blockSignals(True) - self._view.extentXSpinBox.setRange( - self._presenter.getCropWidthLimitsInPixels().lower, - self._presenter.getCropWidthLimitsInPixels().upper, - ) - self._view.extentXSpinBox.setValue(self._presenter.getCropWidthInPixels()) - self._view.extentXSpinBox.blockSignals(False) - - self._view.extentYSpinBox.blockSignals(True) - self._view.extentYSpinBox.setRange( - self._presenter.getCropHeightLimitsInPixels().lower, - self._presenter.getCropHeightLimitsInPixels().upper, - ) - self._view.extentYSpinBox.setValue(self._presenter.getCropHeightInPixels()) - self._view.extentYSpinBox.blockSignals(False) - - def update(self, observable: Observable) -> None: - if observable is self._presenter: - self._syncModelToView() - - -class PatternTransformController(Observer): - def __init__( - self, - presenter: DiffractionPatternPresenter, - view: OpenDatasetWizardPatternTransformView, - ) -> None: - super().__init__() - self._presenter = presenter - self._view = view - - @classmethod - def createInstance( - cls, - presenter: DiffractionPatternPresenter, - view: OpenDatasetWizardPatternTransformView, - ) -> PatternTransformController: - controller = cls(presenter, view) - presenter.addObserver(controller) - - view.valueLowerBoundCheckBox.toggled.connect(presenter.setValueLowerBoundEnabled) - view.valueLowerBoundSpinBox.valueChanged.connect(presenter.setValueLowerBound) - view.valueUpperBoundCheckBox.toggled.connect(presenter.setValueUpperBoundEnabled) - view.valueUpperBoundSpinBox.valueChanged.connect(presenter.setValueUpperBound) - view.flipXCheckBox.toggled.connect(presenter.setFlipXEnabled) - view.flipYCheckBox.toggled.connect(presenter.setFlipYEnabled) - - controller._syncModelToView() - return controller - - def _syncModelToView(self) -> None: - self._view.valueLowerBoundCheckBox.setChecked(self._presenter.isValueLowerBoundEnabled()) - - self._view.valueLowerBoundSpinBox.blockSignals(True) - self._view.valueLowerBoundSpinBox.setRange( - self._presenter.getValueLowerBoundLimits().lower, - self._presenter.getValueLowerBoundLimits().upper, - ) - self._view.valueLowerBoundSpinBox.setValue(self._presenter.getValueLowerBound()) - self._view.valueLowerBoundSpinBox.blockSignals(False) - - self._view.valueUpperBoundCheckBox.setChecked(self._presenter.isValueUpperBoundEnabled()) - - self._view.valueUpperBoundSpinBox.blockSignals(True) - self._view.valueUpperBoundSpinBox.setRange( - self._presenter.getValueUpperBoundLimits().lower, - self._presenter.getValueUpperBoundLimits().upper, - ) - self._view.valueUpperBoundSpinBox.setValue(self._presenter.getValueUpperBound()) - self._view.valueUpperBoundSpinBox.blockSignals(False) - - self._view.flipXCheckBox.setChecked(self._presenter.isFlipXEnabled()) - self._view.flipYCheckBox.setChecked(self._presenter.isFlipYEnabled()) - - def update(self, observable: Observable) -> None: - if observable is self._presenter: - self._syncModelToView() - - -class OpenDatasetWizardPatternsController: - def __init__( - self, - datasetPresenter: DiffractionDatasetPresenter, - patternPresenter: DiffractionPatternPresenter, - page: OpenDatasetWizardPatternsPage, - fileDialogFactory: FileDialogFactory, - ) -> None: - self._datasetPresenter = datasetPresenter - self._patternPresenter = patternPresenter - self._page = page - self._loadController = PatternLoadController.createInstance(datasetPresenter, page.loadView) - self._memoryMapController = PatternMemoryMapController.createInstance( - datasetPresenter, page.memoryMapView, fileDialogFactory - ) - self._cropController = PatternCropController.createInstance(patternPresenter, page.cropView) - self._transformController = PatternTransformController.createInstance( - patternPresenter, page.transformView - ) - - @classmethod - def createInstance( - cls, - ioPresenter: DiffractionDatasetInputOutputPresenter, - datasetPresenter: DiffractionDatasetPresenter, - patternPresenter: DiffractionPatternPresenter, - page: OpenDatasetWizardPatternsPage, - fileDialogFactory: FileDialogFactory, - ) -> OpenDatasetWizardPatternsController: - controller = cls(datasetPresenter, patternPresenter, page, fileDialogFactory) - page._setComplete(True) - return controller - - -class OpenDatasetWizardController: - def __init__( - self, - ioPresenter: DiffractionDatasetInputOutputPresenter, - metadataPresenter: DiffractionMetadataPresenter, - datasetPresenter: DiffractionDatasetPresenter, - patternPresenter: DiffractionPatternPresenter, - wizard: OpenDatasetWizard, - fileDialogFactory: FileDialogFactory, - ) -> None: - self._ioPresenter = ioPresenter - self._wizard = wizard - self._filesController = OpenDatasetWizardFilesController.createInstance( - ioPresenter, wizard.filesPage, fileDialogFactory - ) - self._metadataController = OpenDatasetWizardMetadataController.createInstance( - metadataPresenter, wizard.metadataPage - ) - self._patternsController = OpenDatasetWizardPatternsController.createInstance( - ioPresenter, - datasetPresenter, - patternPresenter, - wizard.patternsPage, - fileDialogFactory, - ) - - @classmethod - def createInstance( - cls, - ioPresenter: DiffractionDatasetInputOutputPresenter, - metadataPresenter: DiffractionMetadataPresenter, - datasetPresenter: DiffractionDatasetPresenter, - patternPresenter: DiffractionPatternPresenter, - wizard: OpenDatasetWizard, - fileDialogFactory: FileDialogFactory, - ) -> OpenDatasetWizardController: - controller = cls( - ioPresenter, - metadataPresenter, - datasetPresenter, - patternPresenter, - wizard, - fileDialogFactory, - ) - wizard.button(QWizard.WizardButton.NextButton).clicked.connect( - controller._executeNextButtonAction - ) - wizard.button(QWizard.WizardButton.FinishButton).clicked.connect( - controller._executeFinishButtonAction - ) - return controller - - def _executeNextButtonAction(self) -> None: - page = self._wizard.currentPage() - - if page is self._wizard.metadataPage: - self._filesController.openDataset() - elif page is self._wizard.patternsPage: - self._metadataController.importMetadata() - - def _executeFinishButtonAction(self) -> None: - self._ioPresenter.startAssemblingDiffractionPatterns() - - def openDataset(self) -> None: - self._ioPresenter.stopAssemblingDiffractionPatterns(finishAssembling=False) - self._wizard.restart() - self._wizard.show() diff --git a/src/ptychodus/controller/patterns/wizard/__init__.py b/src/ptychodus/controller/patterns/wizard/__init__.py new file mode 100644 index 00000000..20b2fcf8 --- /dev/null +++ b/src/ptychodus/controller/patterns/wizard/__init__.py @@ -0,0 +1,5 @@ +from .core import OpenDatasetWizardController + +__all__ = [ + 'OpenDatasetWizardController', +] diff --git a/src/ptychodus/controller/patterns/wizard/core.py b/src/ptychodus/controller/patterns/wizard/core.py new file mode 100644 index 00000000..1de3a971 --- /dev/null +++ b/src/ptychodus/controller/patterns/wizard/core.py @@ -0,0 +1,62 @@ +import logging + +from PyQt5.QtWidgets import QWizard + +from ....model.metadata import MetadataPresenter +from ....model.patterns import PatternSettings, PatternSizer, PatternsAPI + +from ...data import FileDialogFactory +from .files import OpenDatasetWizardFilesViewController +from .metadata import OpenDatasetWizardMetadataViewController +from .patterns import OpenDatasetWizardPatternsViewController + +logger = logging.getLogger(__name__) + + +class OpenDatasetWizardController: + def __init__( + self, + settings: PatternSettings, + sizer: PatternSizer, + api: PatternsAPI, + metadata_presenter: MetadataPresenter, + file_dialog_factory: FileDialogFactory, + ) -> None: + self._api = api + self._file_view_controller = OpenDatasetWizardFilesViewController( + settings, api, file_dialog_factory + ) + self._metadata_view_controller = OpenDatasetWizardMetadataViewController(metadata_presenter) + self._patterns_view_controller = OpenDatasetWizardPatternsViewController( + settings, sizer, file_dialog_factory + ) + + self._wizard = QWizard() + self._wizard.setWindowTitle('Open Dataset') + self._wizard.addPage(self._file_view_controller.get_widget()) + self._wizard.addPage(self._metadata_view_controller.get_widget()) + self._wizard.addPage(self._patterns_view_controller.get_widget()) + + self._wizard.button(QWizard.WizardButton.NextButton).clicked.connect( + self._execute_next_button_action + ) + self._wizard.button(QWizard.WizardButton.FinishButton).clicked.connect( + self._execute_finish_button_action + ) + + def _execute_next_button_action(self) -> None: + page = self._wizard.currentPage() + + if page is self._metadata_view_controller.get_widget(): + self._file_view_controller.open_dataset() + elif page is self._patterns_view_controller.get_widget(): + self._metadata_view_controller.import_metadata() + + def _execute_finish_button_action(self) -> None: + self._api.start_assembling_diffraction_patterns() + + def open_dataset(self) -> None: + self._api.finish_assembling_diffraction_patterns(block=False) + self._wizard.restart() + self._file_view_controller.restart() + self._wizard.show() diff --git a/src/ptychodus/controller/patterns/wizard/files.py b/src/ptychodus/controller/patterns/wizard/files.py new file mode 100644 index 00000000..ef432b7d --- /dev/null +++ b/src/ptychodus/controller/patterns/wizard/files.py @@ -0,0 +1,322 @@ +from collections.abc import Sequence +from pathlib import Path +import logging +import re + +from PyQt5.QtCore import Qt, QModelIndex, QSortFilterProxyModel +from PyQt5.QtWidgets import ( + QAbstractItemView, + QButtonGroup, + QComboBox, + QFileSystemModel, + QFormLayout, + QHBoxLayout, + QHeaderView, + QLineEdit, + QPushButton, + QTableView, + QWidget, + QWizardPage, +) + +from ptychodus.api.observer import Observable, Observer +from ptychodus.api.parametric import PathParameter + +from ....model.patterns import PatternsAPI, PatternSettings +from ....view.patterns import OpenDatasetWizardPage +from ....view.widgets import ExceptionDialog +from ...data import FileDialogFactory + +logger = logging.getLogger(__name__) + + +class OpenDatasetWizardBreadcrumbsViewController(Observer): + def __init__(self, file_dialog_factory: FileDialogFactory) -> None: + super().__init__() + self._file_dialog_factory = file_dialog_factory + self._widget = QWidget() + self._path_list: list[Path] = [] + self._button_group = QButtonGroup() + self._button_group.idClicked.connect(self._handle_id_clicked) + + layout = QHBoxLayout() + layout.setContentsMargins(0, 0, 0, 0) + self._widget.setLayout(layout) + + self._sync_model_to_view() + file_dialog_factory.add_observer(self) + + def _handle_id_clicked(self, button_id: int) -> None: + path = self._path_list[button_id] + self._file_dialog_factory.set_open_working_directory(path) + + def _sync_model_to_view(self) -> None: + path = self._file_dialog_factory.get_open_working_directory().resolve() + + for button_id, existing_path in enumerate(self._path_list): + if path == existing_path: + button = self._button_group.button(button_id) + button.setChecked(True) + return + + layout = self._widget.layout() + + if layout is not None: + while layout.count(): + item = layout.takeAt(0) + widget = item.widget() + + if widget is None: + continue + elif isinstance(widget, QPushButton): + self._button_group.removeButton(widget) + widget.deleteLater() + + self._path_list.clear() + button_list: list[QPushButton] = [] + + while True: + if path.name: + button = QPushButton(path.name) + button.setCheckable(True) + button_list.append(button) + self._path_list.append(path) + path = path.parent + else: + button = QPushButton(path.anchor) + button.setCheckable(True) + button_list.append(button) + self._path_list.append(Path(path.anchor)) + break + + for button_id, button in reversed(list(enumerate(button_list))): + self._button_group.addButton(button, button_id) + layout.addWidget(button) + + if isinstance(layout, QHBoxLayout): + layout.addStretch() + + button_list[0].setChecked(True) + self._widget.setLayout(layout) + self._widget.update() + + def get_widget(self) -> QWidget: + return self._widget + + def _update(self, observable: Observable) -> None: + if observable is self._file_dialog_factory: + self._sync_model_to_view() + + +class OpenDatasetWizardLocationViewController(Observer): + def __init__(self, file_path: PathParameter, file_dialog_factory: FileDialogFactory) -> None: + super().__init__() + self._file_path = file_path + self._file_dialog_factory = file_dialog_factory + + self._widget = QLineEdit() + self._widget.editingFinished.connect(self._handle_editing_finished) + + self._sync_model_to_view() + file_path.add_observer(self) + file_dialog_factory.add_observer(self) + + def _handle_editing_finished(self) -> None: + text = self._widget.text() + path = Path(text) + + if not path.is_absolute(): + path = self._file_dialog_factory.get_open_working_directory() / text + + path = path.resolve() + + self._file_dialog_factory.set_open_working_directory(path) + self._file_path.set_value(path) + + def _sync_model_to_view(self) -> None: + file_path = self._file_path.get_value() + + if file_path.is_file(): + self._widget.setText(file_path.name) + else: + self._widget.clear() + + def get_widget(self) -> QWidget: + return self._widget + + def _update(self, observable: Observable) -> None: + if observable in (self._file_path, self._file_dialog_factory): + self._sync_model_to_view() + + +class OpenDatasetWizardFilePathViewController(Observer): + def __init__(self, file_path: PathParameter, file_dialog_factory: FileDialogFactory) -> None: + super().__init__() + self._file_path = file_path + self._file_dialog_factory = file_dialog_factory + + self._model = QFileSystemModel() + self._model.setNameFilterDisables(False) + + self._proxy_model = QSortFilterProxyModel() + self._proxy_model.setSourceModel(self._model) + + self._widget = QTableView() + self._widget.setModel(self._proxy_model) + self._widget.setSortingEnabled(True) + self._widget.sortByColumn(0, Qt.SortOrder.AscendingOrder) + self._widget.horizontalHeader().setSectionResizeMode(QHeaderView.ResizeToContents) + self._widget.verticalHeader().hide() + self._widget.setSelectionBehavior(QAbstractItemView.SelectionBehavior.SelectRows) + self._widget.doubleClicked.connect(self._handle_table_double_clicked) + self._widget.selectionModel().currentChanged.connect(self._handle_current_changed) + + self._sync_model_to_view() + file_path.add_observer(self) + file_dialog_factory.add_observer(self) + + def set_name_filters(self, name_filters: Sequence[str]) -> None: + logger.debug(f'Name Filters: {name_filters}') + self._model.setNameFilters(name_filters) + + def _handle_table_double_clicked(self, proxy_index: QModelIndex) -> None: + index = self._proxy_model.mapToSource(proxy_index) + file_info = self._model.fileInfo(index) + + if file_info.isDir(): + directory = Path(file_info.canonicalFilePath()) + self._file_dialog_factory.set_open_working_directory(directory) + + def _handle_current_changed(self, current: QModelIndex, previous: QModelIndex) -> None: + index = self._proxy_model.mapToSource(current) + path = Path(self._model.filePath(index)) + self._file_path.set_value(path) + + def _sync_model_to_view(self) -> None: + file_path = self._file_path.get_value() + root_path = self._file_dialog_factory.get_open_working_directory() + + index = self._model.setRootPath(str(root_path)) + proxy_index = self._proxy_model.mapFromSource(index) + self._widget.setRootIndex(proxy_index) + + if file_path.is_relative_to(root_path): + index = self._model.index(str(file_path)) + proxy_index = self._proxy_model.mapFromSource(index) + self._widget.setCurrentIndex(proxy_index) + + def get_widget(self) -> QWidget: + return self._widget + + def _update(self, observable: Observable) -> None: + if observable is self._file_path: + self._sync_model_to_view() + elif observable is self._file_dialog_factory: + self._sync_model_to_view() + + +class OpenDatasetWizardFileTypeViewController(Observable, Observer): + def __init__(self, api: PatternsAPI) -> None: + super().__init__() + self._file_reader_chooser = api.get_file_reader_chooser() + self._file_reader_chooser.add_observer(self) + self._combo_box = QComboBox() + + for plugin in self._file_reader_chooser: + self._combo_box.addItem(plugin.display_name) + + self._sync_model_to_view() + self._combo_box.textActivated.connect(self._handle_text_activated) + + def get_name_filters(self) -> Sequence[str]: + text = self._combo_box.currentText() + z = re.search(r'\((.+)\)', text) + return z.group(1).split() if z else [] + + def _handle_text_activated(self, text: str) -> None: + self._file_reader_chooser.set_current_plugin(text) + + def _sync_model_to_view(self) -> None: + self._combo_box.setCurrentText(self._file_reader_chooser.get_current_plugin().display_name) + + def get_widget(self) -> QWidget: + return self._combo_box + + def _update(self, observable: Observable) -> None: + if observable is self._file_reader_chooser: + self._sync_model_to_view() + self.notify_observers() + + +class OpenDatasetWizardFilesViewController(Observer): + def __init__( + self, settings: PatternSettings, api: PatternsAPI, file_dialog_factory: FileDialogFactory + ) -> None: + super().__init__() + self._settings = settings + self._api = api + self._file_dialog_factory = file_dialog_factory + + self._breadcrumbs_view_controller = OpenDatasetWizardBreadcrumbsViewController( + file_dialog_factory + ) + self._location_view_controller = OpenDatasetWizardLocationViewController( + settings.file_path, file_dialog_factory + ) + self._file_path_view_controller = OpenDatasetWizardFilePathViewController( + settings.file_path, file_dialog_factory + ) + self._file_type_view_controller = OpenDatasetWizardFileTypeViewController(api) + self._file_type_view_controller.add_observer(self) + + layout = QFormLayout() + layout.addRow(self._breadcrumbs_view_controller.get_widget()) + layout.addRow('Location:', self._location_view_controller.get_widget()) + layout.addRow(self._file_path_view_controller.get_widget()) + layout.addRow('File Type:', self._file_type_view_controller.get_widget()) + + self._page = OpenDatasetWizardPage() + self._page.setTitle('Choose Dataset File(s)') + self._page.setLayout(layout) + + self._sync_model_to_view() + settings.file_path.add_observer(self) + + def open_dataset(self) -> None: + file_reader_chooser = self._api.get_file_reader_chooser() + file_type = file_reader_chooser.get_current_plugin().simple_name + file_path = self._settings.file_path.get_value() + + try: + self._api.open_patterns(file_path, file_type=file_type) + except Exception as err: + logger.exception(err) + ExceptionDialog.show_exception('Open Dataset', err) + + def get_widget(self) -> QWizardPage: + return self._page + + def _check_if_complete(self) -> None: + file_path = self._settings.file_path.get_value() + self._page._set_complete(file_path.is_file()) + + def restart(self) -> None: + self._check_if_complete() + + def _handle_file_type_changed(self) -> None: + name_filters = self._file_type_view_controller.get_name_filters() + self._file_path_view_controller.set_name_filters(name_filters) + + def _sync_model_to_view(self) -> None: + file_path = self._settings.file_path.get_value() + + if file_path.exists(): + self._file_dialog_factory.set_open_working_directory(file_path) + + self._handle_file_type_changed() + + def _update(self, observable: Observable) -> None: + if observable is self._settings.file_path: + self._check_if_complete() + elif observable is self._file_type_view_controller: + self._handle_file_type_changed() diff --git a/src/ptychodus/controller/patterns/wizard/metadata.py b/src/ptychodus/controller/patterns/wizard/metadata.py new file mode 100644 index 00000000..9cc74996 --- /dev/null +++ b/src/ptychodus/controller/patterns/wizard/metadata.py @@ -0,0 +1,81 @@ +from PyQt5.QtWidgets import QWizardPage + +from ptychodus.api.observer import Observable, Observer + +from ....model.metadata import MetadataPresenter +from ....view.patterns import OpenDatasetWizardMetadataPage + + +class OpenDatasetWizardMetadataViewController(Observer): + def __init__(self, presenter: MetadataPresenter) -> None: + super().__init__() + self._presenter = presenter + self._page = OpenDatasetWizardMetadataPage() + + presenter.add_observer(self) + self._sync_model_to_view() + self._page._set_complete(True) + + def import_metadata(self) -> None: + if self._page.detector_extent_check_box.isChecked(): + self._presenter.sync_detector_extent() + + if self._page.detector_pixel_size_check_box.isChecked(): + self._presenter.sync_detector_pixel_size() + + if self._page.detector_bit_depth_check_box.isChecked(): + self._presenter.sync_detector_bit_depth() + + if self._page.detector_distance_check_box.isChecked(): + self._presenter.sync_detector_distance() + + self._presenter.sync_pattern_crop( + sync_center=self._page.pattern_crop_center_check_box.isChecked(), + sync_extent=self._page.pattern_crop_extent_check_box.isChecked(), + ) + + if self._page.probe_photon_count_check_box.isChecked(): + self._presenter.sync_probe_photon_count() + + if self._page.probe_energy_check_box.isChecked(): + self._presenter.sync_probe_energy() + + def _sync_model_to_view(self) -> None: + can_sync_detector_extent = self._presenter.can_sync_detector_extent() + self._page.detector_extent_check_box.setVisible(can_sync_detector_extent) + self._page.detector_extent_check_box.setChecked(can_sync_detector_extent) + + can_sync_detector_pixel_size = self._presenter.can_sync_detector_pixel_size() + self._page.detector_pixel_size_check_box.setVisible(can_sync_detector_pixel_size) + self._page.detector_pixel_size_check_box.setChecked(can_sync_detector_pixel_size) + + can_sync_detector_bit_depth = self._presenter.can_sync_detector_bit_depth() + self._page.detector_bit_depth_check_box.setVisible(can_sync_detector_bit_depth) + self._page.detector_bit_depth_check_box.setChecked(can_sync_detector_bit_depth) + + can_sync_detector_distance = self._presenter.can_sync_detector_distance() + self._page.detector_distance_check_box.setVisible(can_sync_detector_distance) + self._page.detector_distance_check_box.setChecked(can_sync_detector_distance) + + can_sync_pattern_crop_center = self._presenter.can_sync_pattern_crop_center() + self._page.pattern_crop_center_check_box.setVisible(can_sync_pattern_crop_center) + self._page.pattern_crop_center_check_box.setChecked(can_sync_pattern_crop_center) + + can_sync_pattern_crop_extent = self._presenter.can_sync_pattern_crop_extent() + self._page.pattern_crop_extent_check_box.setVisible(can_sync_pattern_crop_extent) + self._page.pattern_crop_extent_check_box.setChecked(can_sync_pattern_crop_extent) + + can_sync_probe_photon_count = self._presenter.can_sync_probe_photon_count() + self._page.probe_photon_count_check_box.setVisible(can_sync_probe_photon_count) + self._page.probe_photon_count_check_box.setChecked(can_sync_probe_photon_count) + + can_sync_probe_energy = self._presenter.can_sync_probe_energy() + self._page.probe_energy_check_box.setVisible(can_sync_probe_energy) + self._page.probe_energy_check_box.setChecked(can_sync_probe_energy) + + def get_widget(self) -> QWizardPage: + return self._page + + def _update(self, observable: Observable) -> None: + if observable is self._presenter: + self._sync_model_to_view() diff --git a/src/ptychodus/controller/patterns/wizard/patterns.py b/src/ptychodus/controller/patterns/wizard/patterns.py new file mode 100644 index 00000000..f9d7c0b6 --- /dev/null +++ b/src/ptychodus/controller/patterns/wizard/patterns.py @@ -0,0 +1,303 @@ +from typing import Final + +from PyQt5.QtCore import Qt +from PyQt5.QtWidgets import ( + QCheckBox, + QFormLayout, + QGridLayout, + QGroupBox, + QLabel, + QSpinBox, + QVBoxLayout, + QWidget, + QWizardPage, +) + +from ptychodus.api.observer import Observable + +from ....model.patterns import PatternSettings, PatternSizer +from ....view.patterns import OpenDatasetWizardPage + +from ...data import FileDialogFactory +from ...parametric import ( + CheckBoxParameterViewController, + CheckableGroupBoxParameterViewController, + ParameterViewController, + PathParameterViewController, + SpinBoxParameterViewController, +) + + +class PatternLoadViewController(ParameterViewController): + def __init__(self, settings: PatternSettings) -> None: + super().__init__() + self._view_controller = SpinBoxParameterViewController( + settings.num_data_threads, + ) + self._widget = QGroupBox('Load') + + layout = QFormLayout() + layout.addRow('Number of Data Threads:', self._view_controller.get_widget()) + self._widget.setLayout(layout) + + def get_widget(self) -> QWidget: + return self._widget + + +class PatternMemoryMapViewController(CheckableGroupBoxParameterViewController): + def __init__(self, settings: PatternSettings, file_dialog_factory: FileDialogFactory) -> None: + super().__init__(settings.is_memmap_enabled, 'Memory Map Diffraction Data') + self._view_controller = PathParameterViewController.create_directory_chooser( + settings.scratch_directory, file_dialog_factory + ) + + layout = QFormLayout() + layout.addRow('Scratch Directory:', self._view_controller.get_widget()) + self.get_widget().setLayout(layout) + + +class PatternCropViewController(CheckableGroupBoxParameterViewController): + def __init__( + self, + settings: PatternSettings, + sizer: PatternSizer, + ) -> None: + super().__init__(settings.is_crop_enabled, 'Crop') + self._settings = settings + self._sizer = sizer + + self._center_x_spin_box = QSpinBox() + self._center_y_spin_box = QSpinBox() + self._width_spin_box = QSpinBox() + self._height_spin_box = QSpinBox() + self._flip_x_check_box = QCheckBox('Flip X') + self._flip_y_check_box = QCheckBox('Flip Y') + + layout = QGridLayout() + layout.addWidget(QLabel('Center:'), 0, 0) + layout.addWidget(self._center_x_spin_box, 0, 1) + layout.addWidget(self._center_y_spin_box, 0, 2) + layout.addWidget(QLabel('Extent:'), 1, 0) + layout.addWidget(self._width_spin_box, 1, 1) + layout.addWidget(self._height_spin_box, 1, 2) + layout.addWidget(QLabel('Axes:'), 2, 0) + layout.addWidget(self._flip_x_check_box, 2, 1, Qt.AlignmentFlag.AlignHCenter) + layout.addWidget(self._flip_y_check_box, 2, 2, Qt.AlignmentFlag.AlignHCenter) + layout.setColumnStretch(1, 1) + layout.setColumnStretch(2, 1) + self.get_widget().setLayout(layout) + + self._sync_model_to_view() + + self._center_x_spin_box.valueChanged.connect(settings.crop_center_x_px.set_value) + self._center_y_spin_box.valueChanged.connect(settings.crop_center_y_px.set_value) + self._width_spin_box.valueChanged.connect(settings.crop_width_px.set_value) + self._height_spin_box.valueChanged.connect(settings.crop_height_px.set_value) + self._flip_x_check_box.toggled.connect(settings.is_flip_x_enabled.set_value) + self._flip_y_check_box.toggled.connect(settings.is_flip_y_enabled.set_value) + + sizer.add_observer(self) + + def _sync_model_to_view(self) -> None: + center_x = self._sizer.axis_x.get_crop_center() + center_y = self._sizer.axis_y.get_crop_center() + width = self._sizer.axis_x.get_crop_size() + height = self._sizer.axis_y.get_crop_size() + + center_x_limits = self._sizer.axis_x.get_crop_center_limits() + center_y_limits = self._sizer.axis_y.get_crop_center_limits() + width_limits = self._sizer.axis_x.get_crop_size_limits() + height_limits = self._sizer.axis_y.get_crop_size_limits() + + self._center_x_spin_box.blockSignals(True) + self._center_x_spin_box.setRange(center_x_limits.lower, center_x_limits.upper) + self._center_x_spin_box.setValue(center_x) + self._center_x_spin_box.blockSignals(False) + + self._center_y_spin_box.blockSignals(True) + self._center_y_spin_box.setRange(center_y_limits.lower, center_y_limits.upper) + self._center_y_spin_box.setValue(center_y) + self._center_y_spin_box.blockSignals(False) + + self._width_spin_box.blockSignals(True) + self._width_spin_box.setRange(width_limits.lower, width_limits.upper) + self._width_spin_box.setValue(width) + self._width_spin_box.blockSignals(False) + + self._height_spin_box.blockSignals(True) + self._height_spin_box.setRange(height_limits.lower, height_limits.upper) + self._height_spin_box.setValue(height) + self._height_spin_box.blockSignals(False) + + self._flip_x_check_box.setChecked(self._settings.is_flip_x_enabled.get_value()) + self._flip_y_check_box.setChecked(self._settings.is_flip_y_enabled.get_value()) + + def _update(self, observable: Observable) -> None: + if observable is self._sizer: + self._sync_model_to_view() + else: + super()._update(observable) + + +class PatternBinningViewController(CheckableGroupBoxParameterViewController): + def __init__( + self, + settings: PatternSettings, + sizer: PatternSizer, + ) -> None: + super().__init__(settings.is_binning_enabled, 'Bin Pixels') + self._settings = settings + self._sizer = sizer + + self._bin_size_x_spin_box = QSpinBox() + self._bin_size_y_spin_box = QSpinBox() + + layout = QGridLayout() + layout.addWidget(QLabel('Bin Size:'), 0, 0) + layout.addWidget(self._bin_size_x_spin_box, 0, 1) + layout.addWidget(self._bin_size_y_spin_box, 0, 2) + layout.setColumnStretch(1, 1) + layout.setColumnStretch(2, 1) + self.get_widget().setLayout(layout) + + self._sync_model_to_view() + + self._bin_size_x_spin_box.valueChanged.connect(settings.bin_size_x.set_value) + self._bin_size_y_spin_box.valueChanged.connect(settings.bin_size_y.set_value) + + sizer.add_observer(self) + + def _sync_model_to_view(self) -> None: + bin_size_x = self._sizer.axis_x.get_bin_size() + bin_size_y = self._sizer.axis_y.get_bin_size() + + bin_size_x_limits = self._sizer.axis_x.get_bin_size_limits() + bin_size_y_limits = self._sizer.axis_y.get_bin_size_limits() + + self._bin_size_x_spin_box.blockSignals(True) + self._bin_size_x_spin_box.setRange(bin_size_x_limits.lower, bin_size_x_limits.upper) + self._bin_size_x_spin_box.setValue(bin_size_x) + self._bin_size_x_spin_box.blockSignals(False) + + self._bin_size_y_spin_box.blockSignals(True) + self._bin_size_y_spin_box.setRange(bin_size_y_limits.lower, bin_size_y_limits.upper) + self._bin_size_y_spin_box.setValue(bin_size_y) + self._bin_size_y_spin_box.blockSignals(False) + + def _update(self, observable: Observable) -> None: + if observable is self._sizer: + self._sync_model_to_view() + else: + super()._update(observable) + + +class PatternPaddingViewController(CheckableGroupBoxParameterViewController): + MAX_INT: Final[int] = 0x7FFFFFFF + + def __init__( + self, + settings: PatternSettings, + sizer: PatternSizer, + ) -> None: + super().__init__(settings.is_padding_enabled, 'Pad') + self._settings = settings + self._sizer = sizer + + self._pad_x_spin_box = QSpinBox() + self._pad_y_spin_box = QSpinBox() + + layout = QGridLayout() + layout.addWidget(QLabel('Padding:'), 0, 0) + layout.addWidget(self._pad_x_spin_box, 0, 1) + layout.addWidget(self._pad_y_spin_box, 0, 2) + layout.setColumnStretch(1, 1) + layout.setColumnStretch(2, 1) + self.get_widget().setLayout(layout) + + self._sync_model_to_view() + + self._pad_x_spin_box.valueChanged.connect(settings.pad_x.set_value) + self._pad_y_spin_box.valueChanged.connect(settings.pad_y.set_value) + + sizer.add_observer(self) + + def _sync_model_to_view(self) -> None: + pad_x = self._sizer.axis_x.get_pad_size() + pad_y = self._sizer.axis_y.get_pad_size() + + self._pad_x_spin_box.blockSignals(True) + self._pad_x_spin_box.setRange(0, self.MAX_INT) + self._pad_x_spin_box.setValue(pad_x) + self._pad_x_spin_box.blockSignals(False) + + self._pad_y_spin_box.blockSignals(True) + self._pad_y_spin_box.setRange(0, self.MAX_INT) + self._pad_y_spin_box.setValue(pad_y) + self._pad_y_spin_box.blockSignals(False) + + def _update(self, observable: Observable) -> None: + if observable is self._sizer: + self._sync_model_to_view() + else: + super()._update(observable) + + +class PatternTransformViewController: + def __init__(self, settings: PatternSettings) -> None: + self._lower_bound_enabled_view_controller = CheckBoxParameterViewController( + settings.is_value_lower_bound_enabled, 'Value Lower Bound:' + ) + self._lower_bound_view_controller = SpinBoxParameterViewController( + settings.value_lower_bound + ) + self._upper_bound_enabled_view_controller = CheckBoxParameterViewController( + settings.is_value_upper_bound_enabled, 'Value upper Bound:' + ) + self._upper_bound_view_controller = SpinBoxParameterViewController( + settings.value_upper_bound + ) + + layout = QGridLayout() + layout.addWidget(self._lower_bound_enabled_view_controller.get_widget(), 0, 0) + layout.addWidget(self._lower_bound_view_controller.get_widget(), 0, 1, 1, 2) + layout.addWidget(self._upper_bound_enabled_view_controller.get_widget(), 1, 0) + layout.addWidget(self._upper_bound_view_controller.get_widget(), 1, 1, 1, 2) + layout.setColumnStretch(2, 1) + layout.setColumnStretch(3, 1) + + self._widget = QGroupBox('Transform') + self._widget.setLayout(layout) + + def get_widget(self) -> QWidget: + return self._widget + + +class OpenDatasetWizardPatternsViewController(ParameterViewController): + def __init__( + self, settings: PatternSettings, sizer: PatternSizer, file_dialog_factory: FileDialogFactory + ) -> None: + self._load_view_controller = PatternLoadViewController(settings) + self._memory_map_view_controller = PatternMemoryMapViewController( + settings, file_dialog_factory + ) + self._crop_view_controller = PatternCropViewController(settings, sizer) + self._binning_view_controller = PatternBinningViewController(settings, sizer) + self._padding_view_controller = PatternPaddingViewController(settings, sizer) + self._transform_view_controller = PatternTransformViewController(settings) + + layout = QVBoxLayout() + layout.addWidget(self._load_view_controller.get_widget()) + layout.addWidget(self._memory_map_view_controller.get_widget()) + layout.addWidget(self._crop_view_controller.get_widget()) + layout.addWidget(self._binning_view_controller.get_widget()) + layout.addWidget(self._padding_view_controller.get_widget()) + layout.addWidget(self._transform_view_controller.get_widget()) + layout.addStretch() + + self._page = OpenDatasetWizardPage() + self._page.setTitle('Pattern Processing') + self._page._set_complete(True) + self._page.setLayout(layout) + + def get_widget(self) -> QWizardPage: + return self._page diff --git a/src/ptychodus/controller/probe/core.py b/src/ptychodus/controller/probe/core.py index 509b7ef9..077c847b 100644 --- a/src/ptychodus/controller/probe/core.py +++ b/src/ptychodus/controller/probe/core.py @@ -7,7 +7,7 @@ from ptychodus.api.observer import SequenceObserver from ...model.analysis import ( - ExposureAnalyzer, + IlluminationMapper, ProbePropagator, STXMSimulator, ) @@ -23,12 +23,12 @@ ) from ..data import FileDialogFactory from ..image import ImageController -from .editorFactory import ProbeEditorViewControllerFactory -from .exposure import ExposureViewController +from .editor_factory import ProbeEditorViewControllerFactory +from .illumination import IlluminationViewController from .fluorescence import FluorescenceViewController from .propagator import ProbePropagationViewController from .stxm import STXMViewController -from .treeModel import ProbeTreeModel +from .tree_model import ProbeTreeModel logger = logging.getLogger(__name__) @@ -38,277 +38,245 @@ def __init__( self, repository: ProbeRepository, api: ProbeAPI, - imageController: ImageController, + image_controller: ImageController, propagator: ProbePropagator, - propagatorVisualizationEngine: VisualizationEngine, - stxmSimulator: STXMSimulator, - stxmVisualizationEngine: VisualizationEngine, - exposureAnalyzer: ExposureAnalyzer, - exposureVisualizationEngine: VisualizationEngine, - fluorescenceEnhancer: FluorescenceEnhancer, - fluorescenceVisualizationEngine: VisualizationEngine, + propagator_visualization_engine: VisualizationEngine, + stxm_simulator: STXMSimulator, + stxm_visualization_engine: VisualizationEngine, + illumination_mapper: IlluminationMapper, + illumination_visualization_engine: VisualizationEngine, + fluorescence_enhancer: FluorescenceEnhancer, + fluorescence_visualization_engine: VisualizationEngine, view: RepositoryTreeView, - fileDialogFactory: FileDialogFactory, - treeModel: ProbeTreeModel, + file_dialog_factory: FileDialogFactory, + *, + is_developer_mode_enabled: bool, ) -> None: super().__init__() self._repository = repository self._api = api - self._imageController = imageController + self._image_controller = image_controller self._view = view - self._fileDialogFactory = fileDialogFactory - self._treeModel = treeModel - self._editorFactory = ProbeEditorViewControllerFactory() + self._file_dialog_factory = file_dialog_factory + self._tree_model = ProbeTreeModel(repository, api) + self._editor_factory = ProbeEditorViewControllerFactory() - self._propagationViewController = ProbePropagationViewController( - propagator, propagatorVisualizationEngine, fileDialogFactory + self._propagation_view_controller = ProbePropagationViewController( + propagator, propagator_visualization_engine, file_dialog_factory ) - self._stxmViewController = STXMViewController( - stxmSimulator, stxmVisualizationEngine, fileDialogFactory + self._stxm_view_controller = STXMViewController( + stxm_simulator, stxm_visualization_engine, file_dialog_factory ) - self._exposureViewController = ExposureViewController( - exposureAnalyzer, exposureVisualizationEngine, fileDialogFactory + self._illumination_view_controller = IlluminationViewController( + illumination_mapper, + illumination_visualization_engine, + file_dialog_factory, + is_developer_mode_enabled=is_developer_mode_enabled, ) - self._fluorescenceViewController = FluorescenceViewController( - fluorescenceEnhancer, fluorescenceVisualizationEngine, fileDialogFactory + self._fluorescence_view_controller = FluorescenceViewController( + fluorescence_enhancer, fluorescence_visualization_engine, file_dialog_factory ) - @classmethod - def createInstance( - cls, - repository: ProbeRepository, - api: ProbeAPI, - imageController: ImageController, - propagator: ProbePropagator, - propagatorVisualizationEngine: VisualizationEngine, - stxmSimulator: STXMSimulator, - stxmVisualizationEngine: VisualizationEngine, - exposureAnalyzer: ExposureAnalyzer, - exposureVisualizationEngine: VisualizationEngine, - fluorescenceEnhancer: FluorescenceEnhancer, - fluorescenceVisualizationEngine: VisualizationEngine, - view: RepositoryTreeView, - fileDialogFactory: FileDialogFactory, - ) -> ProbeController: # TODO figure out good fix when saving NPY file without suffix (numpy adds suffix) - treeModel = ProbeTreeModel(repository, api) - controller = cls( - repository, - api, - imageController, - propagator, - propagatorVisualizationEngine, - stxmSimulator, - stxmVisualizationEngine, - exposureAnalyzer, - exposureVisualizationEngine, - fluorescenceEnhancer, - fluorescenceVisualizationEngine, - view, - fileDialogFactory, - treeModel, - ) - repository.addObserver(controller) - - builderListModel = QStringListModel() - builderListModel.setStringList([name for name in api.builderNames()]) - builderItemDelegate = ComboBoxItemDelegate(builderListModel, view.treeView) + repository.add_observer(self) - view.treeView.setModel(treeModel) - view.treeView.setSelectionBehavior(QAbstractItemView.SelectionBehavior.SelectRows) - powerItemDelegate = ProgressBarItemDelegate(view.treeView) - view.treeView.setItemDelegateForColumn(1, powerItemDelegate) - view.treeView.setItemDelegateForColumn(2, builderItemDelegate) - view.treeView.selectionModel().currentChanged.connect(controller._updateView) - controller._updateView(QModelIndex(), QModelIndex()) + builder_list_model = QStringListModel() + builder_list_model.setStringList([name for name in api.builder_names()]) + builder_item_delegate = ComboBoxItemDelegate(builder_list_model, view.tree_view) - loadFromFileAction = view.buttonBox.loadMenu.addAction('Open File...') - loadFromFileAction.triggered.connect(controller._loadCurrentProbeFromFile) + view.tree_view.setModel(self._tree_model) + view.tree_view.setSelectionBehavior(QAbstractItemView.SelectionBehavior.SelectRows) + power_item_delegate = ProgressBarItemDelegate(view.tree_view) + view.tree_view.setItemDelegateForColumn(1, power_item_delegate) + view.tree_view.setItemDelegateForColumn(2, builder_item_delegate) + view.tree_view.selectionModel().currentChanged.connect(self._update_view) + self._update_view(QModelIndex(), QModelIndex()) - copyAction = view.buttonBox.loadMenu.addAction('Copy...') - copyAction.triggered.connect(controller._copyToCurrentProbe) + load_from_file_action = view.button_box.load_menu.addAction('Open File...') + load_from_file_action.triggered.connect(self._load_current_probe_from_file) - saveToFileAction = view.buttonBox.saveMenu.addAction('Save File...') - saveToFileAction.triggered.connect(controller._saveCurrentProbeToFile) + copy_action = view.button_box.load_menu.addAction('Copy...') + copy_action.triggered.connect(self._copy_to_current_probe) - syncToSettingsAction = view.buttonBox.saveMenu.addAction('Sync To Settings') - syncToSettingsAction.triggered.connect(controller._syncCurrentProbeToSettings) + save_to_file_action = view.button_box.save_menu.addAction('Save File...') + save_to_file_action.triggered.connect(self._save_current_probe_to_file) - view.copierDialog.setWindowTitle('Copy Probe') - view.copierDialog.sourceComboBox.setModel(treeModel) - view.copierDialog.destinationComboBox.setModel(treeModel) - view.copierDialog.finished.connect(controller._finishCopyingProbe) + sync_to_settings_action = view.button_box.save_menu.addAction('Sync To Settings') + sync_to_settings_action.triggered.connect(self._sync_current_probe_to_settings) - view.buttonBox.editButton.clicked.connect(controller._editCurrentProbe) + view.copier_dialog.setWindowTitle('Copy Probe') + view.copier_dialog.source_combo_box.setModel(self._tree_model) + view.copier_dialog.destination_combo_box.setModel(self._tree_model) + view.copier_dialog.finished.connect(self._finish_copying_probe) - propagateAction = view.buttonBox.analyzeMenu.addAction('Propagate...') - propagateAction.triggered.connect(controller._propagateProbe) + view.button_box.edit_button.clicked.connect(self._edit_current_probe) - stxmAction = view.buttonBox.analyzeMenu.addAction('Simulate STXM...') - stxmAction.triggered.connect(controller._simulateSTXM) + propagate_action = view.button_box.analyze_menu.addAction('Propagate...') + propagate_action.triggered.connect(self._propagate_probe) - exposureAction = view.buttonBox.analyzeMenu.addAction('Exposure...') - exposureAction.triggered.connect(controller._analyzeExposure) + stxm_action = view.button_box.analyze_menu.addAction('Simulate STXM...') + stxm_action.triggered.connect(self._simulate_stxm) - fluorescenceAction = view.buttonBox.analyzeMenu.addAction('Enhance Fluorescence...') - fluorescenceAction.triggered.connect(controller._enhanceFluorescence) + illumination_action = view.button_box.analyze_menu.addAction('Map Illumination...') + illumination_action.triggered.connect(self._map_illumination) - return controller + fluorescence_action = view.button_box.analyze_menu.addAction('Enhance Fluorescence...') + fluorescence_action.triggered.connect(self._enhance_fluorescence) - def _getCurrentItemIndex(self) -> int: - modelIndex = self._view.treeView.currentIndex() + def _get_current_item_index(self) -> int: + model_index = self._view.tree_view.currentIndex() - if modelIndex.isValid(): - parent = modelIndex.parent() + if model_index.isValid(): + parent = model_index.parent() while parent.isValid(): - modelIndex = parent - parent = modelIndex.parent() + model_index = parent + parent = model_index.parent() - return modelIndex.row() + return model_index.row() logger.warning('No current index!') return -1 - def _loadCurrentProbeFromFile(self) -> None: - itemIndex = self._getCurrentItemIndex() + def _load_current_probe_from_file(self) -> None: + item_index = self._get_current_item_index() - if itemIndex < 0: + if item_index < 0: return - filePath, nameFilter = self._fileDialogFactory.getOpenFilePath( + file_path, name_filter = self._file_dialog_factory.get_open_file_path( self._view, 'Open Probe', - nameFilters=self._api.getOpenFileFilterList(), - selectedNameFilter=self._api.getOpenFileFilter(), + name_filters=[nf for nf in self._api.get_open_file_filters()], + selected_name_filter=self._api.get_open_file_filter(), ) - if filePath: + if file_path: try: - self._api.openProbe(itemIndex, filePath, fileType=nameFilter) + self._api.open_probe(item_index, file_path, file_type=name_filter) except Exception as err: logger.exception(err) - ExceptionDialog.showException('File Reader', err) + ExceptionDialog.show_exception('File Reader', err) - def _copyToCurrentProbe(self) -> None: - itemIndex = self._getCurrentItemIndex() + def _copy_to_current_probe(self) -> None: + item_index = self._get_current_item_index() - if itemIndex >= 0: - self._view.copierDialog.destinationComboBox.setCurrentIndex(itemIndex) - self._view.copierDialog.open() + if item_index >= 0: + self._view.copier_dialog.destination_combo_box.setCurrentIndex(item_index) + self._view.copier_dialog.open() - def _finishCopyingProbe(self, result: int) -> None: + def _finish_copying_probe(self, result: int) -> None: if result == QDialog.DialogCode.Accepted: - sourceIndex = self._view.copierDialog.sourceComboBox.currentIndex() - destinationIndex = self._view.copierDialog.destinationComboBox.currentIndex() - self._api.copyProbe(sourceIndex, destinationIndex) + source_index = self._view.copier_dialog.source_combo_box.currentIndex() + destination_index = self._view.copier_dialog.destination_combo_box.currentIndex() + self._api.copy_probe(source_index, destination_index) - def _editCurrentProbe(self) -> None: - itemIndex = self._getCurrentItemIndex() + def _edit_current_probe(self) -> None: + item_index = self._get_current_item_index() - if itemIndex < 0: + if item_index < 0: return - itemName = self._repository.getName(itemIndex) - item = self._repository[itemIndex] - dialog = self._editorFactory.createEditorDialog(itemName, item, self._view) + item_name = self._repository.get_name(item_index) + item = self._repository[item_index] + dialog = self._editor_factory.create_editor_dialog(item_name, item, self._view) dialog.open() - def _saveCurrentProbeToFile(self) -> None: - itemIndex = self._getCurrentItemIndex() + def _save_current_probe_to_file(self) -> None: + item_index = self._get_current_item_index() - if itemIndex < 0: + if item_index < 0: return - filePath, nameFilter = self._fileDialogFactory.getSaveFilePath( + file_path, name_filter = self._file_dialog_factory.get_save_file_path( self._view, 'Save Probe', - nameFilters=self._api.getSaveFileFilterList(), - selectedNameFilter=self._api.getSaveFileFilter(), + name_filters=[nf for nf in self._api.get_save_file_filters()], + selected_name_filter=self._api.get_save_file_filter(), ) - if filePath: + if file_path: try: - self._api.saveProbe(itemIndex, filePath, nameFilter) + self._api.save_probe(item_index, file_path, name_filter) except Exception as err: logger.exception(err) - ExceptionDialog.showException('File Writer', err) + ExceptionDialog.show_exception('File Writer', err) - def _syncCurrentProbeToSettings(self) -> None: - itemIndex = self._getCurrentItemIndex() + def _sync_current_probe_to_settings(self) -> None: + item_index = self._get_current_item_index() - if itemIndex < 0: + if item_index < 0: logger.warning('No current item!') else: - item = self._repository[itemIndex] - item.syncToSettings() + item = self._repository[item_index] + item.sync_to_settings() - def _propagateProbe(self) -> None: - itemIndex = self._getCurrentItemIndex() + def _propagate_probe(self) -> None: + item_index = self._get_current_item_index() - if itemIndex < 0: + if item_index < 0: logger.warning('No current item!') else: - self._propagationViewController.launch(itemIndex) + self._propagation_view_controller.launch(item_index) - def _simulateSTXM(self) -> None: - itemIndex = self._getCurrentItemIndex() + def _simulate_stxm(self) -> None: + item_index = self._get_current_item_index() - if itemIndex < 0: + if item_index < 0: logger.warning('No current item!') else: - self._stxmViewController.launch(itemIndex) + self._stxm_view_controller.simulate(item_index) - def _analyzeExposure(self) -> None: - itemIndex = self._getCurrentItemIndex() + def _map_illumination(self) -> None: + item_index = self._get_current_item_index() - if itemIndex < 0: + if item_index < 0: logger.warning('No current item!') else: - self._exposureViewController.analyze(itemIndex) + self._illumination_view_controller.map(item_index) - def _enhanceFluorescence(self) -> None: - itemIndex = self._getCurrentItemIndex() + def _enhance_fluorescence(self) -> None: + item_index = self._get_current_item_index() - if itemIndex < 0: + if item_index < 0: logger.warning('No current item!') else: - self._fluorescenceViewController.launch(itemIndex) + self._fluorescence_view_controller.launch(item_index) - def _updateView(self, current: QModelIndex, previous: QModelIndex) -> None: + def _update_view(self, current: QModelIndex, previous: QModelIndex) -> None: enabled = current.isValid() - self._view.buttonBox.loadButton.setEnabled(enabled) - self._view.buttonBox.saveButton.setEnabled(enabled) - self._view.buttonBox.editButton.setEnabled(enabled) - self._view.buttonBox.analyzeButton.setEnabled(enabled) + self._view.button_box.load_button.setEnabled(enabled) + self._view.button_box.save_button.setEnabled(enabled) + self._view.button_box.edit_button.setEnabled(enabled) + self._view.button_box.analyze_button.setEnabled(enabled) - itemIndex = self._getCurrentItemIndex() + item_index = self._get_current_item_index() - if itemIndex < 0: - self._imageController.clearArray() + if item_index < 0: + self._image_controller.clear_array() else: try: - item = self._repository[itemIndex] + item = self._repository[item_index] except IndexError: logger.warning('Unable to access item for visualization!') else: - probe = item.getProbe() + probe = item.get_probes().get_probe_no_opr() # TODO OPR array = ( - probe.getMode(current.row()) + probe.get_incoherent_mode(current.row()) if current.parent().isValid() - else probe.getModesFlattened() + else probe.get_incoherent_modes_flattened() ) - self._imageController.setArray(array, probe.getPixelGeometry()) + self._image_controller.set_array(array, probe.get_pixel_geometry()) - def handleItemInserted(self, index: int, item: ProbeRepositoryItem) -> None: - self._treeModel.insertItem(index, item) + def handle_item_inserted(self, index: int, item: ProbeRepositoryItem) -> None: + self._tree_model.insert_item(index, item) - def handleItemChanged(self, index: int, item: ProbeRepositoryItem) -> None: - self._treeModel.updateItem(index, item) + def handle_item_changed(self, index: int, item: ProbeRepositoryItem) -> None: + self._tree_model.update_item(index, item) - if index == self._getCurrentItemIndex(): - currentIndex = self._view.treeView.currentIndex() - self._updateView(currentIndex, currentIndex) + if index == self._get_current_item_index(): + current_index = self._view.tree_view.currentIndex() + self._update_view(current_index, current_index) - def handleItemRemoved(self, index: int, item: ProbeRepositoryItem) -> None: - self._treeModel.removeItem(index, item) + def handle_item_removed(self, index: int, item: ProbeRepositoryItem) -> None: + self._tree_model.remove_item(index, item) diff --git a/src/ptychodus/controller/probe/editorFactory.py b/src/ptychodus/controller/probe/editorFactory.py deleted file mode 100644 index 453e3256..00000000 --- a/src/ptychodus/controller/probe/editorFactory.py +++ /dev/null @@ -1,284 +0,0 @@ -from PyQt5.QtWidgets import ( - QButtonGroup, - QDialog, - QFormLayout, - QGroupBox, - QHBoxLayout, - QMessageBox, - QRadioButton, - QSpinBox, - QTableView, - QWidget, -) - -from ptychodus.api.observer import Observable, Observer -from ptychodus.api.parametric import StringParameter - -from ...model.product.probe import ( - AveragePatternProbeBuilder, - DiskProbeBuilder, - FresnelZonePlateProbeBuilder, - MultimodalProbeBuilder, - ProbeModeDecayType, - ProbeRepositoryItem, - RectangularProbeBuilder, - SuperGaussianProbeBuilder, - ZernikeProbeBuilder, -) -from ...view.widgets import GroupBoxWithPresets -from ..parametric import ( - LengthWidgetParameterViewController, - ParameterViewBuilder, - ParameterViewController, -) -from .zernike import ZernikeTableModel - -__all__ = [ - 'ProbeEditorViewControllerFactory', -] - - -class FresnelZonePlateViewController(ParameterViewController): - def __init__(self, title: str, probeBuilder: FresnelZonePlateProbeBuilder) -> None: - super().__init__() - self._widget = GroupBoxWithPresets(title) - - for index, presetsLabel in enumerate(probeBuilder.labelsForPresets()): - action = self._widget.presetsMenu.addAction(presetsLabel) - action.triggered.connect(lambda _, index=index: probeBuilder.applyPresets(index)) - - self._zonePlateDiameterViewController = LengthWidgetParameterViewController( - probeBuilder.zonePlateDiameterInMeters - ) - self._outermostZoneWidthInMetersViewController = LengthWidgetParameterViewController( - probeBuilder.outermostZoneWidthInMeters - ) - self._centralBeamstopDiameterInMetersViewController = LengthWidgetParameterViewController( - probeBuilder.centralBeamstopDiameterInMeters - ) - self._defocusDistanceInMetersViewController = LengthWidgetParameterViewController( - probeBuilder.defocusDistanceInMeters - ) - - layout = QFormLayout() - layout.setContentsMargins(0, 0, 0, 0) - layout.addRow('Zone Plate Diameter:', self._zonePlateDiameterViewController.getWidget()) - layout.addRow( - 'Outermost Zone Width:', - self._outermostZoneWidthInMetersViewController.getWidget(), - ) - layout.addRow( - 'Central Beamstop Diameter:', - self._centralBeamstopDiameterInMetersViewController.getWidget(), - ) - layout.addRow('Defocus Distance:', self._defocusDistanceInMetersViewController.getWidget()) - self._widget.contents.setLayout(layout) - - def getWidget(self) -> QWidget: - return self._widget - - -class ZernikeViewController(ParameterViewController, Observer): - def __init__(self, title: str, probeBuilder: ZernikeProbeBuilder) -> None: - super().__init__() - self._widget = QGroupBox(title) - self._probeBuilder = probeBuilder - self._orderSpinBox = QSpinBox() - self._coefficientsTableModel = ZernikeTableModel(probeBuilder) - self._coefficientsTableView = QTableView() - self._diameterViewController = LengthWidgetParameterViewController( - probeBuilder.diameterInMeters - ) - - self._coefficientsTableView.setModel(self._coefficientsTableModel) - - layout = QFormLayout() - layout.addRow('Diameter:', self._diameterViewController.getWidget()) - layout.addRow('Order:', self._orderSpinBox) - layout.addRow(self._coefficientsTableView) - self._widget.setLayout(layout) - - self._syncModelToView() - self._orderSpinBox.valueChanged.connect(probeBuilder.setOrder) - probeBuilder.addObserver(self) - - def getWidget(self) -> QWidget: - return self._widget - - def _syncModelToView(self) -> None: - self._orderSpinBox.setRange(1, 100) - self._orderSpinBox.setValue(self._probeBuilder.getOrder()) - - self._coefficientsTableModel.beginResetModel() # TODO clean up - self._coefficientsTableModel.endResetModel() - - def update(self, observable: Observable) -> None: - if observable is self._probeBuilder: - self._syncModelToView() - - -class DecayTypeParameterViewController(ParameterViewController, Observer): - def __init__(self, parameter: StringParameter) -> None: - super().__init__() - self._parameter = parameter - self._polynomialDecayButton = QRadioButton('Polynomial') - self._exponentialDecayButton = QRadioButton('Exponential') - - self._buttonGroup = QButtonGroup() - self._buttonGroup.addButton( - self._polynomialDecayButton, ProbeModeDecayType.POLYNOMIAL.value - ) - self._buttonGroup.addButton( - self._exponentialDecayButton, ProbeModeDecayType.EXPONENTIAL.value - ) - self._buttonGroup.setExclusive(True) - self._buttonGroup.idToggled.connect(self._syncViewToModel) - - layout = QHBoxLayout() - layout.setContentsMargins(0, 0, 0, 0) - layout.addWidget(self._polynomialDecayButton) - layout.addWidget(self._exponentialDecayButton) - - self._widget = QWidget() - self._widget.setLayout(layout) - - self._syncModelToView() - parameter.addObserver(self) - - def getWidget(self) -> QWidget: - return self._widget - - def _syncViewToModel(self, toolId: int, checked: bool) -> None: - if checked: - decayType = ProbeModeDecayType(toolId) - self._parameter.setValue(decayType.name) - - def _syncModelToView(self) -> None: - try: - decayType = ProbeModeDecayType[self._parameter.getValue().upper()] - except KeyError: - decayType = ProbeModeDecayType.POLYNOMIAL - - button = self._buttonGroup.button(decayType.value) - button.setChecked(True) - - def update(self, observable: Observable) -> None: - if observable is self._parameter: - self._syncModelToView() - - -class ProbeEditorViewControllerFactory: - def _appendAdditionalModes( - self, - dialogBuilder: ParameterViewBuilder, - modesBuilder: MultimodalProbeBuilder, - ) -> None: - additionalModesGroup = 'Additional Modes' - dialogBuilder.addSpinBox( - modesBuilder.numberOfModes, - 'Number of Modes:', - group=additionalModesGroup, - ) - dialogBuilder.addCheckBox( - modesBuilder.isOrthogonalizeModesEnabled, - 'Orthogonalize Modes:', - group=additionalModesGroup, - ) - dialogBuilder.addViewController( - DecayTypeParameterViewController(modesBuilder.modeDecayType), - 'Decay Type:', - group=additionalModesGroup, - ) - dialogBuilder.addDecimalSlider( - modesBuilder.modeDecayRatio, - 'Decay Ratio:', - group=additionalModesGroup, - ) - - def createEditorDialog( - self, itemName: str, item: ProbeRepositoryItem, parent: QWidget - ) -> QDialog: - probeBuilder = item.getBuilder() - builderName = probeBuilder.getName() - modesBuilder = item.getAdditionalModesBuilder() - primaryModeGroup = 'Primary Mode' - title = f'{itemName} [{builderName}]' - - if isinstance(probeBuilder, AveragePatternProbeBuilder): - dialogBuilder = ParameterViewBuilder() - self._appendAdditionalModes(dialogBuilder, modesBuilder) - return dialogBuilder.buildDialog(title, parent) - elif isinstance(probeBuilder, DiskProbeBuilder): - dialogBuilder = ParameterViewBuilder() - dialogBuilder.addLengthWidget( - probeBuilder.diameterInMeters, - 'Diameter:', - group=primaryModeGroup, - ) - dialogBuilder.addLengthWidget( - probeBuilder.defocusDistanceInMeters, - 'Defocus Distance:', - group=primaryModeGroup, - ) - self._appendAdditionalModes(dialogBuilder, modesBuilder) - return dialogBuilder.buildDialog(title, parent) - elif isinstance(probeBuilder, FresnelZonePlateProbeBuilder): - dialogBuilder = ParameterViewBuilder() - dialogBuilder.addViewControllerToTop( - FresnelZonePlateViewController(primaryModeGroup, probeBuilder) - ) - self._appendAdditionalModes(dialogBuilder, modesBuilder) - return dialogBuilder.buildDialog(title, parent) - elif isinstance(probeBuilder, RectangularProbeBuilder): - dialogBuilder = ParameterViewBuilder() - dialogBuilder.addLengthWidget( - probeBuilder.widthInMeters, - 'Width:', - group=primaryModeGroup, - ) - dialogBuilder.addLengthWidget( - probeBuilder.heightInMeters, - 'Height:', - group=primaryModeGroup, - ) - dialogBuilder.addLengthWidget( - probeBuilder.defocusDistanceInMeters, - 'Defocus Distance:', - group=primaryModeGroup, - ) - self._appendAdditionalModes(dialogBuilder, modesBuilder) - return dialogBuilder.buildDialog(title, parent) - elif isinstance(probeBuilder, SuperGaussianProbeBuilder): - dialogBuilder = ParameterViewBuilder() - dialogBuilder.addLengthWidget( - probeBuilder.annularRadiusInMeters, - 'Annular Radius:', - group=primaryModeGroup, - ) - dialogBuilder.addLengthWidget( - probeBuilder.fwhmInMeters, - 'Full Width at Half Maximum:', - group=primaryModeGroup, - ) - dialogBuilder.addDecimalLineEdit( - probeBuilder.orderParameter, - 'Order Parameter:', - group=primaryModeGroup, - ) - self._appendAdditionalModes(dialogBuilder, modesBuilder) - return dialogBuilder.buildDialog(title, parent) - elif isinstance(probeBuilder, ZernikeProbeBuilder): - dialogBuilder = ParameterViewBuilder() - dialogBuilder.addViewControllerToTop( - ZernikeViewController(primaryModeGroup, probeBuilder) - ) - self._appendAdditionalModes(dialogBuilder, modesBuilder) - return dialogBuilder.buildDialog(title, parent) - - return QMessageBox( - QMessageBox.Icon.Information, - title, - f'"{builderName}" has no editable parameters!', - QMessageBox.Ok, - parent, - ) diff --git a/src/ptychodus/controller/probe/editor_factory.py b/src/ptychodus/controller/probe/editor_factory.py new file mode 100644 index 00000000..cb2bf341 --- /dev/null +++ b/src/ptychodus/controller/probe/editor_factory.py @@ -0,0 +1,296 @@ +from PyQt5.QtWidgets import ( + QButtonGroup, + QDialog, + QFormLayout, + QGroupBox, + QHBoxLayout, + QMessageBox, + QRadioButton, + QSpinBox, + QTableView, + QWidget, +) + +from ptychodus.api.observer import Observable, Observer +from ptychodus.api.parametric import StringParameter + +from ...model.product.probe import ( + AveragePatternProbeBuilder, + DiskProbeBuilder, + FresnelZonePlateProbeBuilder, + MultimodalProbeBuilder, + ProbeModeDecayType, + ProbeRepositoryItem, + RectangularProbeBuilder, + SuperGaussianProbeBuilder, + ZernikeProbeBuilder, +) +from ...view.widgets import GroupBoxWithPresets +from ..parametric import ( + LengthWidgetParameterViewController, + ParameterViewBuilder, + ParameterViewController, +) +from .zernike import ZernikeTableModel + +__all__ = [ + 'ProbeEditorViewControllerFactory', +] + + +class FresnelZonePlateViewController(ParameterViewController): + def __init__(self, title: str, probe_builder: FresnelZonePlateProbeBuilder) -> None: + super().__init__() + self._widget = GroupBoxWithPresets(title) + + for label in probe_builder.labels_for_presets(): + action = self._widget.presets_menu.addAction(label) + action.triggered.connect(lambda _, label=label: probe_builder.apply_presets(label)) + + self._zone_plate_diameter_view_controller = LengthWidgetParameterViewController( + probe_builder.zone_plate_diameter_m + ) + self._outermost_zone_width_view_controller = LengthWidgetParameterViewController( + probe_builder.outermost_zone_width_m + ) + self._central_beamstop_diameter_view_controller = LengthWidgetParameterViewController( + probe_builder.central_beamstop_diameter_m + ) + self._defocus_distance_view_controller = LengthWidgetParameterViewController( + probe_builder.defocus_distance_m + ) + + layout = QFormLayout() + layout.setContentsMargins(0, 0, 0, 0) + layout.addRow( + 'Zone Plate Diameter:', self._zone_plate_diameter_view_controller.get_widget() + ) + layout.addRow( + 'Outermost Zone Width:', + self._outermost_zone_width_view_controller.get_widget(), + ) + layout.addRow( + 'Central Beamstop Diameter:', + self._central_beamstop_diameter_view_controller.get_widget(), + ) + layout.addRow('Defocus Distance:', self._defocus_distance_view_controller.get_widget()) + self._widget.contents.setLayout(layout) + + def get_widget(self) -> QWidget: + return self._widget + + +class ZernikeViewController(ParameterViewController, Observer): + def __init__(self, title: str, probe_builder: ZernikeProbeBuilder) -> None: + super().__init__() + self._widget = QGroupBox(title) + self._probe_builder = probe_builder + self._order_spin_box = QSpinBox() + self._coefficients_table_model = ZernikeTableModel(probe_builder) + self._coefficients_table_view = QTableView() + self._diameter_view_controller = LengthWidgetParameterViewController( + probe_builder.diameter_m + ) + + self._coefficients_table_view.setModel(self._coefficients_table_model) + + layout = QFormLayout() + layout.addRow('Diameter:', self._diameter_view_controller.get_widget()) + layout.addRow('Order:', self._order_spin_box) + layout.addRow(self._coefficients_table_view) + self._widget.setLayout(layout) + + self._sync_model_to_view() + self._order_spin_box.valueChanged.connect(probe_builder.set_order) + probe_builder.add_observer(self) + + def get_widget(self) -> QWidget: + return self._widget + + def _sync_model_to_view(self) -> None: + self._order_spin_box.setRange(1, 100) + self._order_spin_box.setValue(self._probe_builder.get_order()) + + self._coefficients_table_model.beginResetModel() # TODO clean up + self._coefficients_table_model.endResetModel() + + def _update(self, observable: Observable) -> None: + if observable is self._probe_builder: + self._sync_model_to_view() + + +class DecayTypeParameterViewController(ParameterViewController, Observer): + def __init__(self, parameter: StringParameter) -> None: + super().__init__() + self._parameter = parameter + self._polynomial_decay_button = QRadioButton('Polynomial') + self._exponential_decay_button = QRadioButton('Exponential') + + self._button_group = QButtonGroup() + self._button_group.addButton( + self._polynomial_decay_button, ProbeModeDecayType.POLYNOMIAL.value + ) + self._button_group.addButton( + self._exponential_decay_button, ProbeModeDecayType.EXPONENTIAL.value + ) + self._button_group.setExclusive(True) + self._button_group.idToggled.connect(self._sync_view_to_model) + + layout = QHBoxLayout() + layout.setContentsMargins(0, 0, 0, 0) + layout.addWidget(self._polynomial_decay_button) + layout.addWidget(self._exponential_decay_button) + + self._widget = QWidget() + self._widget.setLayout(layout) + + self._sync_model_to_view() + parameter.add_observer(self) + + def get_widget(self) -> QWidget: + return self._widget + + def _sync_view_to_model(self, tool_id: int, checked: bool) -> None: + if checked: + decay_type = ProbeModeDecayType(tool_id) + self._parameter.set_value(decay_type.name) + + def _sync_model_to_view(self) -> None: + try: + decay_type = ProbeModeDecayType[self._parameter.get_value().upper()] + except KeyError: + decay_type = ProbeModeDecayType.POLYNOMIAL + + button = self._button_group.button(decay_type.value) + button.setChecked(True) + + def _update(self, observable: Observable) -> None: + if observable is self._parameter: + self._sync_model_to_view() + + +class ProbeEditorViewControllerFactory: + def _append_additional_modes( + self, + dialog_builder: ParameterViewBuilder, + additional_modes_builder: MultimodalProbeBuilder | None, + ) -> None: + if additional_modes_builder is None: + return + + incoherent_modes_group = 'Incoherent (Mixed State) Modes' + dialog_builder.add_spin_box( + additional_modes_builder.num_incoherent_modes, + 'Number of Modes:', + group=incoherent_modes_group, + ) + dialog_builder.add_check_box( + additional_modes_builder.orthogonalize_incoherent_modes, + 'Orthogonalize Modes:', + group=incoherent_modes_group, + ) + dialog_builder.add_view_controller( + DecayTypeParameterViewController(additional_modes_builder.incoherent_mode_decay_type), + 'Decay Type:', + group=incoherent_modes_group, + ) + dialog_builder.add_decimal_slider( + additional_modes_builder.incoherent_mode_decay_ratio, + 'Decay Ratio:', + group=incoherent_modes_group, + ) + + coherent_modes_group = 'Coherent (OPR) Modes' + dialog_builder.add_spin_box( + additional_modes_builder.num_coherent_modes, + 'Number of Modes:', + group=coherent_modes_group, + ) + + def create_editor_dialog( + self, item_name: str, item: ProbeRepositoryItem, parent: QWidget + ) -> QDialog: + probe_builder = item.get_builder() + builder_name = probe_builder.get_name() + additional_modes_builder = item.get_additional_modes_builder() + primary_mode_group = 'Primary Mode' + title = f'{item_name} [{builder_name}]' + + if isinstance(probe_builder, AveragePatternProbeBuilder): + dialog_builder = ParameterViewBuilder() + self._append_additional_modes(dialog_builder, additional_modes_builder) + return dialog_builder.build_dialog(title, parent) + elif isinstance(probe_builder, DiskProbeBuilder): + dialog_builder = ParameterViewBuilder() + dialog_builder.add_length_widget( + probe_builder.diameter_m, + 'Diameter:', + group=primary_mode_group, + ) + dialog_builder.add_length_widget( + probe_builder.defocus_distance_m, + 'Defocus Distance:', + group=primary_mode_group, + ) + self._append_additional_modes(dialog_builder, additional_modes_builder) + return dialog_builder.build_dialog(title, parent) + elif isinstance(probe_builder, FresnelZonePlateProbeBuilder): + dialog_builder = ParameterViewBuilder() + dialog_builder.add_view_controller_to_top( + FresnelZonePlateViewController(primary_mode_group, probe_builder) + ) + self._append_additional_modes(dialog_builder, additional_modes_builder) + return dialog_builder.build_dialog(title, parent) + elif isinstance(probe_builder, RectangularProbeBuilder): + dialog_builder = ParameterViewBuilder() + dialog_builder.add_length_widget( + probe_builder.width_m, + 'Width:', + group=primary_mode_group, + ) + dialog_builder.add_length_widget( + probe_builder.height_m, + 'Height:', + group=primary_mode_group, + ) + dialog_builder.add_length_widget( + probe_builder.defocus_distance_m, + 'Defocus Distance:', + group=primary_mode_group, + ) + self._append_additional_modes(dialog_builder, additional_modes_builder) + return dialog_builder.build_dialog(title, parent) + elif isinstance(probe_builder, SuperGaussianProbeBuilder): + dialog_builder = ParameterViewBuilder() + dialog_builder.add_length_widget( + probe_builder.annular_radius_m, + 'Annular Radius:', + group=primary_mode_group, + ) + dialog_builder.add_length_widget( + probe_builder.fwhm_m, + 'Full Width at Half Maximum:', + group=primary_mode_group, + ) + dialog_builder.add_decimal_line_edit( + probe_builder.order_parameter, + 'Order Parameter:', + group=primary_mode_group, + ) + self._append_additional_modes(dialog_builder, additional_modes_builder) + return dialog_builder.build_dialog(title, parent) + elif isinstance(probe_builder, ZernikeProbeBuilder): + dialog_builder = ParameterViewBuilder() + dialog_builder.add_view_controller_to_top( + ZernikeViewController(primary_mode_group, probe_builder) + ) + self._append_additional_modes(dialog_builder, additional_modes_builder) + return dialog_builder.build_dialog(title, parent) + + return QMessageBox( + QMessageBox.Icon.Information, + title, + f'"{builder_name}" has no editable parameters!', + QMessageBox.Ok, + parent, + ) diff --git a/src/ptychodus/controller/probe/exposure.py b/src/ptychodus/controller/probe/exposure.py deleted file mode 100644 index 0ee2fb46..00000000 --- a/src/ptychodus/controller/probe/exposure.py +++ /dev/null @@ -1,70 +0,0 @@ -import logging - -from ...model.analysis import ExposureAnalyzer, ExposureMap -from ...model.visualization import VisualizationEngine -from ...view.probe import ExposureDialog -from ...view.widgets import ExceptionDialog -from ..data import FileDialogFactory -from ..visualization import ( - VisualizationParametersController, - VisualizationWidgetController, -) - -logger = logging.getLogger(__name__) - - -class ExposureViewController: - def __init__( - self, - analyzer: ExposureAnalyzer, - engine: VisualizationEngine, - fileDialogFactory: FileDialogFactory, - ) -> None: - super().__init__() - self._analyzer = analyzer - self._engine = engine - self._fileDialogFactory = fileDialogFactory - self._dialog = ExposureDialog() - self._dialog.saveButton.clicked.connect(self._saveResult) - - self._visualizationWidgetController = VisualizationWidgetController( - engine, - self._dialog.visualizationWidget, - self._dialog.statusBar, - fileDialogFactory, - ) - self._visualizationParametersController = VisualizationParametersController.createInstance( - engine, self._dialog.visualizationParametersView - ) - self._result: ExposureMap | None = None - - def analyze(self, itemIndex: int) -> None: - try: - result = self._analyzer.analyze(itemIndex) - except Exception as err: - logger.exception(err) - ExceptionDialog.showException('Exposure Analysis', err) - return - - self._result = result - self._dialog.open() - - def _saveResult(self) -> None: - if self._result is None: - logger.debug('No result to save!') - return - - title = 'Save Result' - filePath, nameFilter = self._fileDialogFactory.getSaveFilePath( - self._dialog, - title, - nameFilters=self._analyzer.getSaveFileFilterList(), - selectedNameFilter=self._analyzer.getSaveFileFilter(), - ) - - if filePath: - try: - self._analyzer.saveResult(filePath, self._result) - except Exception as err: - logger.exception(err) - ExceptionDialog.showException(title, err) diff --git a/src/ptychodus/controller/probe/fluorescence.py b/src/ptychodus/controller/probe/fluorescence.py index a2a49c1b..34b16884 100644 --- a/src/ptychodus/controller/probe/fluorescence.py +++ b/src/ptychodus/controller/probe/fluorescence.py @@ -36,11 +36,11 @@ def __init__(self, enhancer: FluorescenceEnhancer, parent: QObject | None = None def data(self, index: QModelIndex, role: int = Qt.ItemDataRole.DisplayRole) -> Any: # TODO make this a table model and show measured/enhanced count statistics if index.isValid() and role == Qt.ItemDataRole.DisplayRole: - emap = self._enhancer.getMeasuredElementMap(index.row()) + emap = self._enhancer.get_measured_element_map(index.row()) return emap.name - def rowCount(self, parent: QModelIndex = QModelIndex()) -> int: - return self._enhancer.getNumberOfChannels() + def rowCount(self, parent: QModelIndex = QModelIndex()) -> int: # noqa: N802 + return self._enhancer.get_num_channels() class FluorescenceTwoStepViewController(Observer): @@ -49,33 +49,37 @@ def __init__(self, algorithm: TwoStepFluorescenceEnhancingAlgorithm) -> None: self._algorithm = algorithm self._view = FluorescenceTwoStepParametersView() - self._upscalingModel = QStringListModel() - self._upscalingModel.setStringList(self._algorithm.getUpscalingStrategyList()) - self._view.upscalingStrategyComboBox.setModel(self._upscalingModel) - self._view.upscalingStrategyComboBox.textActivated.connect(algorithm.setUpscalingStrategy) + self._upscaling_model = QStringListModel() + self._upscaling_model.setStringList(self._algorithm.get_upscaling_strategies()) + self._view.upscaling_strategy_combo_box.setModel(self._upscaling_model) + self._view.upscaling_strategy_combo_box.textActivated.connect( + algorithm.set_upscaling_strategy + ) - self._deconvolutionModel = QStringListModel() - self._deconvolutionModel.setStringList(self._algorithm.getDeconvolutionStrategyList()) - self._view.deconvolutionStrategyComboBox.setModel(self._deconvolutionModel) - self._view.deconvolutionStrategyComboBox.textActivated.connect( - algorithm.setDeconvolutionStrategy + self._deconvolution_model = QStringListModel() + self._deconvolution_model.setStringList(self._algorithm.get_deconvolution_strategies()) + self._view.deconvolution_strategy_combo_box.setModel(self._deconvolution_model) + self._view.deconvolution_strategy_combo_box.textActivated.connect( + algorithm.set_deconvolution_strategy ) - self._syncModelToView() - algorithm.addObserver(self) + self._sync_model_to_view() + algorithm.add_observer(self) - def getWidget(self) -> QWidget: + def get_widget(self) -> QWidget: return self._view - def _syncModelToView(self) -> None: - self._view.upscalingStrategyComboBox.setCurrentText(self._algorithm.getUpscalingStrategy()) - self._view.deconvolutionStrategyComboBox.setCurrentText( - self._algorithm.getDeconvolutionStrategy() + def _sync_model_to_view(self) -> None: + self._view.upscaling_strategy_combo_box.setCurrentText( + self._algorithm.get_upscaling_strategy() + ) + self._view.deconvolution_strategy_combo_box.setCurrentText( + self._algorithm.get_deconvolution_strategy() ) - def update(self, observable: Observable) -> None: + def _update(self, observable: Observable) -> None: if observable is self._algorithm: - self._syncModelToView() + self._sync_model_to_view() class FluorescenceVSPIViewController(Observer): @@ -86,26 +90,30 @@ def __init__(self, algorithm: VSPIFluorescenceEnhancingAlgorithm) -> None: self._algorithm = algorithm self._view = FluorescenceVSPIParametersView() - self._view.dampingFactorLineEdit.valueChanged.connect(self._syncDampingFactorToModel) - self._view.maxIterationsSpinBox.setRange(1, self.MAX_INT) - self._view.maxIterationsSpinBox.valueChanged.connect(algorithm.setMaxIterations) + self._view.damping_factor_line_edit.value_changed.connect( + self._sync_damping_factor_to_model + ) + self._view.max_iterations_spin_box.setRange(1, self.MAX_INT) + self._view.max_iterations_spin_box.valueChanged.connect(algorithm.set_max_iterations) - algorithm.addObserver(self) - self._syncModelToView() + algorithm.add_observer(self) + self._sync_model_to_view() - def getWidget(self) -> QWidget: + def get_widget(self) -> QWidget: return self._view - def _syncDampingFactorToModel(self, value: Decimal) -> None: - self._algorithm.setDampingFactor(float(value)) + def _sync_damping_factor_to_model(self, value: Decimal) -> None: + self._algorithm.set_damping_factor(float(value)) - def _syncModelToView(self) -> None: - self._view.dampingFactorLineEdit.setValue(Decimal(repr(self._algorithm.getDampingFactor()))) - self._view.maxIterationsSpinBox.setValue(self._algorithm.getMaxIterations()) + def _sync_model_to_view(self) -> None: + self._view.damping_factor_line_edit.set_value( + Decimal(repr(self._algorithm.get_damping_factor())) + ) + self._view.max_iterations_spin_box.setValue(self._algorithm.get_max_iterations()) - def update(self, observable: Observable) -> None: + def _update(self, observable: Observable) -> None: if observable is self._algorithm: - self._syncModelToView() + self._sync_model_to_view() class FluorescenceViewController(Observer): @@ -113,168 +121,172 @@ def __init__( self, enhancer: FluorescenceEnhancer, engine: VisualizationEngine, - fileDialogFactory: FileDialogFactory, + file_dialog_factory: FileDialogFactory, ) -> None: super().__init__() self._enhancer = enhancer self._engine = engine - self._fileDialogFactory = fileDialogFactory + self._file_dialog_factory = file_dialog_factory self._dialog = FluorescenceDialog() - self._enhancementModel = QStringListModel() - self._enhancementModel.setStringList(self._enhancer.getAlgorithmList()) - self._channelListModel = FluorescenceChannelListModel(enhancer) + self._enhancement_model = QStringListModel() + self._enhancement_model.setStringList(self._enhancer.algorithms()) + self._channel_list_model = FluorescenceChannelListModel(enhancer) - self._dialog.fluorescenceParametersView.openButton.clicked.connect( - self._openMeasuredDataset + self._dialog.fluorescence_parameters_view.open_button.clicked.connect( + self._open_measured_dataset ) - twoStepViewController = FluorescenceTwoStepViewController( - enhancer.twoStepEnhancingAlgorithm + two_step_view_controller = FluorescenceTwoStepViewController( + enhancer.two_step_enhancing_algorithm ) - self._dialog.fluorescenceParametersView.algorithmComboBox.addItem( + self._dialog.fluorescence_parameters_view.algorithm_combo_box.addItem( TwoStepFluorescenceEnhancingAlgorithm.DISPLAY_NAME, - self._dialog.fluorescenceParametersView.algorithmComboBox.count(), + self._dialog.fluorescence_parameters_view.algorithm_combo_box.count(), ) - self._dialog.fluorescenceParametersView.stackedWidget.addWidget( - twoStepViewController.getWidget() + self._dialog.fluorescence_parameters_view.stacked_widget.addWidget( + two_step_view_controller.get_widget() ) - vspiViewController = FluorescenceVSPIViewController(enhancer.vspiEnhancingAlgorithm) - self._dialog.fluorescenceParametersView.algorithmComboBox.addItem( + vspi_view_controller = FluorescenceVSPIViewController(enhancer.vspi_enhancing_algorithm) + self._dialog.fluorescence_parameters_view.algorithm_combo_box.addItem( VSPIFluorescenceEnhancingAlgorithm.DISPLAY_NAME, - self._dialog.fluorescenceParametersView.algorithmComboBox.count(), + self._dialog.fluorescence_parameters_view.algorithm_combo_box.count(), ) - self._dialog.fluorescenceParametersView.stackedWidget.addWidget( - vspiViewController.getWidget() + self._dialog.fluorescence_parameters_view.stacked_widget.addWidget( + vspi_view_controller.get_widget() ) - self._dialog.fluorescenceParametersView.algorithmComboBox.textActivated.connect( - enhancer.setAlgorithm + self._dialog.fluorescence_parameters_view.algorithm_combo_box.textActivated.connect( + enhancer.set_algorithm + ) + self._dialog.fluorescence_parameters_view.algorithm_combo_box.currentIndexChanged.connect( + self._dialog.fluorescence_parameters_view.stacked_widget.setCurrentIndex ) - self._dialog.fluorescenceParametersView.algorithmComboBox.currentIndexChanged.connect( - self._dialog.fluorescenceParametersView.stackedWidget.setCurrentIndex + self._dialog.fluorescence_parameters_view.algorithm_combo_box.setModel( + self._enhancement_model ) - self._dialog.fluorescenceParametersView.algorithmComboBox.setModel(self._enhancementModel) - self._dialog.fluorescenceParametersView.algorithmComboBox.textActivated.connect( - enhancer.setAlgorithm + self._dialog.fluorescence_parameters_view.algorithm_combo_box.textActivated.connect( + enhancer.set_algorithm ) - self._dialog.fluorescenceParametersView.enhanceButton.clicked.connect( - self._enhanceFluorescence + self._dialog.fluorescence_parameters_view.enhance_button.clicked.connect( + self._enhance_fluorescence ) - self._dialog.fluorescenceParametersView.saveButton.clicked.connect( - self._saveEnhancedDataset + self._dialog.fluorescence_parameters_view.save_button.clicked.connect( + self._save_enhanced_dataset ) - self._dialog.fluorescenceChannelListView.setModel(self._channelListModel) - self._dialog.fluorescenceChannelListView.selectionModel().currentChanged.connect( - self._updateView + self._dialog.fluorescence_channel_list_view.setModel(self._channel_list_model) + self._dialog.fluorescence_channel_list_view.selectionModel().currentChanged.connect( + self._update_view ) - self._measuredWidgetController = VisualizationWidgetController( + self._measured_widget_controller = VisualizationWidgetController( engine, - self._dialog.measuredWidget, - self._dialog.statusBar, - fileDialogFactory, + self._dialog.measured_widget, + self._dialog.status_bar, + file_dialog_factory, ) - self._enhancedWidgetController = VisualizationWidgetController( + self._enhanced_widget_controller = VisualizationWidgetController( engine, - self._dialog.enhancedWidget, - self._dialog.statusBar, - fileDialogFactory, + self._dialog.enhanced_widget, + self._dialog.status_bar, + file_dialog_factory, ) - self._visualizationParametersController = VisualizationParametersController.createInstance( - engine, self._dialog.visualizationParametersView + self._visualization_parameters_controller = ( + VisualizationParametersController.create_instance( + engine, self._dialog.visualization_parameters_view + ) ) - enhancer.addObserver(self) + enhancer.add_observer(self) - def _openMeasuredDataset(self) -> None: + def _open_measured_dataset(self) -> None: title = 'Open Measured Fluorescence Dataset' - filePath, nameFilter = self._fileDialogFactory.getOpenFilePath( + file_path, name_filter = self._file_dialog_factory.get_open_file_path( self._dialog, title, - nameFilters=self._enhancer.getOpenFileFilterList(), - selectedNameFilter=self._enhancer.getOpenFileFilter(), + name_filters=[nf for nf in self._enhancer.get_open_file_filters()], + selected_name_filter=self._enhancer.get_open_file_filter(), ) - if filePath: + if file_path: try: - self._enhancer.openMeasuredDataset(filePath, nameFilter) + self._enhancer.open_measured_dataset(file_path, name_filter) except Exception as err: logger.exception(err) - ExceptionDialog.showException(title, err) + ExceptionDialog.show_exception(title, err) - def _enhanceFluorescence(self) -> None: + def _enhance_fluorescence(self) -> None: try: - self._enhancer.enhanceFluorescence() + self._enhancer.enhance_fluorescence() except Exception as err: logger.exception(err) - ExceptionDialog.showException('Enhance Fluorescence', err) + ExceptionDialog.show_exception('Enhance Fluorescence', err) - def launch(self, productIndex: int) -> None: - self._enhancer.setProduct(productIndex) + def launch(self, product_index: int) -> None: + self._enhancer.set_product(product_index) try: - itemName = self._enhancer.getProductName() + item_name = self._enhancer.get_product_name() except Exception as err: logger.exception(err) - ExceptionDialog.showException('Launch', err) + ExceptionDialog.show_exception('Launch', err) else: - self._dialog.setWindowTitle(f'Enhance Fluorescence: {itemName}') + self._dialog.setWindowTitle(f'Enhance Fluorescence: {item_name}') self._dialog.open() - def _saveEnhancedDataset(self) -> None: + def _save_enhanced_dataset(self) -> None: title = 'Save Enhanced Fluorescence Dataset' - filePath, nameFilter = self._fileDialogFactory.getSaveFilePath( + file_path, name_filter = self._file_dialog_factory.get_save_file_path( self._dialog, title, - nameFilters=self._enhancer.getSaveFileFilterList(), - selectedNameFilter=self._enhancer.getSaveFileFilter(), + name_filters=[nf for nf in self._enhancer.get_save_file_filters()], + selected_name_filter=self._enhancer.get_save_file_filter(), ) - if filePath: + if file_path: try: - self._enhancer.saveEnhancedDataset(filePath, nameFilter) + self._enhancer.save_enhanced_dataset(file_path, name_filter) except Exception as err: logger.exception(err) - ExceptionDialog.showException(title, err) + ExceptionDialog.show_exception(title, err) - def _syncModelToView(self) -> None: - self._dialog.fluorescenceParametersView.algorithmComboBox.setCurrentText( - self._enhancer.getAlgorithm() + def _sync_model_to_view(self) -> None: + self._dialog.fluorescence_parameters_view.algorithm_combo_box.setCurrentText( + self._enhancer.get_algorithm() ) - self._channelListModel.beginResetModel() - self._channelListModel.endResetModel() + self._channel_list_model.beginResetModel() + self._channel_list_model.endResetModel() - def _updateView(self, current: QModelIndex, previous: QModelIndex) -> None: + def _update_view(self, current: QModelIndex, previous: QModelIndex) -> None: if not current.isValid(): - self._measuredWidgetController.clearArray() - self._enhancedWidgetController.clearArray() + self._measured_widget_controller.clear_array() + self._enhanced_widget_controller.clear_array() return try: - emap_measured = self._enhancer.getMeasuredElementMap(current.row()) + emap_measured = self._enhancer.get_measured_element_map(current.row()) except Exception as err: logger.exception(err) - self._measuredWidgetController.clearArray() - ExceptionDialog.showException('Render Measured Element Map', err) + self._measured_widget_controller.clear_array() + ExceptionDialog.show_exception('Render Measured Element Map', err) else: - self._measuredWidgetController.setArray( - emap_measured.counts_per_second, self._enhancer.getPixelGeometry() + self._measured_widget_controller.set_array( + emap_measured.counts_per_second, self._enhancer.get_pixel_geometry() ) try: - emap_enhanced = self._enhancer.getEnhancedElementMap(current.row()) + emap_enhanced = self._enhancer.get_enhanced_element_map(current.row()) except Exception as err: logger.exception(err) - self._enhancedWidgetController.clearArray() - ExceptionDialog.showException('Render Enhanced Element Map', err) + self._enhanced_widget_controller.clear_array() + ExceptionDialog.show_exception('Render Enhanced Element Map', err) else: - self._enhancedWidgetController.setArray( - emap_enhanced.counts_per_second, self._enhancer.getPixelGeometry() + self._enhanced_widget_controller.set_array( + emap_enhanced.counts_per_second, self._enhancer.get_pixel_geometry() ) - def update(self, observable: Observable) -> None: + def _update(self, observable: Observable) -> None: if observable is self._enhancer: - self._syncModelToView() + self._sync_model_to_view() diff --git a/src/ptychodus/controller/probe/illumination.py b/src/ptychodus/controller/probe/illumination.py new file mode 100644 index 00000000..be5f364a --- /dev/null +++ b/src/ptychodus/controller/probe/illumination.py @@ -0,0 +1,95 @@ +import logging + +from ptychodus.api.observer import Observable, Observer + +from ...model.analysis import IlluminationMapper +from ...model.visualization import VisualizationEngine +from ...view.probe import IlluminationDialog +from ...view.widgets import ExceptionDialog +from ..data import FileDialogFactory +from ..visualization import ( + VisualizationParametersController, + VisualizationWidgetController, +) + +logger = logging.getLogger(__name__) + + +class IlluminationViewController(Observer): + def __init__( + self, + mapper: IlluminationMapper, + engine: VisualizationEngine, + file_dialog_factory: FileDialogFactory, + *, + is_developer_mode_enabled: bool, + ) -> None: + super().__init__() + self._mapper = mapper + self._file_dialog_factory = file_dialog_factory + self._dialog = IlluminationDialog() + self._dialog.exposure_parameters_view.setVisible(is_developer_mode_enabled) + self._dialog.exposure_quantity_view.setVisible(is_developer_mode_enabled) + self._dialog.save_button.clicked.connect(self._save_data) + self._visualization_widget_controller = VisualizationWidgetController( + engine, + self._dialog.visualization_widget, + self._dialog.status_bar, + file_dialog_factory, + ) + self._visualization_parameters_controller = ( + VisualizationParametersController.create_instance( + engine, self._dialog.visualization_parameters_view + ) + ) + + mapper.add_observer(self) + + def map(self, product_index: int) -> None: + self._mapper.set_product(product_index) + + try: + product_name = self._mapper.get_product_name() + except Exception as err: + logger.exception(err) + ExceptionDialog.show_exception('Illumination Mapper', err) + else: + self._dialog.setWindowTitle(f'Illumination Map: {product_name}') + self._dialog.open() + + try: + self._mapper.map() + except Exception as err: + logger.exception(err) + ExceptionDialog.show_exception('Illumination Mapper', err) + + def _save_data(self) -> None: + title = 'Save Illumination Map' + file_path, _ = self._file_dialog_factory.get_save_file_path( + self._dialog, + title, + name_filters=self._mapper.get_save_file_filters(), + selected_name_filter=self._mapper.get_save_file_filter(), + ) + + if file_path: + try: + self._mapper.save_data(file_path) + except Exception as err: + logger.exception(err) + ExceptionDialog.show_exception(title, err) + + def _sync_model_to_view(self) -> None: + try: + data = self._mapper.get_data() + except ValueError: + self._visualization_widget_controller.clear_array() + except Exception as err: + logger.exception(err) + ExceptionDialog.show_exception('Update Views', err) + else: + self._visualization_widget_controller.set_array(data.photon_number, data.pixel_geometry) + + def _update(self, observable: Observable) -> None: + if observable is self._mapper: + self._sync_model_to_view() diff --git a/src/ptychodus/controller/probe/propagator.py b/src/ptychodus/controller/probe/propagator.py index 2fb95cec..a32f5fe6 100644 --- a/src/ptychodus/controller/probe/propagator.py +++ b/src/ptychodus/controller/probe/propagator.py @@ -21,142 +21,153 @@ def __init__( self, propagator: ProbePropagator, engine: VisualizationEngine, - fileDialogFactory: FileDialogFactory, + file_dialog_factory: FileDialogFactory, ) -> None: super().__init__() self._propagator = propagator - self._fileDialogFactory = fileDialogFactory + self._file_dialog_factory = file_dialog_factory self._dialog = ProbePropagationDialog() - self._dialog.propagateButton.clicked.connect(self._propagate) - self._dialog.saveButton.clicked.connect(self._savePropagatedProbe) - self._dialog.coordinateSlider.valueChanged.connect(self._updateCurrentCoordinate) - self._dialog.parametersView.numberOfStepsSpinBox.setRange(1, 999) + self._dialog.propagate_button.clicked.connect(self._propagate) + self._dialog.save_button.clicked.connect(self._save_propagated_probe) + self._dialog.coordinate_slider.valueChanged.connect(self._update_current_coordinate) + self._dialog.parameters_view.num_steps_spin_box.setRange(1, 999) - self._xyVisualizationWidgetController = VisualizationWidgetController( - engine, self._dialog.xyView, self._dialog.statusBar, fileDialogFactory + self._xy_visualization_widget_controller = VisualizationWidgetController( + engine, self._dialog.xy_view, self._dialog.status_bar, file_dialog_factory ) - self._zxVisualizationWidgetController = VisualizationWidgetController( - engine, self._dialog.zxView, self._dialog.statusBar, fileDialogFactory + self._zx_visualization_widget_controller = VisualizationWidgetController( + engine, self._dialog.zx_view, self._dialog.status_bar, file_dialog_factory ) - self._visualizationParametersController = VisualizationParametersController.createInstance( - engine, self._dialog.parametersView.visualizationParametersView + self._visualization_parameters_controller = ( + VisualizationParametersController.create_instance( + engine, self._dialog.parameters_view.visualization_parameters_view + ) ) - self._zyVisualizationWidgetController = VisualizationWidgetController( - engine, self._dialog.zyView, self._dialog.statusBar, fileDialogFactory + self._zy_visualization_widget_controller = VisualizationWidgetController( + engine, self._dialog.zy_view, self._dialog.status_bar, file_dialog_factory ) - propagator.addObserver(self) - self._syncModelToView() + propagator.add_observer(self) + self._sync_model_to_view() - def _updateCurrentCoordinate(self, step: int) -> None: - lerpValue = 0.0 + def _update_current_coordinate(self, step: int) -> None: + lerp_value = 0.0 - slider = self._dialog.coordinateSlider + slider = self._dialog.coordinate_slider upper = step - slider.minimum() lower = slider.maximum() - slider.minimum() if lower > 0: alpha = upper / lower - z0 = self._propagator.getBeginCoordinateInMeters() - z1 = self._propagator.getEndCoordinateInMeters() - lerpValue = (1 - alpha) * z0 + alpha * z1 + z0 = self._propagator.get_begin_coordinate_m() + z1 = self._propagator.get_end_coordinate_m() + lerp_value = (1 - alpha) * z0 + alpha * z1 else: logger.error('Bad slider range!') try: - xyProjection = self._propagator.getXYProjection(step) + xy_projection = self._propagator.get_xy_projection(step) except (IndexError, ValueError): - self._xyVisualizationWidgetController.clearArray() + self._xy_visualization_widget_controller.clear_array() except Exception as err: logger.exception(err) - ExceptionDialog.showException('Update Current Coordinate', err) + ExceptionDialog.show_exception('Update Current Coordinate', err) else: - self._xyVisualizationWidgetController.setArray( - xyProjection, self._propagator.getPixelGeometry() - ) + pixel_geometry = self._propagator.get_pixel_geometry() + + if pixel_geometry is None: + logger.warning('Missing propagator pixel geometry!') + else: + self._xy_visualization_widget_controller.set_array(xy_projection, pixel_geometry) # TODO auto-units - lerpValue *= 1e6 - self._dialog.coordinateLabel.setText(f'{lerpValue:.1f} \u00b5m') + lerp_value *= 1e6 + self._dialog.coordinate_label.setText(f'{lerp_value:.1f} \u00b5m') def _propagate(self) -> None: - view = self._dialog.parametersView + view = self._dialog.parameters_view try: self._propagator.propagate( - numberOfSteps=view.numberOfStepsSpinBox.value(), - beginCoordinateInMeters=float(view.beginCoordinateWidget.getLengthInMeters()), - endCoordinateInMeters=float(view.endCoordinateWidget.getLengthInMeters()), + num_steps=view.num_steps_spin_box.value(), + begin_coordinate_m=float(view.begin_coordinate_widget.get_length_m()), + end_coordinate_m=float(view.end_coordinate_widget.get_length_m()), ) except Exception as err: logger.exception(err) - ExceptionDialog.showException('Propagate Probe', err) + ExceptionDialog.show_exception('Propagate Probe', err) - def launch(self, productIndex: int) -> None: - self._propagator.setProduct(productIndex) + def launch(self, product_index: int) -> None: + self._propagator.set_product(product_index) try: - itemName = self._propagator.getProductName() + item_name = self._propagator.get_product_name() except Exception as err: logger.exception(err) - ExceptionDialog.showException('Launch', err) + ExceptionDialog.show_exception('Launch', err) else: - self._dialog.setWindowTitle(f'Propagate Probe: {itemName}') + self._dialog.setWindowTitle(f'Propagate Probe: {item_name}') self._dialog.open() - def _savePropagatedProbe(self) -> None: + def _save_propagated_probe(self) -> None: title = 'Save Propagated Probe' - filePath, nameFilter = self._fileDialogFactory.getSaveFilePath( + file_path, name_filter = self._file_dialog_factory.get_save_file_path( self._dialog, title, - nameFilters=self._propagator.getSaveFileFilterList(), - selectedNameFilter=self._propagator.getSaveFileFilter(), + name_filters=self._propagator.get_save_file_filters(), + selected_name_filter=self._propagator.get_save_file_filter(), ) - if filePath: + if file_path: try: - self._propagator.savePropagatedProbe(filePath) + self._propagator.save_propagated_probe(file_path) except Exception as err: logger.exception(err) - ExceptionDialog.showException(title, err) + ExceptionDialog.show_exception(title, err) - def _syncModelToView(self) -> None: - view = self._dialog.parametersView - view.beginCoordinateWidget.setLengthInMeters( - Decimal.from_float(self._propagator.getBeginCoordinateInMeters()) + def _sync_model_to_view(self) -> None: + view = self._dialog.parameters_view + view.begin_coordinate_widget.set_length_m( + Decimal(repr(self._propagator.get_begin_coordinate_m())) ) - view.endCoordinateWidget.setLengthInMeters( - Decimal.from_float(self._propagator.getEndCoordinateInMeters()) + view.end_coordinate_widget.set_length_m( + Decimal(repr(self._propagator.get_end_coordinate_m())) ) - view.numberOfStepsSpinBox.setValue(self._propagator.getNumberOfSteps()) + view.num_steps_spin_box.setValue(self._propagator.get_num_steps()) - numberOfSteps = self._propagator.getNumberOfSteps() + num_steps = self._propagator.get_num_steps() - if numberOfSteps > 1: - self._dialog.coordinateSlider.setEnabled(True) - self._dialog.coordinateSlider.setRange(0, numberOfSteps - 1) + if num_steps > 1: + self._dialog.coordinate_slider.setEnabled(True) + self._dialog.coordinate_slider.setRange(0, num_steps - 1) else: - self._dialog.coordinateSlider.setEnabled(False) - self._dialog.coordinateSlider.setRange(0, 1) - self._dialog.coordinateSlider.setValue(0) + self._dialog.coordinate_slider.setEnabled(False) + self._dialog.coordinate_slider.setRange(0, 1) + self._dialog.coordinate_slider.setValue(0) + + self._update_current_coordinate(self._dialog.coordinate_slider.value()) + pixel_geometry = self._propagator.get_pixel_geometry() - self._updateCurrentCoordinate(self._dialog.coordinateSlider.value()) + if pixel_geometry is None: + logger.warning('Missing propagator pixel geometry!') + return try: - self._zxVisualizationWidgetController.setArray( - self._propagator.getZXProjection(), self._propagator.getPixelGeometry() + # vvv TODO display correct pixel geometry for projections vvv + self._zx_visualization_widget_controller.set_array( + self._propagator.get_zx_projection(), pixel_geometry ) - self._zyVisualizationWidgetController.setArray( - self._propagator.getZYProjection(), self._propagator.getPixelGeometry() + self._zy_visualization_widget_controller.set_array( + self._propagator.get_zy_projection(), pixel_geometry ) except ValueError: - self._zxVisualizationWidgetController.clearArray() - self._zyVisualizationWidgetController.clearArray() + self._zx_visualization_widget_controller.clear_array() + self._zy_visualization_widget_controller.clear_array() except Exception as err: logger.exception(err) - ExceptionDialog.showException('Update Views', err) + ExceptionDialog.show_exception('Update Views', err) - def update(self, observable: Observable) -> None: + def _update(self, observable: Observable) -> None: if observable is self._propagator: - self._syncModelToView() + self._sync_model_to_view() diff --git a/src/ptychodus/controller/probe/stxm.py b/src/ptychodus/controller/probe/stxm.py index 6799353d..9e88da86 100644 --- a/src/ptychodus/controller/probe/stxm.py +++ b/src/ptychodus/controller/probe/stxm.py @@ -20,68 +20,72 @@ def __init__( self, simulator: STXMSimulator, engine: VisualizationEngine, - fileDialogFactory: FileDialogFactory, + file_dialog_factory: FileDialogFactory, ) -> None: super().__init__() self._simulator = simulator - self._fileDialogFactory = fileDialogFactory - + self._file_dialog_factory = file_dialog_factory self._dialog = STXMDialog() - self._dialog.saveButton.clicked.connect(self._saveResult) - - self._visualizationWidgetController = VisualizationWidgetController( + self._dialog.save_button.clicked.connect(self._save_data) + self._visualization_widget_controller = VisualizationWidgetController( engine, - self._dialog.visualizationWidget, - self._dialog.statusBar, - fileDialogFactory, + self._dialog.visualization_widget, + self._dialog.status_bar, + file_dialog_factory, ) - self._visualizationParametersController = VisualizationParametersController.createInstance( - engine, self._dialog.visualizationParametersView + self._visualization_parameters_controller = ( + VisualizationParametersController.create_instance( + engine, self._dialog.visualization_parameters_view + ) ) - simulator.addObserver(self) + simulator.add_observer(self) - def launch(self, productIndex: int) -> None: - self._simulator.setProduct(productIndex) + def simulate(self, product_index: int) -> None: + self._simulator.set_product(product_index) try: - itemName = self._simulator.getProductName() + product_name = self._simulator.get_product_name() except Exception as err: logger.exception(err) - ExceptionDialog.showException('Launch', err) + ExceptionDialog.show_exception('Simulate STXM', err) else: - self._dialog.setWindowTitle(f'Simulate STXM: {itemName}') + self._dialog.setWindowTitle(f'Simulate STXM: {product_name}') self._dialog.open() - self._simulator.simulate() + try: + self._simulator.simulate() + except Exception as err: + logger.exception(err) + ExceptionDialog.show_exception('Simulate STXM', err) - def _saveResult(self) -> None: - title = 'Save STXM Image' - filePath, nameFilter = self._fileDialogFactory.getSaveFilePath( + def _save_data(self) -> None: + title = 'Save STXM Data' + file_path, _ = self._file_dialog_factory.get_save_file_path( self._dialog, title, - nameFilters=self._simulator.getSaveFileFilterList(), - selectedNameFilter=self._simulator.getSaveFileFilter(), + name_filters=self._simulator.get_save_file_filters(), + selected_name_filter=self._simulator.get_save_file_filter(), ) - if filePath: + if file_path: try: - self._simulator.saveImage(filePath) + self._simulator.save_data(file_path) except Exception as err: logger.exception(err) - ExceptionDialog.showException(title, err) + ExceptionDialog.show_exception(title, err) - def _syncModelToView(self) -> None: + def _sync_model_to_view(self) -> None: try: - image = self._simulator.getImage() + data = self._simulator.get_data() except ValueError: - self._visualizationWidgetController.clearArray() + self._visualization_widget_controller.clear_array() except Exception as err: logger.exception(err) - ExceptionDialog.showException('Update Views', err) + ExceptionDialog.show_exception('Update Views', err) else: - self._visualizationWidgetController.setArray(image.intensity, image.pixel_geometry) + self._visualization_widget_controller.set_array(data.intensity, data.pixel_geometry) - def update(self, observable: Observable) -> None: + def _update(self, observable: Observable) -> None: if observable is self._simulator: - self._syncModelToView() + self._sync_model_to_view() diff --git a/src/ptychodus/controller/probe/treeModel.py b/src/ptychodus/controller/probe/treeModel.py deleted file mode 100644 index 3c2f38ec..00000000 --- a/src/ptychodus/controller/probe/treeModel.py +++ /dev/null @@ -1,214 +0,0 @@ -from __future__ import annotations -from typing import Any, overload - -import numpy - -from PyQt5.QtCore import Qt, QAbstractItemModel, QModelIndex, QObject - -from ...model.product import ProbeAPI, ProbeRepository -from ...model.product.probe import ProbeRepositoryItem - - -class ProbeTreeNode: - def __init__(self, parent: ProbeTreeNode | None = None) -> None: - self.parent = parent - self.children: list[ProbeTreeNode] = list() - - def insertNode(self, index: int = -1) -> ProbeTreeNode: - node = ProbeTreeNode(self) - self.children.insert(index, node) - return node - - def removeNode(self, index: int = -1) -> ProbeTreeNode: - return self.children.pop(index) - - def row(self) -> int: - return 0 if self.parent is None else self.parent.children.index(self) - - -class ProbeTreeModel(QAbstractItemModel): - def __init__( - self, repository: ProbeRepository, api: ProbeAPI, parent: QObject | None = None - ) -> None: - super().__init__(parent) - self._repository = repository - self._api = api - self._treeRoot = ProbeTreeNode() - self._header = [ - 'Name', - 'Relative Power', - 'Builder', - 'Data Type', - 'Width [px]', - 'Height [px]', - 'Size [MB]', - ] - - for index, item in enumerate(repository): - self.insertItem(index, item) - - @staticmethod - def _appendModes(node: ProbeTreeNode, item: ProbeRepositoryItem) -> None: - object_ = item.getProbe() - - for layer in range(object_.numberOfModes): - node.insertNode() - - def insertItem(self, index: int, item: ProbeRepositoryItem) -> None: - self.beginInsertRows(QModelIndex(), index, index) - ProbeTreeModel._appendModes(self._treeRoot.insertNode(index), item) - self.endInsertRows() - - def updateItem(self, index: int, item: ProbeRepositoryItem) -> None: - topLeft = self.index(index, 0) - bottomRight = self.index(index, len(self._header)) - self.dataChanged.emit(topLeft, bottomRight) - - node = self._treeRoot.children[index] - numModesOld = len(node.children) - numModesNew = item.getProbe().numberOfModes - - if numModesOld < numModesNew: - self.beginInsertRows(topLeft, numModesOld, numModesNew) - - while len(node.children) < numModesNew: - node.insertNode() - - self.endInsertRows() - elif numModesOld > numModesNew: - self.beginRemoveRows(topLeft, numModesNew, numModesOld) - - while len(node.children) > numModesNew: - node.removeNode() - - self.endRemoveRows() - - childTopLeft = self.index(0, 0, topLeft) - childBottomRight = self.index(numModesNew, len(self._header), topLeft) - self.dataChanged.emit(childTopLeft, childBottomRight) - - def removeItem(self, index: int, item: ProbeRepositoryItem) -> None: - self.beginRemoveRows(QModelIndex(), index, index) - self._treeRoot.removeNode(index) - self.endRemoveRows() - - def headerData( - self, - section: int, - orientation: Qt.Orientation, - role: int = Qt.ItemDataRole.DisplayRole, - ) -> Any: - if orientation == Qt.Orientation.Horizontal and role == Qt.ItemDataRole.DisplayRole: - return self._header[section] - - @overload - def parent(self, index: QModelIndex) -> QModelIndex: ... - - @overload - def parent(self) -> QObject: ... - - def parent(self, index: QModelIndex | None = None) -> QModelIndex | QObject: - if index is None: - return super().parent() - elif index.isValid(): - node = index.internalPointer() - parentNode = node.parent - return ( - QModelIndex() - if parentNode is self._treeRoot - else self.createIndex(parentNode.row(), 0, parentNode) - ) - - return QModelIndex() - - def index(self, row: int, column: int, parent: QModelIndex = QModelIndex()) -> QModelIndex: - if self.hasIndex(row, column, parent): - if parent.isValid(): - parentNode = parent.internalPointer() - node = parentNode.children[row] - else: - node = self._treeRoot.children[row] - - return self.createIndex(row, column, node) - - return QModelIndex() - - def data(self, index: QModelIndex, role: int = Qt.ItemDataRole.DisplayRole) -> Any: - if not index.isValid(): - return None - - parent = index.parent() - - if parent.isValid(): - item = self._repository[parent.row()] - - if role == Qt.ItemDataRole.DisplayRole and index.column() == 0: - return f'Mode {index.row() + 1}' - elif role == Qt.ItemDataRole.UserRole and index.column() == 1: - probe = item.getProbe() - - try: - relativePower = probe.getModeRelativePower(index.row()) - except IndexError: - return -1 - - if numpy.isfinite(relativePower): - return int(100.0 * relativePower) - else: - item = self._repository[index.row()] - probe = item.getProbe() - - if role == Qt.ItemDataRole.DisplayRole: - if index.column() == 0: - return self._repository.getName(index.row()) - elif index.column() == 1: - return None - elif index.column() == 2: - return item.getBuilder().getName() - elif index.column() == 3: - return str(probe.dataType) - elif index.column() == 4: - return probe.widthInPixels - elif index.column() == 5: - return probe.heightInPixels - elif index.column() == 6: - return f'{probe.sizeInBytes / (1024 * 1024):.2f}' - elif role == Qt.ItemDataRole.UserRole and index.column() == 1: - probe = item.getProbe() - coherence = probe.getCoherence() - return int(100.0 * coherence) if numpy.isfinite(coherence) else -1 - - def flags(self, index: QModelIndex) -> Qt.ItemFlags: - value = super().flags(index) - - if index.isValid(): - parent = index.parent() - - if not parent.isValid() and index.column() in (0, 2): - value |= Qt.ItemFlag.ItemIsEditable - - return value - - def setData(self, index: QModelIndex, value: Any, role: int = Qt.ItemDataRole.EditRole) -> bool: - if index.isValid() and role == Qt.ItemDataRole.EditRole: - parent = index.parent() - - if not parent.isValid(): - if index.column() == 0: - self._repository.setName(index.row(), str(value)) - return True - elif index.column() == 2: - self._api.buildProbe(index.row(), str(value)) - return True - - return False - - def rowCount(self, parent: QModelIndex = QModelIndex()) -> int: - if parent.column() > 0: - return 0 - - node = parent.internalPointer() if parent.isValid() else self._treeRoot - return len(node.children) - - def columnCount(self, parent: QModelIndex = QModelIndex()) -> int: - return len(self._header) diff --git a/src/ptychodus/controller/probe/tree_model.py b/src/ptychodus/controller/probe/tree_model.py new file mode 100644 index 00000000..654008ee --- /dev/null +++ b/src/ptychodus/controller/probe/tree_model.py @@ -0,0 +1,235 @@ +from __future__ import annotations +from typing import Any, overload + +import numpy + +from PyQt5.QtCore import Qt, QAbstractItemModel, QModelIndex, QObject + +from ptychodus.api.probe import Probe +from ptychodus.api.units import BYTES_PER_MEGABYTE + +from ...model.product import ProbeAPI, ProbeRepository +from ...model.product.probe import ProbeRepositoryItem + + +class ProbeTreeNode: + def __init__(self, parent: ProbeTreeNode | None = None) -> None: + self.parent = parent + self.children: list[ProbeTreeNode] = list() + + def insert_node(self, index: int = -1) -> ProbeTreeNode: + node = ProbeTreeNode(self) + self.children.insert(index, node) + return node + + def remove_node(self, index: int = -1) -> ProbeTreeNode: + return self.children.pop(index) + + def row(self) -> int: + return 0 if self.parent is None else self.parent.children.index(self) + + +def calc_relative_power_percent(probe: Probe, imode: int) -> int: + try: + relative_power = probe.get_incoherent_mode_relative_power(imode) + except IndexError: + return -1 + + if numpy.isfinite(relative_power): + return int(100.0 * relative_power) + + return -1 + + +def calc_coherent_percent(probe: Probe) -> int: + coherence = probe.get_coherence() + return int(100.0 * coherence) if numpy.isfinite(coherence) else -1 + + +class ProbeTreeModel(QAbstractItemModel): + def __init__( + self, repository: ProbeRepository, api: ProbeAPI, parent: QObject | None = None + ) -> None: + super().__init__(parent) + self._repository = repository + self._api = api + self._tree_root = ProbeTreeNode() + self._header = [ + 'Name', + 'Relative Power', + 'Builder', + 'Data Type', + 'Width [px]', + 'Height [px]', + 'Size [MB]', + ] + + for index, item in enumerate(repository): + self.insert_item(index, item) + + @staticmethod + def _append_modes(node: ProbeTreeNode, item: ProbeRepositoryItem) -> None: + probe = item.get_probes() + + for layer in range(probe.num_incoherent_modes): + node.insert_node() + + def insert_item(self, index: int, item: ProbeRepositoryItem) -> None: + self.beginInsertRows(QModelIndex(), index, index) + ProbeTreeModel._append_modes(self._tree_root.insert_node(index), item) + self.endInsertRows() + + def update_item(self, index: int, item: ProbeRepositoryItem) -> None: + top_left = self.index(index, 0) + bottom_right = self.index(index, len(self._header)) + self.dataChanged.emit(top_left, bottom_right) + + node = self._tree_root.children[index] + num_modes_old = len(node.children) + num_modes_new = item.get_probes().num_incoherent_modes + + if num_modes_old < num_modes_new: + self.beginInsertRows(top_left, num_modes_old, num_modes_new) + + while len(node.children) < num_modes_new: + node.insert_node() + + self.endInsertRows() + elif num_modes_old > num_modes_new: + self.beginRemoveRows(top_left, num_modes_new, num_modes_old) + + while len(node.children) > num_modes_new: + node.remove_node() + + self.endRemoveRows() + + child_top_left = self.index(0, 0, top_left) + child_bottom_right = self.index(num_modes_new, len(self._header), top_left) + self.dataChanged.emit(child_top_left, child_bottom_right) + + def remove_item(self, index: int, item: ProbeRepositoryItem) -> None: + self.beginRemoveRows(QModelIndex(), index, index) + self._tree_root.remove_node(index) + self.endRemoveRows() + + def headerData( # noqa: N802 + self, + section: int, + orientation: Qt.Orientation, + role: int = Qt.ItemDataRole.DisplayRole, + ) -> Any: + if orientation == Qt.Orientation.Horizontal and role == Qt.ItemDataRole.DisplayRole: + return self._header[section] + + @overload + def parent(self, child: QModelIndex) -> QModelIndex: ... + + @overload + def parent(self) -> QObject: ... + + def parent(self, child: QModelIndex | None = None) -> QModelIndex | QObject: + if child is None: + return super().parent() + elif child.isValid(): + node = child.internalPointer() + parent_node = node.parent + return ( + QModelIndex() + if parent_node is self._tree_root + else self.createIndex(parent_node.row(), 0, parent_node) + ) + + return QModelIndex() + + def index(self, row: int, column: int, parent: QModelIndex = QModelIndex()) -> QModelIndex: + if self.hasIndex(row, column, parent): + if parent.isValid(): + parent_node = parent.internalPointer() + node = parent_node.children[row] + else: + node = self._tree_root.children[row] + + return self.createIndex(row, column, node) + + return QModelIndex() + + def data(self, index: QModelIndex, role: int = Qt.ItemDataRole.DisplayRole) -> Any: + if not index.isValid(): + return None + + parent = index.parent() + + if parent.isValid(): + item = self._repository[parent.row()] + + if role == Qt.ItemDataRole.DisplayRole: + match index.column(): + case 0: + return f'Mode {index.row() + 1}' + case 1: + probe = item.get_probes().get_probe_no_opr() # TODO OPR + power_percent = calc_relative_power_percent(probe, index.row()) + return f'{power_percent}%' + elif role == Qt.ItemDataRole.UserRole and index.column() == 1: + probe = item.get_probes().get_probe_no_opr() # TODO OPR + return calc_relative_power_percent(probe, index.row()) + else: + item = self._repository[index.row()] + probes = item.get_probes() + probe = probes.get_probe_no_opr() # TODO OPR + + if role == Qt.ItemDataRole.DisplayRole: + match index.column(): + case 0: + return self._repository.get_name(index.row()) + case 1: + coherent_percent = calc_coherent_percent(probe) + return f'{coherent_percent}%' + case 2: + return item.get_builder().get_name() + case 3: + return str(probe.dtype) + case 4: + return probe.width_px + case 5: + return probe.height_px + case 6: + return f'{probes.nbytes / BYTES_PER_MEGABYTE:.2f}' + elif role == Qt.ItemDataRole.UserRole and index.column() == 1: + probe = item.get_probes().get_probe_no_opr() # TODO OPR + return calc_coherent_percent(probe) + + def flags(self, index: QModelIndex) -> Qt.ItemFlags: + value = super().flags(index) + + if index.isValid(): + parent = index.parent() + + if not parent.isValid() and index.column() in (0, 2): + value |= Qt.ItemFlag.ItemIsEditable + + return value + + def setData(self, index: QModelIndex, value: Any, role: int = Qt.ItemDataRole.EditRole) -> bool: # noqa: N802 + if index.isValid() and role == Qt.ItemDataRole.EditRole: + parent = index.parent() + + if not parent.isValid(): + if index.column() == 0: + self._repository.set_name(index.row(), str(value)) + return True + elif index.column() == 2: + self._api.build_probe(index.row(), str(value)) + return True + + return False + + def rowCount(self, parent: QModelIndex = QModelIndex()) -> int: # noqa: N802 + if parent.column() > 0: + return 0 + + node = parent.internalPointer() if parent.isValid() else self._tree_root + return len(node.children) + + def columnCount(self, parent: QModelIndex = QModelIndex()) -> int: # noqa: N802 + return len(self._header) diff --git a/src/ptychodus/controller/probe/zernike.py b/src/ptychodus/controller/probe/zernike.py index 4e6ae036..127f91f9 100644 --- a/src/ptychodus/controller/probe/zernike.py +++ b/src/ptychodus/controller/probe/zernike.py @@ -29,7 +29,7 @@ def flags(self, index: QModelIndex) -> Qt.ItemFlags: return value - def headerData( + def headerData( # noqa: N802 self, section: int, orientation: Qt.Orientation, @@ -43,8 +43,8 @@ def data(self, index: QModelIndex, role: int = Qt.ItemDataRole.DisplayRole) -> A return None try: - poly = self._builder.getPolynomial(index.row()) - coef = self._builder.getCoefficient(index.row()) + poly = self._builder.get_polynomial(index.row()) + coef = self._builder.get_coefficient(index.row()) except IndexError as err: logger.exception(err) return None @@ -59,7 +59,7 @@ def data(self, index: QModelIndex, role: int = Qt.ItemDataRole.DisplayRole) -> A elif index.column() == 3: return f'{numpy.angle(coef):.6g}' - def setData(self, index: QModelIndex, value: Any, role: int = Qt.ItemDataRole.EditRole) -> bool: + def setData(self, index: QModelIndex, value: Any, role: int = Qt.ItemDataRole.EditRole) -> bool: # noqa: N802 if not index.isValid(): return False @@ -71,16 +71,16 @@ def setData(self, index: QModelIndex, value: Any, role: int = Qt.ItemDataRole.Ed return False try: - coef = self._builder.getCoefficient(index.row()) + coef = self._builder.get_coefficient(index.row()) except IndexError: return False try: - complexValue = amplitude * coef / numpy.absolute(coef) + complex_value = amplitude * coef / numpy.absolute(coef) except ZeroDivisionError: - complexValue = amplitude + 0j + complex_value = amplitude + 0j - self._builder.setCoefficient(index.row(), complexValue) + self._builder.set_coefficient(index.row(), complex_value) return True elif index.column() == 3: try: @@ -89,18 +89,18 @@ def setData(self, index: QModelIndex, value: Any, role: int = Qt.ItemDataRole.Ed return False try: - coef = self._builder.getCoefficient(index.row()) + coef = self._builder.get_coefficient(index.row()) except IndexError: return False - complexValue = numpy.absolute(coef) * numpy.exp(2j * numpy.pi * phase) - self._builder.setCoefficient(index.row(), complexValue) + complex_value = numpy.absolute(coef) * numpy.exp(2j * numpy.pi * phase) + self._builder.set_coefficient(index.row(), complex_value) return True return False - def rowCount(self, parent: QModelIndex = QModelIndex()) -> int: + def rowCount(self, parent: QModelIndex = QModelIndex()) -> int: # noqa: N802 return len(self._builder) - def columnCount(self, parent: QModelIndex = QModelIndex()) -> int: + def columnCount(self, parent: QModelIndex = QModelIndex()) -> int: # noqa: N802 return len(self._header) diff --git a/src/ptychodus/controller/product/core.py b/src/ptychodus/controller/product/core.py index fae48798..39ec637b 100644 --- a/src/ptychodus/controller/product/core.py +++ b/src/ptychodus/controller/product/core.py @@ -12,12 +12,15 @@ ) from PyQt5.QtWidgets import QAbstractItemView, QAction +from ptychodus.api.units import BYTES_PER_MEGABYTE + from ...model.product import ( ProductAPI, ProductRepository, ProductRepositoryItem, ProductRepositoryObserver, ) +from ...model.patterns import AssembledDiffractionDataset from ...model.product.metadata import MetadataRepositoryItem from ...model.product.object import ObjectRepositoryItem from ...model.product.probe import ProbeRepositoryItem @@ -38,8 +41,7 @@ def __init__(self, repository: ProductRepository, parent: QObject | None = None) 'Name', 'Detector-Object\nDistance [m]', 'Probe Energy\n[keV]', - 'Probe Photon\nFlux [ph/s]', - 'Exposure\nTime [s]', + 'Probe Photon\nCount', 'Pixel Width\n[nm]', 'Pixel Height\n[nm]', 'Size\n[MB]', @@ -48,12 +50,12 @@ def __init__(self, repository: ProductRepository, parent: QObject | None = None) def flags(self, index: QModelIndex) -> Qt.ItemFlags: value = super().flags(index) - if index.isValid() and index.column() < 5: + if index.isValid() and index.column() < 4: value |= Qt.ItemFlag.ItemIsEditable return value - def headerData( + def headerData( # noqa: N802 self, section: int, orientation: Qt.Orientation, @@ -70,29 +72,28 @@ def data(self, index: QModelIndex, role: int = Qt.ItemDataRole.DisplayRole) -> A logger.exception(err) return None - metadata = item.getMetadata() - geometry = item.getGeometry() + metadata_item = item.get_metadata_item() + geometry = item.get_geometry() if role == Qt.ItemDataRole.DisplayRole or role == Qt.ItemDataRole.EditRole: - if index.column() == 0: - return metadata.getName() - elif index.column() == 1: - return f'{metadata.detectorDistanceInMeters.getValue():.4g}' - elif index.column() == 2: - return f'{metadata.probeEnergyInElectronVolts.getValue() / 1e3:.4g}' - elif index.column() == 3: - return f'{metadata.probePhotonsPerSecond.getValue():.4g}' - elif index.column() == 4: - return f'{metadata.exposureTimeInSeconds.getValue():.4g}' - elif index.column() == 5: - return f'{geometry.objectPlanePixelWidthInMeters * 1e9:.4g}' - elif index.column() == 6: - return f'{geometry.objectPlanePixelHeightInMeters * 1e9:.4g}' - elif index.column() == 7: - product = item.getProduct() - return f'{product.sizeInBytes / (1024 * 1024):.2f}' - - def setData(self, index: QModelIndex, value: Any, role: int = Qt.ItemDataRole.EditRole) -> bool: + match index.column(): + case 0: + return metadata_item.name.get_value() + case 1: + return f'{metadata_item.detector_distance_m.get_value():.4g}' + case 2: + return f'{metadata_item.probe_energy_eV.get_value() / 1e3:.4g}' + case 3: + return f'{metadata_item.probe_photon_count.get_value():.4g}' + case 4: + return f'{geometry.object_plane_pixel_width_m * 1e9:.4g}' + case 5: + return f'{geometry.object_plane_pixel_height_m * 1e9:.4g}' + case 6: + product = item.get_product() + return f'{product.nbytes / BYTES_PER_MEGABYTE:.2f}' + + def setData(self, index: QModelIndex, value: Any, role: int = Qt.ItemDataRole.EditRole) -> bool: # noqa: N802 if index.isValid() and role == Qt.ItemDataRole.EditRole: try: item = self._repository[index.row()] @@ -100,245 +101,242 @@ def setData(self, index: QModelIndex, value: Any, role: int = Qt.ItemDataRole.Ed logger.exception(err) return False - metadata = item.getMetadata() + metadata_item = item.get_metadata_item() if index.column() == 0: - metadata.setName(str(value)) + metadata_item.name.set_value(str(value)) return True elif index.column() == 1: try: - distanceInM = float(value) + distance_m = float(value) except ValueError: return False - metadata.detectorDistanceInMeters.setValue(distanceInM) + metadata_item.detector_distance_m.set_value(distance_m) return True elif index.column() == 2: try: - energyInKEV = float(value) + energy_keV = float(value) # noqa: N806 except ValueError: return False - metadata.probeEnergyInElectronVolts.setValue(energyInKEV * 1e3) + metadata_item.probe_energy_eV.set_value(energy_keV * 1e3) return True elif index.column() == 3: try: - photonsPerSecond = float(value) - except ValueError: - return False - - metadata.probePhotonsPerSecond.setValue(photonsPerSecond) - return True - elif index.column() == 4: - try: - exposureTimeInSeconds = float(value) + photon_count = float(value) except ValueError: return False - metadata.exposureTimeInSeconds.setValue(exposureTimeInSeconds) + metadata_item.probe_photon_count.set_value(photon_count) return True return False - def rowCount(self, parent: QModelIndex = QModelIndex()) -> int: + def rowCount(self, parent: QModelIndex = QModelIndex()) -> int: # noqa: N802 return len(self._repository) - def columnCount(self, parent: QModelIndex = QModelIndex()) -> int: + def columnCount(self, parent: QModelIndex = QModelIndex()) -> int: # noqa: N802 return len(self._header) class ProductController(ProductRepositoryObserver): def __init__( self, + dataset: AssembledDiffractionDataset, repository: ProductRepository, api: ProductAPI, view: ProductView, - fileDialogFactory: FileDialogFactory, - duplicateAction: QAction, - tableModel: ProductRepositoryTableModel, - tableProxyModel: QSortFilterProxyModel, + file_dialog_factory: FileDialogFactory, + duplicate_action: QAction, + table_model: ProductRepositoryTableModel, + table_proxy_model: QSortFilterProxyModel, ) -> None: super().__init__() + self._dataset = dataset self._repository = repository self._api = api self._view = view - self._fileDialogFactory = fileDialogFactory - self._duplicateAction = duplicateAction - self._tableModel = tableModel - self._tableProxyModel = tableProxyModel + self._file_dialog_factory = file_dialog_factory + self._duplicate_action = duplicate_action + self._table_model = table_model + self._table_proxy_model = table_proxy_model @classmethod - def createInstance( + def create_instance( cls, + dataset: AssembledDiffractionDataset, repository: ProductRepository, api: ProductAPI, view: ProductView, - fileDialogFactory: FileDialogFactory, + file_dialog_factory: FileDialogFactory, ) -> ProductController: - openFileAction = view.buttonBox.insertMenu.addAction('Open File...') - createNewAction = view.buttonBox.insertMenu.addAction('Create New') - duplicateAction = view.buttonBox.insertMenu.addAction('Duplicate') - saveFileAction = view.buttonBox.saveMenu.addAction('Save File...') - syncToSettingsAction = view.buttonBox.saveMenu.addAction('Sync To Settings') + open_file_action = view.button_box.insert_menu.addAction('Open File...') + create_new_action = view.button_box.insert_menu.addAction('Create New') + duplicate_action = view.button_box.insert_menu.addAction('Duplicate') + save_file_action = view.button_box.save_menu.addAction('Save File...') + sync_to_settings_action = view.button_box.save_menu.addAction('Sync To Settings') - tableModel = ProductRepositoryTableModel(repository) - tableProxyModel = QSortFilterProxyModel() - tableProxyModel.setSourceModel(tableModel) + table_model = ProductRepositoryTableModel(repository) + table_proxy_model = QSortFilterProxyModel() + table_proxy_model.setSourceModel(table_model) controller = cls( + dataset, repository, api, view, - fileDialogFactory, - duplicateAction, - tableModel, - tableProxyModel, + file_dialog_factory, + duplicate_action, + table_model, + table_proxy_model, ) - repository.addObserver(controller) - controller._updateInfoText() - - view.tableView.setModel(tableProxyModel) - view.tableView.setSortingEnabled(True) - view.tableView.sortByColumn(0, Qt.SortOrder.AscendingOrder) - view.tableView.verticalHeader().hide() - view.tableView.setSelectionBehavior(QAbstractItemView.SelectionBehavior.SelectRows) - view.tableView.selectionModel().currentChanged.connect(controller._updateEnabledButtons) - controller._updateEnabledButtons(QModelIndex(), QModelIndex()) - - openFileAction.triggered.connect(controller._openProductFromFile) - createNewAction.triggered.connect(controller._createNewProduct) - duplicateAction.triggered.connect(controller._duplicateCurrentProduct) - saveFileAction.triggered.connect(controller._saveCurrentProductToFile) - syncToSettingsAction.triggered.connect(controller._syncCurrentProductToSettings) - - view.buttonBox.editButton.clicked.connect(controller._editCurrentProduct) - view.buttonBox.removeButton.clicked.connect(controller._removeCurrentProduct) + repository.add_observer(controller) + controller._update_info_text() + + view.table_view.setModel(table_proxy_model) + view.table_view.setSortingEnabled(True) + view.table_view.sortByColumn(0, Qt.SortOrder.AscendingOrder) + view.table_view.verticalHeader().hide() + view.table_view.setSelectionBehavior(QAbstractItemView.SelectionBehavior.SelectRows) + view.table_view.selectionModel().currentChanged.connect(controller._update_enabled_buttons) + controller._update_enabled_buttons(QModelIndex(), QModelIndex()) + + open_file_action.triggered.connect(controller._open_product_from_file) + create_new_action.triggered.connect(controller._create_new_product) + duplicate_action.triggered.connect(controller._duplicate_current_product) + save_file_action.triggered.connect(controller._save_current_product_to_file) + sync_to_settings_action.triggered.connect(controller._sync_current_product_to_settings) + + view.button_box.edit_button.clicked.connect(controller._edit_current_product) + view.button_box.remove_button.clicked.connect(controller._remove_current_product) return controller @property - def tableModel(self) -> QAbstractTableModel: - return self._tableModel + def table_model(self) -> QAbstractTableModel: + return self._table_model - def _getCurrentItemIndex(self) -> int: - proxyIndex = self._view.tableView.currentIndex() + def _get_current_item_index(self) -> int: + proxy_index = self._view.table_view.currentIndex() - if proxyIndex.isValid(): - modelIndex = self._tableProxyModel.mapToSource(proxyIndex) - return modelIndex.row() + if proxy_index.isValid(): + model_index = self._table_proxy_model.mapToSource(proxy_index) + return model_index.row() logger.warning('No current index!') return -1 - def _openProductFromFile(self) -> None: - filePath, nameFilter = self._fileDialogFactory.getOpenFilePath( + def _open_product_from_file(self) -> None: + file_path, name_filter = self._file_dialog_factory.get_open_file_path( self._view, 'Open Product', - nameFilters=self._api.getOpenFileFilterList(), - selectedNameFilter=self._api.getOpenFileFilter(), + name_filters=[nf for nf in self._api.get_open_file_filters()], + selected_name_filter=self._api.get_open_file_filter(), ) - if filePath: + if file_path: try: - self._api.openProduct(filePath, fileType=nameFilter) + self._api.open_product(file_path, file_type=name_filter) except Exception as err: logger.exception(err) - ExceptionDialog.showException('File Reader', err) + ExceptionDialog.show_exception('File Reader', err) - def _createNewProduct(self) -> None: - self._api.insertNewProduct() + def _create_new_product(self) -> None: + self._api.insert_new_product() - def _saveCurrentProductToFile(self) -> None: - current = self._tableProxyModel.mapToSource(self._view.tableView.currentIndex()) + def _save_current_product_to_file(self) -> None: + current = self._table_proxy_model.mapToSource(self._view.table_view.currentIndex()) if current.isValid(): - filePath, nameFilter = self._fileDialogFactory.getSaveFilePath( + file_path, name_filter = self._file_dialog_factory.get_save_file_path( self._view, 'Save Product', - nameFilters=self._api.getSaveFileFilterList(), - selectedNameFilter=self._api.getSaveFileFilter(), + name_filters=[nf for nf in self._api.get_save_file_filters()], + selected_name_filter=self._api.get_save_file_filter(), ) - if filePath: + if file_path: try: - self._api.saveProduct(current.row(), filePath, fileType=nameFilter) + self._api.save_product(current.row(), file_path, file_type=name_filter) except Exception as err: logger.exception(err) - ExceptionDialog.showException('File Writer', err) + ExceptionDialog.show_exception('File Writer', err) else: logger.error('No current item!') - def _syncCurrentProductToSettings(self) -> None: - itemIndex = self._getCurrentItemIndex() + def _sync_current_product_to_settings(self) -> None: + item_index = self._get_current_item_index() - if itemIndex < 0: + if item_index < 0: logger.warning('No current item!') else: - item = self._repository[itemIndex] - item.syncToSettings() + item = self._repository[item_index] + item.sync_to_settings() - def _duplicateCurrentProduct(self) -> None: - current = self._tableProxyModel.mapToSource(self._view.tableView.currentIndex()) + def _duplicate_current_product(self) -> None: + current = self._table_proxy_model.mapToSource(self._view.table_view.currentIndex()) if current.isValid(): - self._api.insertNewProduct(likeIndex=current.row()) + like_item = self._repository[current.row()] + self._api.insert_product(like_item.get_product()) else: logger.error('No current item!') - def _editCurrentProduct(self) -> None: - current = self._tableProxyModel.mapToSource(self._view.tableView.currentIndex()) + def _edit_current_product(self) -> None: + current = self._table_proxy_model.mapToSource(self._view.table_view.currentIndex()) if current.isValid(): product = self._repository[current.row()] - ProductEditorViewController.editProduct(product, self._view) + ProductEditorViewController.edit_product(self._dataset, product, self._view) else: logger.error('No current item!') - def _removeCurrentProduct(self) -> None: - current = self._tableProxyModel.mapToSource(self._view.tableView.currentIndex()) + def _remove_current_product(self) -> None: + current = self._table_proxy_model.mapToSource(self._view.table_view.currentIndex()) if current.isValid(): - self._repository.removeProduct(current.row()) + self._repository.remove_product(current.row()) else: logger.error('No current item!') - def _updateEnabledButtons(self, current: QModelIndex, previous: QModelIndex) -> None: + def _update_enabled_buttons(self, current: QModelIndex, previous: QModelIndex) -> None: enabled = current.isValid() - self._duplicateAction.setEnabled(enabled) - self._view.buttonBox.saveButton.setEnabled(enabled) - self._view.buttonBox.editButton.setEnabled(enabled) - self._view.buttonBox.removeButton.setEnabled(enabled) + self._duplicate_action.setEnabled(enabled) + self._view.button_box.save_button.setEnabled(enabled) + self._view.button_box.edit_button.setEnabled(enabled) + self._view.button_box.remove_button.setEnabled(enabled) - def _updateInfoText(self) -> None: - infoText = self._repository.getInfoText() - self._view.infoLabel.setText(infoText) + def _update_info_text(self) -> None: + info_text = self._repository.get_info_text() + self._view.info_label.setText(info_text) - def handleItemInserted(self, index: int, item: ProductRepositoryItem) -> None: + def handle_item_inserted(self, index: int, item: ProductRepositoryItem) -> None: parent = QModelIndex() - self._tableModel.beginInsertRows(parent, index, index) - self._tableModel.endInsertRows() - self._updateInfoText() + self._table_model.beginInsertRows(parent, index, index) + self._table_model.endInsertRows() + self._update_info_text() - def handleMetadataChanged(self, index: int, item: MetadataRepositoryItem) -> None: - topLeft = self._tableModel.index(index, 0) - bottomRight = self._tableModel.index(index, self._tableModel.columnCount() - 1) - self._tableModel.dataChanged.emit(topLeft, bottomRight, [Qt.ItemDataRole.DisplayRole]) - self._updateInfoText() + def handle_metadata_changed(self, index: int, item: MetadataRepositoryItem) -> None: + top_left = self._table_model.index(index, 0) + bottom_right = self._table_model.index(index, self._table_model.columnCount() - 1) + self._table_model.dataChanged.emit(top_left, bottom_right, [Qt.ItemDataRole.DisplayRole]) + self._update_info_text() - def handleScanChanged(self, index: int, item: ScanRepositoryItem) -> None: - self._updateInfoText() + def handle_scan_changed(self, index: int, item: ScanRepositoryItem) -> None: + self._update_info_text() - def handleProbeChanged(self, index: int, item: ProbeRepositoryItem) -> None: - self._updateInfoText() + def handle_probe_changed(self, index: int, item: ProbeRepositoryItem) -> None: + self._update_info_text() - def handleObjectChanged(self, index: int, item: ObjectRepositoryItem) -> None: - self._updateInfoText() + def handle_object_changed(self, index: int, item: ObjectRepositoryItem) -> None: + self._update_info_text() - def handleCostsChanged(self, index: int, costs: Sequence[float]) -> None: - self._updateInfoText() + def handle_costs_changed(self, index: int, costs: Sequence[float]) -> None: + self._update_info_text() - def handleItemRemoved(self, index: int, item: ProductRepositoryItem) -> None: + def handle_item_removed(self, index: int, item: ProductRepositoryItem) -> None: parent = QModelIndex() - self._tableModel.beginRemoveRows(parent, index, index) - self._tableModel.endRemoveRows() - self._updateInfoText() + self._table_model.beginRemoveRows(parent, index, index) + self._table_model.endRemoveRows() + self._update_info_text() diff --git a/src/ptychodus/controller/product/editor.py b/src/ptychodus/controller/product/editor.py index 2a8acb07..dcefa004 100644 --- a/src/ptychodus/controller/product/editor.py +++ b/src/ptychodus/controller/product/editor.py @@ -1,34 +1,49 @@ from typing import Any from PyQt5.QtCore import ( - Qt, QAbstractTableModel, QModelIndex, QObject, QSortFilterProxyModel, + Qt, ) from PyQt5.QtWidgets import QWidget from ptychodus.api.observer import Observable, Observer +from ...model.patterns import AssembledDiffractionDataset from ...model.product import ProductRepositoryItem from ...view.product import ProductEditorDialog class ProductPropertyTableModel(QAbstractTableModel): - def __init__(self, product: ProductRepositoryItem, parent: QObject | None = None) -> None: + def __init__(self, product_item: ProductRepositoryItem, parent: QObject | None = None) -> None: super().__init__(parent) - self._product = product + self._product_item = product_item self._header = ['Property', 'Value'] self._properties = [ 'Probe Wavelength [nm]', + 'Probe Wavenumber [1/nm]', + 'Probe Angular Wavenumber [rad/nm]', + 'Probe Photon Flux [ph/s]', 'Probe Power [W]', 'Object Plane Pixel Width [nm]', 'Object Plane Pixel Height [nm]', 'Fresnel Number', + 'Exposure Time [s]', + 'Mass Attenuation [m\u00b2/kg]', + 'Tomography Angle [deg]', ] - def headerData( + def flags(self, index: QModelIndex) -> Qt.ItemFlags: + value = super().flags(index) + + if index.isValid() and index.row() in (8, 9, 10): + value |= Qt.ItemFlag.ItemIsEditable + + return value + + def headerData( # noqa: N802 self, section: int, orientation: Qt.Orientation, @@ -39,78 +54,143 @@ def headerData( def data(self, index: QModelIndex, role: int = Qt.ItemDataRole.DisplayRole) -> Any: if index.isValid() and role == Qt.ItemDataRole.DisplayRole: - if index.column() == 0: - return self._properties[index.row()] - elif index.column() == 1: - geometry = self._product.getGeometry() - - if index.row() == 0: - return f'{geometry.probeWavelengthInMeters * 1e9:.4g}' - elif index.row() == 1: - return f'{geometry.probePowerInWatts:.4g}' - elif index.row() == 2: - return f'{geometry.objectPlanePixelWidthInMeters * 1e9:.4g}' - elif index.row() == 3: - return f'{geometry.objectPlanePixelHeightInMeters * 1e9:.4g}' - elif index.row() == 4: - return f'{geometry.fresnelNumber:.4g}' - - def rowCount(self, parent: QModelIndex = QModelIndex()) -> int: + match index.column(): + case 0: + return self._properties[index.row()] + case 1: + metadata_item = self._product_item.get_metadata_item() + geometry = self._product_item.get_geometry() + + match index.row(): + case 0: + return f'{geometry.probe_wavelength_m * 1e9:.4g}' + case 1: + return f'{geometry.probe_wavelengths_per_m * 1e-9:.4g}' + case 2: + return f'{geometry.probe_radians_per_m * 1e-9:.4g}' + case 3: + return f'{geometry.probe_photons_per_s:.4g}' + case 4: + return f'{geometry.probe_power_W:.4g}' + case 5: + return f'{geometry.object_plane_pixel_width_m * 1e9:.4g}' + case 6: + return f'{geometry.object_plane_pixel_height_m * 1e9:.4g}' + case 7: + try: + return f'{geometry.fresnel_number:.4g}' + except ZeroDivisionError: + return 'inf' + case 8: + return f'{metadata_item.exposure_time_s.get_value():.4g}' + case 9: + return f'{metadata_item.mass_attenuation_m2_kg.get_value():.4g}' + case 10: + return f'{metadata_item.tomography_angle_deg.get_value():.4g}' + + def setData(self, index: QModelIndex, value: Any, role: int = Qt.ItemDataRole.EditRole) -> bool: # noqa: N802 + if index.isValid() and role == Qt.ItemDataRole.EditRole: + metadata_item = self._product_item.get_metadata_item() + + match index.row(): + case 8: + try: + exposure_time_s = float(value) + except ValueError: + return False + + metadata_item.exposure_time_s.set_value(exposure_time_s) + return True + case 9: + try: + mass_attenuation_m2_kg = float(value) + except ValueError: + return False + + metadata_item.mass_attenuation_m2_kg.set_value(mass_attenuation_m2_kg) + return True + case 10: + try: + tomography_angle_deg = float(value) + except ValueError: + return False + + metadata_item.tomography_angle_deg.set_value(tomography_angle_deg) + return True + + return False + + def rowCount(self, parent: QModelIndex = QModelIndex()) -> int: # noqa: N802 return len(self._properties) - def columnCount(self, parent: QModelIndex = QModelIndex()) -> int: + def columnCount(self, parent: QModelIndex = QModelIndex()) -> int: # noqa: N802 return len(self._header) class ProductEditorViewController(Observer): def __init__( self, + dataset: AssembledDiffractionDataset, product: ProductRepositoryItem, - tableModel: ProductPropertyTableModel, + table_model: ProductPropertyTableModel, dialog: ProductEditorDialog, ) -> None: super().__init__() + self._dataset = dataset self._product = product - self._tableModel = tableModel + self._table_model = table_model self._dialog = dialog @classmethod - def editProduct(cls, product: ProductRepositoryItem, parent: QWidget) -> None: - tableModel = ProductPropertyTableModel(product) - tableProxyModel = QSortFilterProxyModel() - tableProxyModel.setSourceModel(tableModel) + def edit_product( + cls, dataset: AssembledDiffractionDataset, product: ProductRepositoryItem, parent: QWidget + ) -> None: + table_model = ProductPropertyTableModel(product) + table_proxy_model = QSortFilterProxyModel() + table_proxy_model.setSourceModel(table_model) dialog = ProductEditorDialog(parent) - dialog.setWindowTitle(f'Edit Product: {product.getName()}') - dialog.tableView.setModel(tableProxyModel) - dialog.tableView.setSortingEnabled(True) - dialog.tableView.verticalHeader().hide() - dialog.tableView.resizeColumnsToContents() - dialog.tableView.resizeRowsToContents() - - viewController = cls(product, tableModel, dialog) - product.addObserver(viewController) - dialog.textEdit.textChanged.connect(viewController._syncViewToModel) - - viewController._syncModelToView() - dialog.finished.connect(viewController._finish) + dialog.setWindowTitle(f'Edit Product: {product.get_name()}') + dialog.table_view.setModel(table_proxy_model) + dialog.table_view.setSortingEnabled(True) + dialog.table_view.verticalHeader().hide() + dialog.table_view.resizeColumnsToContents() + dialog.table_view.resizeRowsToContents() + + view_controller = cls(dataset, product, table_model, dialog) + product.add_observer(view_controller) + dialog.text_edit.textChanged.connect(view_controller._sync_view_to_model) + + view_controller._sync_model_to_view() + + dialog.actions_view.estimate_probe_photon_count_button.clicked.connect( + view_controller._estimate_probe_photon_count + ) + dialog.finished.connect(view_controller._finish) dialog.open() dialog.adjustSize() - def _syncViewToModel(self) -> None: - metadata = self._product.getMetadata() - metadata.comments.setValue(self._dialog.textEdit.toPlainText()) + def _sync_view_to_model(self) -> None: + metadata = self._product.get_metadata_item() + metadata.comments.set_value(self._dialog.text_edit.toPlainText()) + + def _sync_model_to_view(self) -> None: + self._table_model.beginResetModel() + self._table_model.endResetModel() + + metadata = self._product.get_metadata_item() + self._dialog.text_edit.setPlainText(metadata.comments.get_value()) - def _syncModelToView(self) -> None: - self._tableModel.beginResetModel() - self._tableModel.endResetModel() + def _estimate_probe_photon_count(self) -> None: + metadata = self._product.get_metadata_item() + metadata.probe_photon_count.set_value(self._dataset.get_maximum_pattern_counts()) - metadata = self._product.getMetadata() - self._dialog.textEdit.setPlainText(metadata.comments.getValue()) + self._table_model.beginResetModel() + self._table_model.endResetModel() def _finish(self, result: int) -> None: - self._product.removeObserver(self) + self._product.remove_observer(self) - def update(self, observable: Observable) -> None: + def _update(self, observable: Observable) -> None: if observable is self._product: - self._syncModelToView() + self._sync_model_to_view() diff --git a/src/ptychodus/controller/ptychi/__init__.py b/src/ptychodus/controller/ptychi/__init__.py new file mode 100644 index 00000000..f85a0981 --- /dev/null +++ b/src/ptychodus/controller/ptychi/__init__.py @@ -0,0 +1,5 @@ +from .core import PtyChiViewControllerFactory + +__all__ = [ + 'PtyChiViewControllerFactory', +] diff --git a/src/ptychodus/controller/ptychi/core.py b/src/ptychodus/controller/ptychi/core.py new file mode 100644 index 00000000..9f07ecf9 --- /dev/null +++ b/src/ptychodus/controller/ptychi/core.py @@ -0,0 +1,97 @@ +from PyQt5.QtWidgets import QVBoxLayout, QWidget + +from ...model.ptychi import ( + PtyChiAutodiffSettings, + PtyChiDMSettings, + PtyChiLSQMLSettings, + PtyChiPIESettings, + PtyChiReconstructorLibrary, +) + +from ..reconstructor import ReconstructorViewControllerFactory +from .object import PtyChiObjectViewController +from .opr import PtyChiOPRViewController +from .positions import PtyChiProbePositionsViewController +from .probe import PtyChiProbeViewController +from .reconstructor import PtyChiReconstructorViewController + +__all__ = ['PtyChiViewControllerFactory'] + + +class PtyChiViewController(QWidget): + def __init__( + self, + model: PtyChiReconstructorLibrary, + reconstructor_name: str, + parent: QWidget | None = None, + ) -> None: + super().__init__(parent) + autodiff_settings: PtyChiAutodiffSettings | None = None + dm_settings: PtyChiDMSettings | None = None + lsqml_settings: PtyChiLSQMLSettings | None = None + pie_settings: PtyChiPIESettings | None = None + + match reconstructor_name: + case 'Autodiff': + autodiff_settings = model.autodiff_settings + case 'DM': + dm_settings = model.dm_settings + case 'LSQML': + lsqml_settings = model.lsqml_settings + case 'PIE' | 'ePIE' | 'rPIE': + pie_settings = model.pie_settings + + self._reconstructor_view_controller = PtyChiReconstructorViewController( + model.settings, + autodiff_settings, + dm_settings, + lsqml_settings, + model.enumerators, + model.device_repository, + ) + self._object_view_controller = PtyChiObjectViewController( + model.object_settings, + dm_settings, + lsqml_settings, + pie_settings, + model.settings.num_epochs, + model.enumerators, + ) + self._probe_view_controller = PtyChiProbeViewController( + model.probe_settings, + dm_settings, + lsqml_settings, + pie_settings, + model.settings.num_epochs, + model.enumerators, + ) + self._probe_positions_view_controller = PtyChiProbePositionsViewController( + model.probe_position_settings, + model.settings.num_epochs, + model.enumerators, + ) + self._opr_view_controller = PtyChiOPRViewController( + model.opr_settings, model.settings.num_epochs, model.enumerators + ) + + layout = QVBoxLayout() + layout.addWidget(self._reconstructor_view_controller.get_widget()) + layout.addWidget(self._object_view_controller.get_widget()) + layout.addWidget(self._probe_view_controller.get_widget()) + layout.addWidget(self._probe_positions_view_controller.get_widget()) + layout.addWidget(self._opr_view_controller.get_widget()) + layout.addStretch() + self.setLayout(layout) + + +class PtyChiViewControllerFactory(ReconstructorViewControllerFactory): + def __init__(self, model: PtyChiReconstructorLibrary) -> None: + super().__init__() + self._model = model + + @property + def backend_name(self) -> str: + return 'pty-chi' + + def create_view_controller(self, reconstructor_name: str) -> QWidget: + return PtyChiViewController(self._model, reconstructor_name) diff --git a/src/ptychodus/controller/ptychi/object.py b/src/ptychodus/controller/ptychi/object.py new file mode 100644 index 00000000..516a45df --- /dev/null +++ b/src/ptychodus/controller/ptychi/object.py @@ -0,0 +1,472 @@ +from PyQt5.QtWidgets import QFormLayout + +from ptychodus.api.parametric import ( + BooleanParameter, + IntegerParameter, + RealParameter, + StringParameter, +) + +from ...model.ptychi import ( + PtyChiDMSettings, + PtyChiEnumerators, + PtyChiLSQMLSettings, + PtyChiObjectSettings, + PtyChiPIESettings, +) +from ..parametric import ( + CheckBoxParameterViewController, + CheckableGroupBoxParameterViewController, + ComboBoxParameterViewController, + DecimalLineEditParameterViewController, + DecimalSliderParameterViewController, + LengthWidgetParameterViewController, + SpinBoxParameterViewController, +) +from .optimizer import PtyChiOptimizationPlanViewController, PtyChiOptimizerParameterViewController + +__all__ = ['PtyChiObjectViewController'] + + +class PtyChiOptimizeSliceSpacingViewController(CheckableGroupBoxParameterViewController): + def __init__( + self, + optimize_slice_spacing: BooleanParameter, + start: IntegerParameter, + stop: IntegerParameter, + stride: IntegerParameter, + optimizer: StringParameter, + step_size: RealParameter, + num_epochs: IntegerParameter, + enumerators: PtyChiEnumerators, + ) -> None: + super().__init__( + optimize_slice_spacing, + 'Optimize Slice Spacing', + tool_tip='Whether to optimize the slice spacing', + ) + self._plan_view_controller = PtyChiOptimizationPlanViewController( + start, + stop, + stride, + num_epochs, + ) + self._optimizer_view_controller = PtyChiOptimizerParameterViewController( + optimizer, enumerators + ) + self._step_size_view_controller = DecimalLineEditParameterViewController( + step_size, tool_tip='Optimizer step size' + ) + + layout = QFormLayout() + layout.addRow('Plan:', self._plan_view_controller.get_widget()) + layout.addRow('Optimizer:', self._optimizer_view_controller.get_widget()) + layout.addRow('Step Size:', self._step_size_view_controller.get_widget()) + self.get_widget().setLayout(layout) + + +class PtyChiConstrainL1NormViewController(CheckableGroupBoxParameterViewController): + def __init__( + self, + constrain_l1_norm: BooleanParameter, + start: IntegerParameter, + stop: IntegerParameter, + stride: IntegerParameter, + weight: RealParameter, + num_epochs: IntegerParameter, + ) -> None: + super().__init__( + constrain_l1_norm, + 'Constrain L\u2081 Norm', + tool_tip='Whether to constrain the L\u2081 norm', + ) + self._plan_view_controller = PtyChiOptimizationPlanViewController( + start, stop, stride, num_epochs + ) + self._weight_view_controller = DecimalLineEditParameterViewController( + weight, + tool_tip='Weight of the L\u2081 norm constraint', + ) + + layout = QFormLayout() + layout.addRow('Plan:', self._plan_view_controller.get_widget()) + layout.addRow('Weight:', self._weight_view_controller.get_widget()) + self.get_widget().setLayout(layout) + + +class PtyChiConstrainL2NormViewController(CheckableGroupBoxParameterViewController): + def __init__( + self, + constrain_l2_norm: BooleanParameter, + start: IntegerParameter, + stop: IntegerParameter, + stride: IntegerParameter, + weight: RealParameter, + num_epochs: IntegerParameter, + ) -> None: + super().__init__( + constrain_l2_norm, + 'Constrain L\u2082 Norm', + tool_tip='Whether to constrain the L\u2082 norm', + ) + self._plan_view_controller = PtyChiOptimizationPlanViewController( + start, stop, stride, num_epochs + ) + self._weight_view_controller = DecimalLineEditParameterViewController( + weight, + tool_tip='Weight of the L\u2082 norm constraint', + ) + + layout = QFormLayout() + layout.addRow('Plan:', self._plan_view_controller.get_widget()) + layout.addRow('Weight:', self._weight_view_controller.get_widget()) + self.get_widget().setLayout(layout) + + +class PtyChiConstrainSmoothnessViewController(CheckableGroupBoxParameterViewController): + def __init__( + self, + constrain_smoothness: BooleanParameter, + start: IntegerParameter, + stop: IntegerParameter, + stride: IntegerParameter, + alpha: RealParameter, + num_epochs: IntegerParameter, + ) -> None: + super().__init__( + constrain_smoothness, + 'Constrain Smoothness', + tool_tip='Whether to constrain smoothness in the magnitude (but not phase) of the object', + ) + self._plan_view_controller = PtyChiOptimizationPlanViewController( + start, stop, stride, num_epochs + ) + self._alpha_view_controller = DecimalSliderParameterViewController( + alpha, tool_tip='Relaxation smoothing constant' + ) + + layout = QFormLayout() + layout.addRow('Plan:', self._plan_view_controller.get_widget()) + layout.addRow('Alpha:', self._alpha_view_controller.get_widget()) + self.get_widget().setLayout(layout) + + +class PtyChiConstrainTotalVariationViewController(CheckableGroupBoxParameterViewController): + def __init__( + self, + constrain_total_variation: BooleanParameter, + start: IntegerParameter, + stop: IntegerParameter, + stride: IntegerParameter, + weight: RealParameter, + num_epochs: IntegerParameter, + ) -> None: + super().__init__( + constrain_total_variation, + 'Constrain Total Variation', + tool_tip='Whether to constrain the total variation', + ) + self._plan_view_controller = PtyChiOptimizationPlanViewController( + start, stop, stride, num_epochs + ) + self._weight_view_controller = DecimalLineEditParameterViewController( + weight, + tool_tip='Weight of the total variation constraint', + ) + + layout = QFormLayout() + layout.addRow('Plan:', self._plan_view_controller.get_widget()) + layout.addRow('Weight:', self._weight_view_controller.get_widget()) + self.get_widget().setLayout(layout) + + +class PtyChiRemoveGridArtifactsViewController(CheckableGroupBoxParameterViewController): + def __init__( + self, + remove_grid_artifacts: BooleanParameter, + start: IntegerParameter, + stop: IntegerParameter, + stride: IntegerParameter, + period_x_m: RealParameter, + period_y_m: RealParameter, + window_size_px: IntegerParameter, + direction: StringParameter, + num_epochs: IntegerParameter, + enumerators: PtyChiEnumerators, + ) -> None: + super().__init__( + remove_grid_artifacts, + 'Remove Grid Artifacts', + tool_tip="Whether to remove grid artifacts in the object's phase at the end of an epoch.", + ) + self._plan_view_controller = PtyChiOptimizationPlanViewController( + start, stop, stride, num_epochs + ) + self._period_x_view_controller = LengthWidgetParameterViewController( + period_x_m, tool_tip='Horizontal period of grid artifacts in meters' + ) + self._period_y_view_controller = LengthWidgetParameterViewController( + period_y_m, tool_tip='Vertical period of grid artifacts in meters' + ) + self._window_size_view_controller = SpinBoxParameterViewController( + window_size_px, tool_tip='Window size for grid artifact removal in pixels' + ) + self._direction_view_controller = ComboBoxParameterViewController( + direction, enumerators.directions(), tool_tip='Direction of grid artifact removal' + ) + + layout = QFormLayout() + layout.addRow('Plan:', self._plan_view_controller.get_widget()) + layout.addRow('Period X:', self._period_x_view_controller.get_widget()) + layout.addRow('Period Y:', self._period_y_view_controller.get_widget()) + layout.addRow('Window Size [px]:', self._window_size_view_controller.get_widget()) + layout.addRow('Direction:', self._direction_view_controller.get_widget()) + self.get_widget().setLayout(layout) + + +class PtyChiRegularizeMultisliceViewController(CheckableGroupBoxParameterViewController): + def __init__( + self, + regularize_multislice: BooleanParameter, + start: IntegerParameter, + stop: IntegerParameter, + stride: IntegerParameter, + weight: RealParameter, + unwrap_phase: BooleanParameter, + gradient_method: StringParameter, + integration_method: StringParameter, + num_epochs: IntegerParameter, + enumerators: PtyChiEnumerators, + ) -> None: + super().__init__( + regularize_multislice, + 'Regularize Multislice', + tool_tip='Whether to regularize multislice objects using cross-slice smoothing', + ) + self._plan_view_controller = PtyChiOptimizationPlanViewController( + start, stop, stride, num_epochs + ) + self._weight_view_controller = DecimalLineEditParameterViewController( + weight, + tool_tip='Weight for multislice regularization', + ) + self._unwrap_phase_view_controller = CheckBoxParameterViewController( + unwrap_phase, + 'Unwrap Phase', + tool_tip='Whether to unwrap the phase of the object during multislice regularization', + ) + self._gradient_method_view_controller = ComboBoxParameterViewController( + gradient_method, + enumerators.image_gradient_methods(), + tool_tip='Method for calculating the phase gradient during phase unwrapping', + ) + self._integration_method_view_controller = ComboBoxParameterViewController( + integration_method, + enumerators.image_integration_methods(), + tool_tip='Method for integrating the phase gradient during phase unwrapping', + ) + + layout = QFormLayout() + layout.addRow('Plan:', self._plan_view_controller.get_widget()) + layout.addRow('Weight:', self._weight_view_controller.get_widget()) + layout.addRow(self._unwrap_phase_view_controller.get_widget()) + layout.addRow('Gradient Method:', self._gradient_method_view_controller.get_widget()) + layout.addRow('Integration Method:', self._integration_method_view_controller.get_widget()) + self.get_widget().setLayout(layout) + + +class PtyChiRemoveObjectProbeAmbiguityViewController(CheckableGroupBoxParameterViewController): + def __init__( + self, + remove_object_probe_ambiguity: BooleanParameter, + start: IntegerParameter, + stop: IntegerParameter, + stride: IntegerParameter, + num_epochs: IntegerParameter, + ) -> None: + super().__init__( + remove_object_probe_ambiguity, + 'Remove Object Probe Ambiguity', + tool_tip='Whether to remove object-probe ambiguity', + ) + self._plan_view_controller = PtyChiOptimizationPlanViewController( + start, stop, stride, num_epochs + ) + + layout = QFormLayout() + layout.addRow('Plan:', self._plan_view_controller.get_widget()) + self.get_widget().setLayout(layout) + + +class PtyChiObjectViewController(CheckableGroupBoxParameterViewController): + def __init__( + self, + settings: PtyChiObjectSettings, + dm_settings: PtyChiDMSettings | None, + lsqml_settings: PtyChiLSQMLSettings | None, + pie_settings: PtyChiPIESettings | None, + num_epochs: IntegerParameter, + enumerators: PtyChiEnumerators, + ) -> None: + super().__init__( + settings.is_optimizable, + 'Optimize Object', + tool_tip='Whether the object is optimizable', + ) + self._optimization_plan_view_controller = PtyChiOptimizationPlanViewController( + settings.optimization_plan_start, + settings.optimization_plan_stop, + settings.optimization_plan_stride, + num_epochs, + ) + self._optimizer_view_controller = PtyChiOptimizerParameterViewController( + settings.optimizer, enumerators + ) + self._step_size_view_controller = DecimalLineEditParameterViewController( + settings.step_size, tool_tip='Optimizer step size' + ) + self._optimize_slice_spacing_view_controller = PtyChiOptimizeSliceSpacingViewController( + settings.optimize_slice_spacing, + settings.optimize_slice_spacing_start, + settings.optimize_slice_spacing_stop, + settings.optimize_slice_spacing_stride, + settings.optimize_slice_spacing_optimizer, + settings.optimize_slice_spacing_step_size, + num_epochs, + enumerators, + ) + self._constrain_l1_norm_view_controller = PtyChiConstrainL1NormViewController( + settings.constrain_l1_norm, + settings.constrain_l1_norm_start, + settings.constrain_l1_norm_stop, + settings.constrain_l1_norm_stride, + settings.constrain_l1_norm_weight, + num_epochs, + ) + self._constrain_l2_norm_view_controller = PtyChiConstrainL2NormViewController( + settings.constrain_l2_norm, + settings.constrain_l2_norm_start, + settings.constrain_l2_norm_stop, + settings.constrain_l2_norm_stride, + settings.constrain_l2_norm_weight, + num_epochs, + ) + self._constrain_smoothness_view_controller = PtyChiConstrainSmoothnessViewController( + settings.constrain_smoothness, + settings.constrain_smoothness_start, + settings.constrain_smoothness_stop, + settings.constrain_smoothness_stride, + settings.constrain_smoothness_alpha, + num_epochs, + ) + self._constrain_total_variation_view_controller = ( + PtyChiConstrainTotalVariationViewController( + settings.constrain_total_variation, + settings.constrain_total_variation_start, + settings.constrain_total_variation_stop, + settings.constrain_total_variation_stride, + settings.constrain_total_variation_weight, + num_epochs, + ) + ) + self._remove_grid_artifacts_view_controller = PtyChiRemoveGridArtifactsViewController( + settings.remove_grid_artifacts, + settings.remove_grid_artifacts_start, + settings.remove_grid_artifacts_stop, + settings.remove_grid_artifacts_stride, + settings.remove_grid_artifacts_period_x_m, + settings.remove_grid_artifacts_period_y_m, + settings.remove_grid_artifacts_window_size_px, + settings.remove_grid_artifacts_direction, + num_epochs, + enumerators, + ) + self._regularize_multislice_view_controller = PtyChiRegularizeMultisliceViewController( + settings.regularize_multislice, + settings.regularize_multislice_start, + settings.regularize_multislice_stop, + settings.regularize_multislice_stride, + settings.regularize_multislice_weight, + settings.regularize_multislice_unwrap_phase, + settings.regularize_multislice_unwrap_phase_image_gradient_method, + settings.regularize_multislice_unwrap_phase_image_integration_method, + num_epochs, + enumerators, + ) + self._patch_interpolator_view_controller = ComboBoxParameterViewController( + settings.patch_interpolator, + enumerators.patch_interpolation_methods(), + tool_tip='Interpolation method used for extracting and updating patches of the object', + ) + self._remove_object_probe_ambiguity_view_controller = ( + PtyChiRemoveObjectProbeAmbiguityViewController( + settings.remove_object_probe_ambiguity, + settings.remove_object_probe_ambiguity_start, + settings.remove_object_probe_ambiguity_stop, + settings.remove_object_probe_ambiguity_stride, + num_epochs, + ) + ) + self._build_preconditioner_with_all_modes_view_controller = CheckBoxParameterViewController( + settings.build_preconditioner_with_all_modes, + 'Build Preconditioner with All Modes', + tool_tip='Whether to build the preconditioner using all modes', + ) + + layout = QFormLayout() + layout.addRow('Plan:', self._optimization_plan_view_controller.get_widget()) + layout.addRow('Optimizer:', self._optimizer_view_controller.get_widget()) + layout.addRow('Step Size:', self._step_size_view_controller.get_widget()) + layout.addRow(self._optimize_slice_spacing_view_controller.get_widget()) + layout.addRow(self._constrain_l1_norm_view_controller.get_widget()) + layout.addRow(self._constrain_l2_norm_view_controller.get_widget()) + layout.addRow(self._constrain_smoothness_view_controller.get_widget()) + layout.addRow(self._constrain_total_variation_view_controller.get_widget()) + layout.addRow(self._remove_grid_artifacts_view_controller.get_widget()) + layout.addRow(self._regularize_multislice_view_controller.get_widget()) + layout.addRow('Patch Interpolator:', self._patch_interpolator_view_controller.get_widget()) + layout.addRow(self._remove_object_probe_ambiguity_view_controller.get_widget()) + layout.addRow(self._build_preconditioner_with_all_modes_view_controller.get_widget()) + + if dm_settings is not None: + self._amplitude_clamp_limit_view_controller = DecimalLineEditParameterViewController( + dm_settings.object_amplitude_clamp_limit, + tool_tip='Maximum amplitude value for the object', + ) + layout.addRow( + 'Amplitude Clamp Limit:', self._amplitude_clamp_limit_view_controller.get_widget() + ) + + self._inertia_view_controller = DecimalLineEditParameterViewController( + dm_settings.object_inertia, + tool_tip='Inertia for the object update', + ) + layout.addRow('Inertia:', self._inertia_view_controller.get_widget()) + + if lsqml_settings is not None: + self._object_optimal_step_size_scaler_view_controller = ( + DecimalLineEditParameterViewController( + lsqml_settings.object_optimal_step_size_scaler, + tool_tip='Optimal step size scaler for the object update', + ) + ) + layout.addRow( + 'Optimal Step Size Scaler:', + self._object_optimal_step_size_scaler_view_controller.get_widget(), + ) + + self._object_multimodal_update_view_controller = CheckBoxParameterViewController( + lsqml_settings.object_multimodal_update, + 'Multimodal Update', + tool_tip='When checked, the object update direction is calculated and summed over all probe modes rather than only the first mode', + ) + layout.addRow(self._object_multimodal_update_view_controller.get_widget()) + + if pie_settings is not None: + self._alpha_view_controller = DecimalSliderParameterViewController( + pie_settings.object_alpha, + tool_tip='Relaxation factor for the object update', + ) + layout.addRow('Alpha:', self._alpha_view_controller.get_widget()) + + self.get_widget().setLayout(layout) diff --git a/src/ptychodus/controller/ptychi/opr.py b/src/ptychodus/controller/ptychi/opr.py new file mode 100644 index 00000000..ca1a820d --- /dev/null +++ b/src/ptychodus/controller/ptychi/opr.py @@ -0,0 +1,113 @@ +from PyQt5.QtWidgets import QFormLayout + +from ptychodus.api.parametric import BooleanParameter, IntegerParameter, StringParameter + +from ...model.ptychi import PtyChiEnumerators, PtyChiOPRSettings +from ..parametric import ( + CheckBoxParameterViewController, + CheckableGroupBoxParameterViewController, + ComboBoxParameterViewController, + DecimalLineEditParameterViewController, + DecimalSliderParameterViewController, + SpinBoxParameterViewController, +) +from .optimizer import PtyChiOptimizationPlanViewController, PtyChiOptimizerParameterViewController + + +class PtyChiSmoothOPRModeWeightsViewController(CheckableGroupBoxParameterViewController): + def __init__( + self, + smooth_mode_weights: BooleanParameter, + start: IntegerParameter, + stop: IntegerParameter, + stride: IntegerParameter, + smoothing_method: StringParameter, + polynomial_smoothing_degree: IntegerParameter, + num_epochs: IntegerParameter, + enumerators: PtyChiEnumerators, + ) -> None: + super().__init__( + smooth_mode_weights, + 'Smooth OPR Mode Weights', + tool_tip='Smooth the OPR mode weights', + ) + self._plan_view_controller = PtyChiOptimizationPlanViewController( + start, stop, stride, num_epochs + ) + self._smoothing_method_view_controller = ComboBoxParameterViewController( + smoothing_method, + enumerators.opr_weight_smoothing_methods(), + tool_tip='Method for smoothing OPR mode weights', + ) + self._polynomial_smoothing_degree_view_controller = SpinBoxParameterViewController( + polynomial_smoothing_degree, + tool_tip='Degree of the polynomial used for smoothing OPR mode weights', + ) + + layout = QFormLayout() + layout.addRow('Plan:', self._plan_view_controller.get_widget()) + layout.addRow('Smoothing Method:', self._smoothing_method_view_controller.get_widget()) + layout.addRow( + 'Polynomial Degree:', self._polynomial_smoothing_degree_view_controller.get_widget() + ) + self.get_widget().setLayout(layout) + + +class PtyChiOPRViewController(CheckableGroupBoxParameterViewController): + def __init__( + self, + settings: PtyChiOPRSettings, + num_epochs: IntegerParameter, + enumerators: PtyChiEnumerators, + ) -> None: + super().__init__( + settings.is_optimizable, + 'Orthogonal Probe Relaxation', + tool_tip='Whether OPR modes are optimizable', + ) + self._optimization_plan_view_controller = PtyChiOptimizationPlanViewController( + settings.optimization_plan_start, + settings.optimization_plan_stop, + settings.optimization_plan_stride, + num_epochs, + ) + self._optimizer_view_controller = PtyChiOptimizerParameterViewController( + settings.optimizer, enumerators + ) + self._step_size_view_controller = DecimalLineEditParameterViewController( + settings.step_size, tool_tip='Optimizer step size' + ) + self._optimize_eigenmode_weights_view_controller = CheckBoxParameterViewController( + settings.optimize_eigenmode_weights, + 'Optimize Eigenmode Weights', + tool_tip='Whether to optimize eigenmode weights (i.e., the weights of the second and following OPR modes)', + ) + self._optimize_intensities_view_controller = CheckBoxParameterViewController( + settings.optimize_intensities, + 'Optimize Intensities', + tool_tip='Whether to optimize intensity variation (i.e., the weight of the first OPR mode)', + ) + self._smooth_mode_weights_view_controller = PtyChiSmoothOPRModeWeightsViewController( + settings.smooth_mode_weights, + settings.smooth_mode_weights_start, + settings.smooth_mode_weights_stop, + settings.smooth_mode_weights_stride, + settings.smoothing_method, + settings.polynomial_smoothing_degree, + num_epochs, + enumerators, + ) + self._relax_update_view_controller = DecimalSliderParameterViewController( + settings.relax_update, + tool_tip='Whether to relax the update of the OPR mode weights', + ) + + layout = QFormLayout() + layout.addRow('Plan:', self._optimization_plan_view_controller.get_widget()) + layout.addRow('Optimizer:', self._optimizer_view_controller.get_widget()) + layout.addRow('Step Size:', self._step_size_view_controller.get_widget()) + layout.addRow(self._optimize_intensities_view_controller.get_widget()) + layout.addRow(self._optimize_eigenmode_weights_view_controller.get_widget()) + layout.addRow(self._smooth_mode_weights_view_controller.get_widget()) + layout.addRow('Relax Update:', self._relax_update_view_controller.get_widget()) + self.get_widget().setLayout(layout) diff --git a/src/ptychodus/controller/ptychi/optimizer.py b/src/ptychodus/controller/ptychi/optimizer.py new file mode 100644 index 00000000..9e97faff --- /dev/null +++ b/src/ptychodus/controller/ptychi/optimizer.py @@ -0,0 +1,93 @@ +from PyQt5.QtWidgets import QHBoxLayout, QLabel, QSpinBox, QWidget + +from ptychodus.api.observer import Observable, Observer +from ptychodus.api.parametric import IntegerParameter, StringParameter + +from ...model.ptychi import PtyChiEnumerators +from ..parametric import ( + ComboBoxParameterViewController, + SpinBoxParameterViewController, + ParameterViewController, +) + +__all__ = [ + 'PtyChiOptimizationPlanViewController', + 'PtyChiOptimizerParameterViewController', +] + + +class PtyChiStopSpinBoxParameterViewController(ParameterViewController, Observer): + def __init__( + self, stop: IntegerParameter, num_epochs: IntegerParameter, *, tool_tip: str = '' + ) -> None: + super().__init__() + self._stop = stop + self._num_epochs = num_epochs + self._widget = QSpinBox() + + if tool_tip: + self._widget.setToolTip(tool_tip) + + self._sync_model_to_view() + self._widget.valueChanged.connect(self._sync_view_to_model) + stop.add_observer(self) + num_epochs.add_observer(self) + + def get_widget(self) -> QWidget: + return self._widget + + def _sync_view_to_model(self, value: int) -> None: + num_epochs = self._num_epochs.get_value() + self._stop.set_value(value if value < num_epochs else -1) + + def _sync_model_to_view(self) -> None: + num_epochs = self._num_epochs.get_value() + stop = self._stop.get_value() + + self._widget.blockSignals(True) + self._widget.setRange(0, num_epochs) + self._widget.setValue(num_epochs if stop < 0 else stop) + self._widget.blockSignals(False) + + def _update(self, observable: Observable) -> None: + if observable in (self._stop, self._num_epochs): + self._sync_model_to_view() + + +class PtyChiOptimizationPlanViewController(ParameterViewController): + def __init__( + self, + start: IntegerParameter, + stop: IntegerParameter, + stride: IntegerParameter, + num_epochs: IntegerParameter, + ) -> None: + super().__init__() + self._start_view_controller = SpinBoxParameterViewController( + start, tool_tip='Iteration to start optimizing' + ) + self._stop_view_controller = PtyChiStopSpinBoxParameterViewController( + stop, num_epochs, tool_tip='Iteration to stop optimizing' + ) + self._stride_view_controller = SpinBoxParameterViewController( + stride, tool_tip='Number of iterations between updates' + ) + self._widget = QWidget() + + layout = QHBoxLayout() + layout.setContentsMargins(0, 0, 0, 0) + layout.addWidget(QLabel('Start'), 0) + layout.addWidget(self._start_view_controller.get_widget(), 1) + layout.addWidget(QLabel('Stop'), 0) + layout.addWidget(self._stop_view_controller.get_widget(), 1) + layout.addWidget(QLabel('Stride'), 0) + layout.addWidget(self._stride_view_controller.get_widget(), 1) + self._widget.setLayout(layout) + + def get_widget(self) -> QWidget: + return self._widget + + +class PtyChiOptimizerParameterViewController(ComboBoxParameterViewController): + def __init__(self, parameter: StringParameter, enumerators: PtyChiEnumerators) -> None: + super().__init__(parameter, enumerators.optimizers(), tool_tip='Name of the optimizer.') diff --git a/src/ptychodus/controller/ptychi/positions.py b/src/ptychodus/controller/ptychi/positions.py new file mode 100644 index 00000000..bd3b2109 --- /dev/null +++ b/src/ptychodus/controller/ptychi/positions.py @@ -0,0 +1,296 @@ +from typing import Any + +from PyQt5.QtCore import Qt, QAbstractListModel, QModelIndex, QObject +from PyQt5.QtWidgets import QFormLayout, QFrame, QListView, QWidget + +from ptychodus.api.observer import Observable, Observer +from ptychodus.api.parametric import ( + BooleanParameter, + IntegerParameter, + RealParameter, + StringParameter, +) + +from ...model.ptychi import ( + PtyChiAffineDegreesOfFreedomBitField, + PtyChiEnumerators, + PtyChiProbePositionSettings, +) +from ..parametric import ( + CheckBoxParameterViewController, + CheckableGroupBoxParameterViewController, + ComboBoxParameterViewController, + DecimalLineEditParameterViewController, + DecimalSliderParameterViewController, + ParameterViewController, + SpinBoxParameterViewController, +) +from .optimizer import PtyChiOptimizationPlanViewController, PtyChiOptimizerParameterViewController + +__all__ = ['PtyChiProbePositionsViewController'] + + +class PtyChiCrossCorrelationViewController(ParameterViewController, Observer): + def __init__( + self, + algorithm: StringParameter, + scale: IntegerParameter, + real_space_width: RealParameter, + probe_threshold: RealParameter, + ) -> None: + super().__init__() + self._algorithm = algorithm + self._scale_view_controller = SpinBoxParameterViewController( + scale, tool_tip='Upsampling factor of the cross-correlation in real space' + ) + self._real_space_width_view_controller = DecimalLineEditParameterViewController( + real_space_width, tool_tip='Width of the cross-correlation in real-space' + ) + self._probe_threshold_view_controller = DecimalSliderParameterViewController( + probe_threshold, tool_tip='Probe intensity threshold used to calculate the probe mask' + ) + self._widget = QFrame() + self._widget.setFrameShape(QFrame.StyledPanel) + + layout = QFormLayout() + layout.addRow('Scale:', self._scale_view_controller.get_widget()) + layout.addRow('Real Space Width:', self._real_space_width_view_controller.get_widget()) + layout.addRow('Probe Threshold:', self._probe_threshold_view_controller.get_widget()) + self._widget.setLayout(layout) + + algorithm.add_observer(self) + self._sync_model_to_view() + + def get_widget(self) -> QWidget: + return self._widget + + def _sync_model_to_view(self) -> None: + self._widget.setVisible(self._algorithm.get_value().upper() == 'CROSS_CORRELATION') + + def _update(self, observable: Observable) -> None: + if observable is self._algorithm: + self._sync_model_to_view() + + +class PtyChiUpdateMagnitudeLimitViewController(ParameterViewController, Observer): + def __init__( + self, + limit_update_magnitude: BooleanParameter, + update_magnitude_limit: RealParameter, + ) -> None: + self._limit_update_magnitude = limit_update_magnitude + self._limit_update_magnitude_view_controller = CheckBoxParameterViewController( + limit_update_magnitude, + 'Limit Update Magnitude:', + tool_tip='Whether to limit the update magnitude', + ) + self._update_magnitude_limit_view_controller = DecimalLineEditParameterViewController( + update_magnitude_limit, + tool_tip='Maximum allowed magnitude of position update in each axis', + ) + + limit_update_magnitude.add_observer(self) + self._sync_model_to_view() + + def get_label(self) -> QWidget: + return self._limit_update_magnitude_view_controller.get_widget() + + def get_widget(self) -> QWidget: + return self._update_magnitude_limit_view_controller.get_widget() + + def _sync_model_to_view(self) -> None: + self._update_magnitude_limit_view_controller.get_widget().setEnabled( + self._limit_update_magnitude.get_value() + ) + + def _update(self, observable: Observable) -> None: + if observable is self._limit_update_magnitude: + self._sync_model_to_view() + + +class PtyChiAffineDegreesOfFreedomListModel(QAbstractListModel): + def __init__(self, parameter: IntegerParameter, parent: QObject | None = None) -> None: + super().__init__(parent) + self._dof = PtyChiAffineDegreesOfFreedomBitField(parameter) + + def flags(self, index: QModelIndex) -> Qt.ItemFlags: + value = super().flags(index) + + if index.isValid(): + value |= Qt.ItemFlag.ItemIsUserCheckable + + return value + + def data(self, index: QModelIndex, role: int = Qt.ItemDataRole.DisplayRole) -> Any: + if index.isValid(): + if role == Qt.ItemDataRole.DisplayRole: + return self._dof[index.row()] + elif role == Qt.ItemDataRole.CheckStateRole: + return ( + Qt.CheckState.Checked + if self._dof.is_bit_set(index.row()) + else Qt.CheckState.Unchecked + ) + + def setData(self, index: QModelIndex, value: Any, role: int = Qt.ItemDataRole.EditRole) -> bool: # noqa: N802 + if index.isValid() and role == Qt.ItemDataRole.CheckStateRole: + self._dof.set_bit(index.row(), value == Qt.CheckState.Checked) + self.dataChanged.emit(index, index) + return True + + return False + + def rowCount(self, parent: QModelIndex = QModelIndex()) -> int: # noqa: N802 + return len(self._dof) + + +class PtyChiAffineDegreesOfFreedomViewController(ParameterViewController, Observer): + def __init__(self, parameter: IntegerParameter) -> None: + super().__init__() + self._parameter = parameter + self._list_model = PtyChiAffineDegreesOfFreedomListModel(parameter) + self._widget = QListView() + self._widget.setModel(self._list_model) + + parameter.add_observer(self) + self._sync_model_to_view() + + def get_widget(self) -> QWidget: + return self._widget + + def _sync_model_to_view(self) -> None: + self._list_model.beginResetModel() + self._list_model.endResetModel() + + def _update(self, observable: Observable) -> None: + if observable is self._parameter: + self._sync_model_to_view() + + +class PtyChiConstrainAffineTransformViewController(CheckableGroupBoxParameterViewController): + def __init__( + self, + is_optimizable: BooleanParameter, + start: IntegerParameter, + stop: IntegerParameter, + stride: IntegerParameter, + degrees_of_freedom: IntegerParameter, + position_weight_update_interval: IntegerParameter, + apply_constraint: BooleanParameter, + max_expected_error_px: RealParameter, + num_epochs: IntegerParameter, + ) -> None: + super().__init__( + is_optimizable, + 'Constrain Affine Transform', + tool_tip='Constrain the affine transform during position correction', + ) + self._plan_view_controller = PtyChiOptimizationPlanViewController( + start, stop, stride, num_epochs + ) + self._degrees_of_freedom_view_controller = PtyChiAffineDegreesOfFreedomViewController( + degrees_of_freedom + ) + self._weight_update_interval_view_controller = SpinBoxParameterViewController( + position_weight_update_interval, + tool_tip='Interval for updating the position weight', + ) + self._apply_constraint_view_controller = CheckBoxParameterViewController( + apply_constraint, 'Apply Constraint', tool_tip='Whether to apply the constraint' + ) + self._max_expected_error_px_view_controller = DecimalLineEditParameterViewController( + max_expected_error_px, tool_tip='Maximum expected error in pixels' + ) + + layout = QFormLayout() + layout.addRow('Plan:', self._plan_view_controller.get_widget()) + layout.addRow('Degrees of Freedom:', self._degrees_of_freedom_view_controller.get_widget()) + layout.addRow( + 'Weight Update Interval:', self._weight_update_interval_view_controller.get_widget() + ) + layout.addRow(self._apply_constraint_view_controller.get_widget()) + layout.addRow( + 'Max Expected Error [px]:', self._max_expected_error_px_view_controller.get_widget() + ) + self.get_widget().setLayout(layout) + + +class PtyChiProbePositionsViewController(CheckableGroupBoxParameterViewController): + def __init__( + self, + settings: PtyChiProbePositionSettings, + num_epochs: IntegerParameter, + enumerators: PtyChiEnumerators, + ) -> None: + super().__init__( + settings.is_optimizable, + 'Optimize Probe Positions', + tool_tip='Whether the probe positions are optimizable', + ) + self._optimization_plan_view_controller = PtyChiOptimizationPlanViewController( + settings.optimization_plan_start, + settings.optimization_plan_stop, + settings.optimization_plan_stride, + num_epochs, + ) + self._optimizer_view_controller = PtyChiOptimizerParameterViewController( + settings.optimizer, enumerators + ) + self._step_size_view_controller = DecimalLineEditParameterViewController( + settings.step_size, tool_tip='Optimizer step size' + ) + self._constrain_centroid_view_controller = CheckBoxParameterViewController( + settings.constrain_centroid, + 'Constrain Centroid', + tool_tip='Whether to subtract the mean from positions after updating positions', + ) + self._correction_type_view_controller = ComboBoxParameterViewController( + settings.correction_type, + enumerators.position_correction_types(), + tool_tip='Algorithm used to calculate the position correction update', + ) + self._differentiation_method_view_controller = ComboBoxParameterViewController( + settings.differentiation_method, + enumerators.image_gradient_methods(), + tool_tip='Method for calculating the object gradient', + ) + self._cross_correlation_view_controller = PtyChiCrossCorrelationViewController( + settings.correction_type, + settings.cross_correlation_scale, + settings.cross_correlation_real_space_width, + settings.cross_correlation_probe_threshold, + ) + self._update_magnitude_limit_view_controller = PtyChiUpdateMagnitudeLimitViewController( + settings.limit_update_magnitude, + settings.update_magnitude_limit, + ) + self._constrain_affine_transform_view_controller = ( + PtyChiConstrainAffineTransformViewController( + settings.constrain_affine_transform, + settings.constrain_affine_transform_start, + settings.constrain_affine_transform_stop, + settings.constrain_affine_transform_stride, + settings.constrain_affine_transform_degrees_of_freedom, + settings.constrain_affine_transform_position_weight_update_interval, + settings.constrain_affine_transform_apply_constraint, + settings.constrain_affine_transform_max_expected_error_px, + num_epochs, + ) + ) + + layout = QFormLayout() + layout.addRow('Plan:', self._optimization_plan_view_controller.get_widget()) + layout.addRow('Optimizer:', self._optimizer_view_controller.get_widget()) + layout.addRow('Step Size:', self._step_size_view_controller.get_widget()) + layout.addRow(self._constrain_centroid_view_controller.get_widget()) + layout.addRow('Correction Type:', self._correction_type_view_controller.get_widget()) + layout.addRow( + 'Differentiation Method:', self._differentiation_method_view_controller.get_widget() + ) + layout.addRow(self._cross_correlation_view_controller.get_widget()) + layout.addRow( + self._update_magnitude_limit_view_controller.get_label(), + self._update_magnitude_limit_view_controller.get_widget(), + ) + layout.addRow(self._constrain_affine_transform_view_controller.get_widget()) + self.get_widget().setLayout(layout) diff --git a/src/ptychodus/controller/ptychi/probe.py b/src/ptychodus/controller/ptychi/probe.py new file mode 100644 index 00000000..4f03aada --- /dev/null +++ b/src/ptychodus/controller/ptychi/probe.py @@ -0,0 +1,259 @@ +from PyQt5.QtWidgets import QFormLayout + +from ptychodus.api.parametric import ( + BooleanParameter, + IntegerParameter, + RealParameter, + StringParameter, +) + +from ...model.ptychi import ( + PtyChiDMSettings, + PtyChiEnumerators, + PtyChiLSQMLSettings, + PtyChiPIESettings, + PtyChiProbeSettings, +) +from ..parametric import ( + CheckableGroupBoxParameterViewController, + ComboBoxParameterViewController, + DecimalLineEditParameterViewController, + DecimalSliderParameterViewController, +) +from .optimizer import PtyChiOptimizationPlanViewController, PtyChiOptimizerParameterViewController + +__all__ = ['PtyChiProbeViewController'] + + +class PtyChiConstrainProbePowerViewController(CheckableGroupBoxParameterViewController): + def __init__( + self, + constrain_power: BooleanParameter, + start: IntegerParameter, + stop: IntegerParameter, + stride: IntegerParameter, + num_epochs: IntegerParameter, + ) -> None: + super().__init__( + constrain_power, 'Constrain Power', tool_tip='Whether to constrain probe power' + ) + self._plan_view_controller = PtyChiOptimizationPlanViewController( + start, stop, stride, num_epochs + ) + + layout = QFormLayout() + layout.addRow('Plan:', self._plan_view_controller.get_widget()) + self.get_widget().setLayout(layout) + + +class PtyChiOrthogonalizeIncoherentModesViewController(CheckableGroupBoxParameterViewController): + def __init__( + self, + orthogonalize_modes: BooleanParameter, + start: IntegerParameter, + stop: IntegerParameter, + stride: IntegerParameter, + method: StringParameter, + num_epochs: IntegerParameter, + enumerators: PtyChiEnumerators, + ) -> None: + super().__init__( + orthogonalize_modes, + 'Orthogonalize Incoherent Modes', + tool_tip='Whether to orthogonalize incoherent probe modes', + ) + self._plan_view_controller = PtyChiOptimizationPlanViewController( + start, stop, stride, num_epochs + ) + self._method_view_controller = ComboBoxParameterViewController( + method, + enumerators.orthogonalization_methods(), + tool_tip='Method to use for incoherent mode orthogonalization', + ) + + layout = QFormLayout() + layout.addRow('Plan:', self._plan_view_controller.get_widget()) + layout.addRow('Method:', self._method_view_controller.get_widget()) + self.get_widget().setLayout(layout) + + +class PtyChiOrthogonalizeOPRModesViewController(CheckableGroupBoxParameterViewController): + def __init__( + self, + orthogonalize_modes: BooleanParameter, + start: IntegerParameter, + stop: IntegerParameter, + stride: IntegerParameter, + num_epochs: IntegerParameter, + ) -> None: + super().__init__( + orthogonalize_modes, + 'Orthogonalize OPR Modes', + tool_tip='Whether to orthogonalize OPR modes', + ) + self._plan_view_controller = PtyChiOptimizationPlanViewController( + start, stop, stride, num_epochs + ) + + layout = QFormLayout() + layout.addRow('Plan:', self._plan_view_controller.get_widget()) + self.get_widget().setLayout(layout) + + +class PtyChiConstrainSupportViewController(CheckableGroupBoxParameterViewController): + def __init__( + self, + constrain_support: BooleanParameter, + start: IntegerParameter, + stop: IntegerParameter, + stride: IntegerParameter, + threshold: RealParameter, + num_epochs: IntegerParameter, + ) -> None: + super().__init__( + constrain_support, + 'Constrain Support', + tool_tip='When enabled, the probe will be shrinkwrapped so that small values are set to zero', + ) + self._plan_view_controller = PtyChiOptimizationPlanViewController( + start, stop, stride, num_epochs + ) + self._threshold_view_controller = DecimalLineEditParameterViewController( + threshold, tool_tip='Threshold for the probe support constraint' + ) + + layout = QFormLayout() + layout.addRow('Plan:', self._plan_view_controller.get_widget()) + layout.addRow('Threshold:', self._threshold_view_controller.get_widget()) + self.get_widget().setLayout(layout) + + +class PtyChiConstrainCenterViewController(CheckableGroupBoxParameterViewController): + def __init__( + self, + constrain_center: BooleanParameter, + start: IntegerParameter, + stop: IntegerParameter, + stride: IntegerParameter, + num_epochs: IntegerParameter, + ) -> None: + super().__init__( + constrain_center, + 'Constrain Center', + tool_tip='When enabled, the probe center of mass will be constrained to the center of the probe array', + ) + self._plan_view_controller = PtyChiOptimizationPlanViewController( + start, stop, stride, num_epochs + ) + + layout = QFormLayout() + layout.addRow('Plan:', self._plan_view_controller.get_widget()) + self.get_widget().setLayout(layout) + + +class PtyChiProbeViewController(CheckableGroupBoxParameterViewController): + def __init__( + self, + settings: PtyChiProbeSettings, + dm_settings: PtyChiDMSettings | None, + lsqml_settings: PtyChiLSQMLSettings | None, + pie_settings: PtyChiPIESettings | None, + num_epochs: IntegerParameter, + enumerators: PtyChiEnumerators, + ) -> None: + super().__init__( + settings.is_optimizable, 'Optimize Probe', tool_tip='Whether the probe is optimizable' + ) + self._optimization_plan_view_controller = PtyChiOptimizationPlanViewController( + settings.optimization_plan_start, + settings.optimization_plan_stop, + settings.optimization_plan_stride, + num_epochs, + ) + self._optimizer_view_controller = PtyChiOptimizerParameterViewController( + settings.optimizer, enumerators + ) + self._step_size_view_controller = DecimalLineEditParameterViewController( + settings.step_size, tool_tip='Optimizer step size' + ) + self._constrain_probe_power_view_controller = PtyChiConstrainProbePowerViewController( + settings.constrain_probe_power, + settings.constrain_probe_power_start, + settings.constrain_probe_power_stop, + settings.constrain_probe_power_stride, + num_epochs, + ) + self._orthogonalize_incoherent_modes_view_controller = ( + PtyChiOrthogonalizeIncoherentModesViewController( + settings.orthogonalize_incoherent_modes, + settings.orthogonalize_incoherent_modes_start, + settings.orthogonalize_incoherent_modes_stop, + settings.orthogonalize_incoherent_modes_stride, + settings.orthogonalize_incoherent_modes_method, + num_epochs, + enumerators, + ) + ) + self._orthogonalize_opr_modes_view_controller = PtyChiOrthogonalizeOPRModesViewController( + settings.orthogonalize_opr_modes, + settings.orthogonalize_opr_modes_start, + settings.orthogonalize_opr_modes_stop, + settings.orthogonalize_opr_modes_stride, + num_epochs, + ) + self._constrain_support_view_controller = PtyChiConstrainSupportViewController( + settings.constrain_support, + settings.constrain_support_start, + settings.constrain_support_stop, + settings.constrain_support_stride, + settings.constrain_support_threshold, + num_epochs, + ) + self._constrain_center_view_controller = PtyChiConstrainCenterViewController( + settings.constrain_center, + settings.constrain_center_start, + settings.constrain_center_stop, + settings.constrain_center_stride, + num_epochs, + ) + self._relax_eigenmode_update_view_controller = DecimalSliderParameterViewController( + settings.relax_eigenmode_update, + tool_tip='Relaxation factor for the eigenmode update', + ) + + layout = QFormLayout() + layout.addRow('Plan:', self._optimization_plan_view_controller.get_widget()) + layout.addRow('Optimizer:', self._optimizer_view_controller.get_widget()) + layout.addRow('Step Size:', self._step_size_view_controller.get_widget()) + layout.addRow(self._constrain_probe_power_view_controller.get_widget()) + layout.addRow(self._orthogonalize_incoherent_modes_view_controller.get_widget()) + layout.addRow(self._orthogonalize_opr_modes_view_controller.get_widget()) + layout.addRow(self._constrain_support_view_controller.get_widget()) + layout.addRow(self._constrain_center_view_controller.get_widget()) + layout.addRow( + 'Relax Eigenmode Update:', self._relax_eigenmode_update_view_controller.get_widget() + ) + + if dm_settings is not None: + self._inertia_view_controller = DecimalLineEditParameterViewController( + dm_settings.probe_inertia, tool_tip='Inertia for the probe update' + ) + layout.addRow('Inertia:', self._inertia_view_controller.get_widget()) + + if lsqml_settings is not None: + self._optimal_step_size_scaler_view_controller = DecimalLineEditParameterViewController( + lsqml_settings.probe_optimal_step_size_scaler, + tool_tip='Optimal step size scaler for the probe update', + ) + layout.addRow( + 'Optimal Step Size Scaler:', + self._optimal_step_size_scaler_view_controller.get_widget(), + ) + + if pie_settings is not None: + self._alpha = DecimalSliderParameterViewController( + pie_settings.probe_alpha, tool_tip='Relaxation factor for the probe update' + ) + layout.addRow('Alpha:', self._alpha.get_widget()) + + self.get_widget().setLayout(layout) diff --git a/src/ptychodus/controller/ptychi/reconstructor.py b/src/ptychodus/controller/ptychi/reconstructor.py new file mode 100644 index 00000000..766bf808 --- /dev/null +++ b/src/ptychodus/controller/ptychi/reconstructor.py @@ -0,0 +1,313 @@ +from PyQt5.QtWidgets import ( + QButtonGroup, + QFormLayout, + QGroupBox, + QHBoxLayout, + QLabel, + QRadioButton, + QVBoxLayout, + QWidget, +) + +from ptychodus.api.observer import Observable, Observer +from ptychodus.api.parametric import BooleanParameter, RealParameter + +from ...model.ptychi import ( + PtyChiAutodiffSettings, + PtyChiDMSettings, + PtyChiDeviceRepository, + PtyChiEnumerators, + PtyChiLSQMLSettings, + PtyChiSettings, +) +from ..parametric import ( + CheckBoxParameterViewController, + CheckableGroupBoxParameterViewController, + ComboBoxParameterViewController, + DecimalLineEditParameterViewController, + DecimalSliderParameterViewController, + ParameterViewController, + SpinBoxParameterViewController, +) + +__all__ = ['PtyChiReconstructorViewController'] + + +class PtyChiDeviceViewController(CheckableGroupBoxParameterViewController): + def __init__( + self, + use_devices: BooleanParameter, + repository: PtyChiDeviceRepository, + *, + tool_tip: str = '', + ) -> None: + super().__init__(use_devices, 'Use Devices', tool_tip=tool_tip) + layout = QVBoxLayout() + + for device in repository: + device_label = QLabel(device) + layout.addWidget(device_label) + + self.get_widget().setLayout(layout) + + +class PtyChiPrecisionParameterViewController(ParameterViewController, Observer): + def __init__(self, use_double_precision: BooleanParameter, *, tool_tip: str = '') -> None: + super().__init__() + self._use_double_precision = use_double_precision + self._single_precision_button = QRadioButton('Single') + self._double_precision_button = QRadioButton('Double') + self._button_group = QButtonGroup() + self._widget = QWidget() + + self._single_precision_button.setToolTip('Compute using single precision') + self._double_precision_button.setToolTip('Compute using double precision') + + if tool_tip: + self._widget.setToolTip(tool_tip) + + self._button_group.addButton(self._single_precision_button, 1) + self._button_group.addButton(self._double_precision_button, 2) + self._button_group.setExclusive(True) + self._button_group.idToggled.connect(self._sync_view_to_model) + + layout = QHBoxLayout() + layout.setContentsMargins(0, 0, 0, 0) + layout.addWidget(self._single_precision_button) + layout.addWidget(self._double_precision_button) + layout.addStretch() + self._widget.setLayout(layout) + + self._sync_model_to_view() + use_double_precision.add_observer(self) + + def get_widget(self) -> QWidget: + return self._widget + + def _sync_view_to_model(self, tool_id: int, checked: bool) -> None: + if tool_id == 2: + self._use_double_precision.set_value(checked) + + def _sync_model_to_view(self) -> None: + button = self._button_group.button(2 if self._use_double_precision.get_value() else 1) + button.setChecked(True) + + def _update(self, observable: Observable) -> None: + if observable is self._use_double_precision: + self._sync_model_to_view() + + +class PtyChiMomentumAccelerationGradientMixingFactorViewController( + CheckableGroupBoxParameterViewController +): + def __init__( + self, + use_gradient_mixing_factor: BooleanParameter, + gradient_mixing_factor: RealParameter, + *, + tool_tip: str = '', + ) -> None: + super().__init__( + use_gradient_mixing_factor, + 'Use Gradient Mixing Factor', + tool_tip='Controls how the current gradient is mixed with the accumulated velocity in LSQML momentum acceleration', + ) + self._gradient_mixing_factor_view_controller = DecimalLineEditParameterViewController( + gradient_mixing_factor, tool_tip=tool_tip + ) + + layout = QVBoxLayout() + layout.addWidget(self._gradient_mixing_factor_view_controller.get_widget()) + self.get_widget().setLayout(layout) + + +class PtyChiReconstructorViewController(ParameterViewController): + def __init__( + self, + settings: PtyChiSettings, + autodiff_settings: PtyChiAutodiffSettings | None, + dm_settings: PtyChiDMSettings | None, + lsqml_settings: PtyChiLSQMLSettings | None, + enumerators: PtyChiEnumerators, + repository: PtyChiDeviceRepository, + ) -> None: + super().__init__() + self._num_epochs_view_controller = SpinBoxParameterViewController( + settings.num_epochs, tool_tip='Number of epochs to run' + ) + self._batch_size_view_controller = SpinBoxParameterViewController( + settings.batch_size, tool_tip='Number of data to process in each minibatch' + ) + self._batching_mode_view_controller = ComboBoxParameterViewController( + settings.batching_mode, enumerators.batching_modes(), tool_tip='Batching mode to use' + ) + self._compact_mode_update_clustering_view_controller = SpinBoxParameterViewController( + settings.compact_mode_update_clustering, + tool_tip='When greater than zero, the number of epochs between updating clusters in compact batching mode', + ) + self._device_view_controller = PtyChiDeviceViewController( + settings.use_devices, repository, tool_tip='Default device to use for computation' + ) + self._compute_precision_view_controller = PtyChiPrecisionParameterViewController( + settings.use_double_precision, + tool_tip='Floating point precision to use for computation', + ) + self._fft_precision_view_controller = PtyChiPrecisionParameterViewController( + settings.use_double_precision_for_fft, + tool_tip='Floating point precision to use for critical FFT operations', + ) + self.allow_nondeterministic_algorithms_view_controller = CheckBoxParameterViewController( + settings.allow_nondeterministic_algorithms, + 'Allow Nondeterministic Algorithms', + tool_tip='When checked, nondeterministic algorithms will be used. This may lead to different results on different runs', + ) + + self._use_low_memory_view_controller = CheckBoxParameterViewController( + settings.use_low_memory_mode, + 'Use Low Memory Mode', + tool_tip='When checked, forward propagation of ptychography will be done using less vectorized code. This reduces the speed, but also lowers memory usage', + ) + self._pad_for_shift_view_controller = SpinBoxParameterViewController( + settings.pad_for_shift, + tool_tip='Number of pixels to pad arrays (with border values) before shifting', + ) + + self._use_far_field_propagation_view_controller = CheckBoxParameterViewController( + settings.use_far_field_propagation, + 'Use Far Field Propagation', + tool_tip='When checked, far field propagation will be used instead of near field propagation', + ) + self._fft_shift_diffraction_patterns_view_controller = CheckBoxParameterViewController( + settings.fft_shift_diffraction_patterns, + 'FFT Shift Diffraction Patterns', + tool_tip='When checked, the diffraction patterns will be FFT-shifted', + ) + self._save_data_on_device_view_controller = CheckBoxParameterViewController( + settings.save_data_on_device, + 'Save Data on Device', + tool_tip='When checked, diffraction data will be saved on the device', + ) + self._widget = QGroupBox('Reconstructor') + + layout = QFormLayout() + layout.addRow('Number of Epochs:', self._num_epochs_view_controller.get_widget()) + + if dm_settings is None: + layout.addRow('Batch Size:', self._batch_size_view_controller.get_widget()) + layout.addRow('Batch Mode:', self._batching_mode_view_controller.get_widget()) + layout.addRow( + 'Update Clustering:', + self._compact_mode_update_clustering_view_controller.get_widget(), + ) + + layout.addRow(self._device_view_controller.get_widget()) + layout.addRow('Compute Precision:', self._compute_precision_view_controller.get_widget()) + layout.addRow('FFT Precision:', self._fft_precision_view_controller.get_widget()) + layout.addRow(self.allow_nondeterministic_algorithms_view_controller.get_widget()) + + layout.addRow(self._use_low_memory_view_controller.get_widget()) + layout.addRow('Pad For Shift:', self._pad_for_shift_view_controller.get_widget()) + + layout.addRow(self._use_far_field_propagation_view_controller.get_widget()) + layout.addRow(self._fft_shift_diffraction_patterns_view_controller.get_widget()) + layout.addRow(self._save_data_on_device_view_controller.get_widget()) + + if autodiff_settings is not None: + self._loss_function_view_controller = ComboBoxParameterViewController( + autodiff_settings.loss_function, + enumerators.loss_functions(), + tool_tip='Loss function to optimize', + ) + layout.addRow('Loss Function:', self._loss_function_view_controller.get_widget()) + + self._forward_model_class_view_controller = ComboBoxParameterViewController( + autodiff_settings.forward_model_class, + enumerators.forward_models(), + tool_tip='Forward model class', + ) + layout.addRow('Forward Model:', self._forward_model_class_view_controller.get_widget()) + + if dm_settings is not None: + self._exit_wave_update_relaxation_view_controller = ( + DecimalSliderParameterViewController( + dm_settings.exit_wave_update_relaxation, + tool_tip='Relaxation multiplier for the exit wave update', + ) + ) + layout.addRow( + 'Exit Wave Update Relaxation:', + self._exit_wave_update_relaxation_view_controller.get_widget(), + ) + + self._chunk_length_view_controller = SpinBoxParameterViewController( + dm_settings.chunk_length, + tool_tip='Number of scan positions used in each chunk of the exit wave update loop', + ) + layout.addRow('Chunk Length:', self._chunk_length_view_controller.get_widget()) + + if lsqml_settings is not None: + self._noise_model_view_controller = ComboBoxParameterViewController( + lsqml_settings.noise_model, + enumerators.noise_models(), + tool_tip='Noise model to use', + ) + layout.addRow('Noise Model:', self._noise_model_view_controller.get_widget()) + + self._gaussian_noise_deviation_view_controller = DecimalLineEditParameterViewController( + lsqml_settings.gaussian_noise_deviation, + tool_tip='Standard deviation of the Gaussian noise', + ) + layout.addRow( + 'Gaussian Noise Deviation:', + self._gaussian_noise_deviation_view_controller.get_widget(), + ) + + self._solve_object_probe_step_size_jointly_for_first_slice_in_multislice_view_controller = CheckBoxParameterViewController( + lsqml_settings.solve_object_probe_step_size_jointly_for_first_slice_in_multislice, + 'Solve Object Probe Step Size Jointly For First Slice In Multislice', + tool_tip='When checked, the object and probe step length calculation will be solved simultaneously', + ) + layout.addRow( + self._solve_object_probe_step_size_jointly_for_first_slice_in_multislice_view_controller.get_widget() + ) + + self._solve_step_sizes_only_using_first_probe_mode_view_controller = CheckBoxParameterViewController( + lsqml_settings.solve_step_sizes_only_using_first_probe_mode, + 'Solve Step Sizes Only Using First Probe Mode', + tool_tip='When checked, the step sizes will be calculated using only the first probe mode', + ) + layout.addRow( + self._solve_step_sizes_only_using_first_probe_mode_view_controller.get_widget() + ) + + self._momentum_acceleration_gain_view_controller = ( + DecimalLineEditParameterViewController( + lsqml_settings.momentum_acceleration_gain, + tool_tip='Gain of momentum accleration', + ) + ) + layout.addRow( + 'Momentum Acceleration Gain:', + self._momentum_acceleration_gain_view_controller.get_widget(), + ) + + self._momentum_acceleration_gradient_mixing_factor_view_controller = PtyChiMomentumAccelerationGradientMixingFactorViewController( + lsqml_settings.use_momentum_acceleration_gradient_mixing_factor, + lsqml_settings.momentum_acceleration_gradient_mixing_factor, + tool_tip='Controls how the current gradient is mixed with the accumulated velocity in LSQML momentum acceleration', + ) + layout.addRow( + self._momentum_acceleration_gradient_mixing_factor_view_controller.get_widget() + ) + + self._rescale_probe_intensity_in_first_epoch_view_controller = CheckBoxParameterViewController( + lsqml_settings.rescale_probe_intensity_in_first_epoch, + 'Rescale Probe Intensity In First Epoch', + tool_tip='When checked, the probe intensity will be rescaled on the first epoch', + ) + layout.addRow(self._rescale_probe_intensity_in_first_epoch_view_controller.get_widget()) + + self._widget.setLayout(layout) + + def get_widget(self) -> QWidget: + return self._widget diff --git a/src/ptychodus/controller/ptychonn.py b/src/ptychodus/controller/ptychonn.py new file mode 100644 index 00000000..0c5ca787 --- /dev/null +++ b/src/ptychodus/controller/ptychonn.py @@ -0,0 +1,51 @@ +from PyQt5.QtWidgets import QWidget + +from ..model.ptychonn import PtychoNNReconstructorLibrary +from .parametric import ParameterViewBuilder +from .reconstructor import ReconstructorViewControllerFactory + + +class PtychoNNViewControllerFactory(ReconstructorViewControllerFactory): + def __init__(self, model: PtychoNNReconstructorLibrary) -> None: + super().__init__() + self._model = model + + @property + def backend_name(self) -> str: + return 'PtychoNN' + + def create_view_controller(self, reconstructor_name: str) -> QWidget: + view_builder = ParameterViewBuilder() + + model_settings = self._model.model_settings + model_group = 'Model Parameters' + view_builder.add_spin_box( + model_settings.num_convolution_kernels, 'Convolution Kernels:', group=model_group + ) + view_builder.add_spin_box(model_settings.batch_size, 'Batch Size:', group=model_group) + view_builder.add_check_box( + model_settings.use_batch_normalization, 'Use Batch Normalization', group=model_group + ) + + training_settings = self._model.training_settings + training_group = 'Training Parameters' + + view_builder.add_decimal_slider( + training_settings.validation_set_fractional_size, + 'Validation Set Fractional Size:', + group=training_group, + ) + view_builder.add_decimal_line_edit( + training_settings.max_learning_rate, 'Max Learning Rate:', group=training_group + ) + view_builder.add_decimal_line_edit( + training_settings.min_learning_rate, 'Min Learning Rate:', group=training_group + ) + view_builder.add_spin_box( + training_settings.training_epochs, 'Training Epochs:', group=training_group + ) + view_builder.add_spin_box( + training_settings.status_interval_in_epochs, 'Status Interval:', group=training_group + ) + + return view_builder.build_widget() diff --git a/src/ptychodus/controller/ptychonn/__init__.py b/src/ptychodus/controller/ptychonn/__init__.py deleted file mode 100644 index 5338b35f..00000000 --- a/src/ptychodus/controller/ptychonn/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .factory import PtychoNNViewControllerFactory - -__all__ = [ - 'PtychoNNViewControllerFactory', -] diff --git a/src/ptychodus/controller/ptychonn/controller.py b/src/ptychodus/controller/ptychonn/controller.py deleted file mode 100644 index 47235840..00000000 --- a/src/ptychodus/controller/ptychonn/controller.py +++ /dev/null @@ -1,34 +0,0 @@ -from __future__ import annotations - -from ...model.ptychonn import PtychoNNModelPresenter, PtychoNNTrainingPresenter -from ...view.ptychonn import PtychoNNParametersView -from ..data import FileDialogFactory -from .model import PtychoNNModelParametersController -from .training import PtychoNNTrainingParametersController - - -class PtychoNNParametersController: - def __init__( - self, - modelPresenter: PtychoNNModelPresenter, - trainingPresenter: PtychoNNTrainingPresenter, - view: PtychoNNParametersView, - fileDialogFactory: FileDialogFactory, - ) -> None: - super().__init__() - self._modelParametersController = PtychoNNModelParametersController.createInstance( - modelPresenter, view.modelParametersView, fileDialogFactory - ) - self._trainingParametersController = PtychoNNTrainingParametersController.createInstance( - trainingPresenter, view.trainingParametersView, fileDialogFactory - ) - - @classmethod - def createInstance( - cls, - modelPresenter: PtychoNNModelPresenter, - trainingPresenter: PtychoNNTrainingPresenter, - view: PtychoNNParametersView, - fileDialogFactory: FileDialogFactory, - ) -> PtychoNNParametersController: - return cls(modelPresenter, trainingPresenter, view, fileDialogFactory) diff --git a/src/ptychodus/controller/ptychonn/factory.py b/src/ptychodus/controller/ptychonn/factory.py deleted file mode 100644 index 87c0358e..00000000 --- a/src/ptychodus/controller/ptychonn/factory.py +++ /dev/null @@ -1,35 +0,0 @@ -from ..reconstructor import ReconstructorViewControllerFactory - -from PyQt5.QtWidgets import QWidget - -from ...model.ptychonn import PtychoNNReconstructorLibrary -from ...view.ptychonn import PtychoNNParametersView -from ..data import FileDialogFactory -from .controller import PtychoNNParametersController - - -class PtychoNNViewControllerFactory(ReconstructorViewControllerFactory): - def __init__( - self, model: PtychoNNReconstructorLibrary, fileDialogFactory: FileDialogFactory - ) -> None: - super().__init__() - self._model = model - self._fileDialogFactory = fileDialogFactory - self._controllerList: list[PtychoNNParametersController] = list() - - @property - def backendName(self) -> str: - return 'PtychoNN' - - def createViewController(self, reconstructorName: str) -> QWidget: - view = PtychoNNParametersView.createInstance() - - controller = PtychoNNParametersController.createInstance( - self._model.modelPresenter, - self._model.trainingPresenter, - view, - self._fileDialogFactory, - ) - self._controllerList.append(controller) - - return view diff --git a/src/ptychodus/controller/ptychonn/model.py b/src/ptychodus/controller/ptychonn/model.py deleted file mode 100644 index d2fe704a..00000000 --- a/src/ptychodus/controller/ptychonn/model.py +++ /dev/null @@ -1,67 +0,0 @@ -from __future__ import annotations - -from ptychodus.api.observer import Observable, Observer - -from ...model.ptychonn import PtychoNNModelPresenter -from ...view.ptychonn import PtychoNNModelParametersView -from ..data import FileDialogFactory - - -class PtychoNNModelParametersController(Observer): - def __init__( - self, - presenter: PtychoNNModelPresenter, - view: PtychoNNModelParametersView, - fileDialogFactory: FileDialogFactory, - ) -> None: - super().__init__() - self._presenter = presenter - self._view = view - self._fileDialogFactory = fileDialogFactory - - @classmethod - def createInstance( - cls, - presenter: PtychoNNModelPresenter, - view: PtychoNNModelParametersView, - fileDialogFactory: FileDialogFactory, - ) -> PtychoNNModelParametersController: - controller = cls(presenter, view, fileDialogFactory) - presenter.addObserver(controller) - - view.numberOfConvolutionKernelsSpinBox.valueChanged.connect( - presenter.setNumberOfConvolutionKernels - ) - view.batchSizeSpinBox.valueChanged.connect(presenter.setBatchSize) - view.useBatchNormalizationCheckBox.toggled.connect(presenter.setBatchNormalizationEnabled) - - controller._syncModelToView() - - return controller - - def _syncModelToView(self) -> None: - self._view.numberOfConvolutionKernelsSpinBox.blockSignals(True) - self._view.numberOfConvolutionKernelsSpinBox.setRange( - self._presenter.getNumberOfConvolutionKernelsLimits().lower, - self._presenter.getNumberOfConvolutionKernelsLimits().upper, - ) - self._view.numberOfConvolutionKernelsSpinBox.setValue( - self._presenter.getNumberOfConvolutionKernels() - ) - self._view.numberOfConvolutionKernelsSpinBox.blockSignals(False) - - self._view.batchSizeSpinBox.blockSignals(True) - self._view.batchSizeSpinBox.setRange( - self._presenter.getBatchSizeLimits().lower, - self._presenter.getBatchSizeLimits().upper, - ) - self._view.batchSizeSpinBox.setValue(self._presenter.getBatchSize()) - self._view.batchSizeSpinBox.blockSignals(False) - - self._view.useBatchNormalizationCheckBox.setChecked( - self._presenter.isBatchNormalizationEnabled() - ) - - def update(self, observable: Observable) -> None: - if observable is self._presenter: - self._syncModelToView() diff --git a/src/ptychodus/controller/ptychonn/training.py b/src/ptychodus/controller/ptychonn/training.py deleted file mode 100644 index 33a011dc..00000000 --- a/src/ptychodus/controller/ptychonn/training.py +++ /dev/null @@ -1,78 +0,0 @@ -from __future__ import annotations -from decimal import Decimal -import logging - -from ptychodus.api.observer import Observable, Observer - -from ...model.ptychonn import PtychoNNTrainingPresenter -from ...view.ptychonn import PtychoNNTrainingParametersView -from ..data import FileDialogFactory - -logger = logging.getLogger(__name__) - - -class PtychoNNTrainingParametersController(Observer): - def __init__( - self, - presenter: PtychoNNTrainingPresenter, - view: PtychoNNTrainingParametersView, - fileDialogFactory: FileDialogFactory, - ) -> None: - super().__init__() - self._presenter = presenter - self._view = view - - @classmethod - def createInstance( - cls, - presenter: PtychoNNTrainingPresenter, - view: PtychoNNTrainingParametersView, - fileDialogFactory: FileDialogFactory, - ) -> PtychoNNTrainingParametersController: - controller = cls(presenter, view, fileDialogFactory) - presenter.addObserver(controller) - - view.validationSetFractionalSizeSlider.valueChanged.connect( - presenter.setValidationSetFractionalSize - ) - view.maximumLearningRateLineEdit.valueChanged.connect(presenter.setMaximumLearningRate) - view.minimumLearningRateLineEdit.valueChanged.connect(presenter.setMinimumLearningRate) - view.trainingEpochsSpinBox.valueChanged.connect(presenter.setTrainingEpochs) - view.statusIntervalSpinBox.valueChanged.connect(presenter.setStatusIntervalInEpochs) - - controller._syncModelToView() - - return controller - - def _syncModelToView(self) -> None: - self._view.validationSetFractionalSizeSlider.setValueAndRange( - self._presenter.getValidationSetFractionalSize(), - self._presenter.getValidationSetFractionalSizeLimits(), - blockValueChangedSignal=True, - ) - - self._view.maximumLearningRateLineEdit.setMinimum(Decimal()) - self._view.maximumLearningRateLineEdit.setValue(self._presenter.getMaximumLearningRate()) - - self._view.minimumLearningRateLineEdit.setMinimum(Decimal()) - self._view.minimumLearningRateLineEdit.setValue(self._presenter.getMinimumLearningRate()) - - self._view.trainingEpochsSpinBox.blockSignals(True) - self._view.trainingEpochsSpinBox.setRange( - self._presenter.getTrainingEpochsLimits().lower, - self._presenter.getTrainingEpochsLimits().upper, - ) - self._view.trainingEpochsSpinBox.setValue(self._presenter.getTrainingEpochs()) - self._view.trainingEpochsSpinBox.blockSignals(False) - - self._view.statusIntervalSpinBox.blockSignals(True) - self._view.statusIntervalSpinBox.setRange( - self._presenter.getStatusIntervalInEpochsLimits().lower, - self._presenter.getStatusIntervalInEpochsLimits().upper, - ) - self._view.statusIntervalSpinBox.setValue(self._presenter.getStatusIntervalInEpochs()) - self._view.statusIntervalSpinBox.blockSignals(False) - - def update(self, observable: Observable) -> None: - if observable is self._presenter: - self._syncModelToView() diff --git a/src/ptychodus/controller/ptychopinn.py b/src/ptychodus/controller/ptychopinn.py new file mode 100644 index 00000000..223146cf --- /dev/null +++ b/src/ptychodus/controller/ptychopinn.py @@ -0,0 +1,217 @@ +from PyQt5.QtGui import QValidator +from PyQt5.QtWidgets import QSpinBox, QWidget + +from ptychodus.api.observer import Observable, Observer +from ptychodus.api.parametric import IntegerParameter + +from ..model.ptychopinn.core import PtychoPINNReconstructorLibrary +from .data import FileDialogFactory +from .parametric import ParameterViewBuilder, ParameterViewController +from .reconstructor import ReconstructorViewControllerFactory + + +class PowerTwoSpinBox(QSpinBox): + def __init__(self, parent: QWidget | None = None) -> None: + super().__init__(parent) + + def stepBy(self, steps: int) -> None: # noqa: N802 + if steps < 0: + self.setValue(self.value() // (1 << -steps)) + elif steps > 0: + self.setValue(self.value() * (1 << steps)) + + def validate(self, input: str, pos: int) -> tuple[QValidator.State, str, int]: + try: + value = int(input) + except ValueError: + pass + else: + if value > 0: + is_pow2 = (value & (value - 1)) == 0 + + if is_pow2: + return QValidator.Acceptable, input, pos + + return QValidator.Intermediate, input, pos + + +class PowerTwoSpinBoxParameterViewController(ParameterViewController, Observer): + def __init__(self, parameter: IntegerParameter, *, tool_tip: str = '') -> None: + super().__init__() + self._parameter = parameter + self._widget = PowerTwoSpinBox() + + if tool_tip: + self._widget.setToolTip(tool_tip) + + self._sync_model_to_view() + self._widget.valueChanged.connect(parameter.set_value) + parameter.add_observer(self) + + def get_widget(self) -> QWidget: + return self._widget + + def _sync_model_to_view(self) -> None: + minimum = self._parameter.get_minimum() + maximum = self._parameter.get_maximum() + + if minimum is None: + raise ValueError('Minimum not provided!') + + if maximum is None: + raise ValueError('Maximum not provided!') + + self._widget.blockSignals(True) + self._widget.setRange(minimum, maximum) + self._widget.setValue(self._parameter.get_value()) + self._widget.blockSignals(False) + + def _update(self, observable: Observable) -> None: + if observable is self._parameter: + self._sync_model_to_view() + + +class PtychoPINNViewControllerFactory(ReconstructorViewControllerFactory): + def __init__( + self, model: PtychoPINNReconstructorLibrary, file_dialog_factory: FileDialogFactory + ) -> None: + super().__init__() + self._model = model + self._file_dialog_factory = file_dialog_factory + + @property + def backend_name(self) -> str: + return 'PtychoPINN' + + def create_view_controller(self, reconstructor_name: str) -> QWidget: + is_pinn = reconstructor_name.lower() == 'pinn' + builder = ParameterViewBuilder(self._file_dialog_factory) + enumerators = self._model.enumerators + + model_group = 'Model' + model_settings = self._model.model_settings + builder.add_spin_box( + model_settings.n_filters_scale, + 'Num. Filters Scale Factor:', + tool_tip='Scale factor for number of filters', + group=model_group, + ) + + if is_pinn: + builder.add_spin_box( + model_settings.gridsize, + 'Grid Size:', + tool_tip='Controls number of images processed per solution region (e.g., gridsize=2 means 2^2=4 images at a time)', + group=model_group, + ) + builder.add_combo_box( + model_settings.amp_activation, + enumerators.get_amp_activations(), + 'Amplitude Activation Function:', + group=model_group, + ) + builder.add_check_box( + model_settings.object_big, + 'Object Big', + group=model_group, + tool_tip='Enables a separate real-space reconstruction for each input diffraction image ' + 'and an averaging / overlap constraint step. If False, no explicit averaging is performed ' + 'and the decoders return a single real space image instead of `gridsize**2` images. ' + 'Typically left True.', + ) + builder.add_check_box( + model_settings.probe_big, + 'Probe Big', + group=model_group, + tool_tip='If True, enables a low-resolution reconstruction of the outer region of the NxN real-space grid. ' + 'This technically violates the zero-padding / oversampling condition, ' + 'but may be needed if the probe illumination has wide tails. ' + 'Has no effect unless pad_object is True.', + ) + builder.add_check_box( + model_settings.probe_mask, + 'Probe Mask', + group=model_group, + tool_tip='Whether to apply circular mask to the probe function. ' + "If toggling this changes the reconstruction, it's likely that there are edge / real space truncation artifacts. " + 'Should be used with pad_object = False.', + ) + builder.add_check_box( + model_settings.pad_object, + 'Pad Object', + group=model_group, + tool_tip='Whether to reconstruct the full real space grid (False) or restrict to N/2 x N/2 (True). ' + 'True strictly enforces the necessary reciprocal space oversampling, ' + 'but may cause truncation issues for probe amplitudes with long tails. ' + 'This truncation can be mitigated by setting probe_big, ' + 'which uses a small number of CNN filters to generate a low-resolution reconstruction of the outer region. ' + 'Typically left True.', + ) + builder.add_decimal_line_edit( + model_settings.probe_scale, + 'Probe Scale Factor:', + group=model_group, + tool_tip='Scaling factor for the probe amplitude. ', + ) + builder.add_decimal_line_edit( + model_settings.gaussian_smoothing_sigma, + 'Gaussian Smoothing Sigma:', + group=model_group, + tool_tip='Standard deviation for Gaussian smoothing of probe illumination. ' + 'Increase from 0 to reduce noise / artifacts at cost of resolution. ' + 'Beware that abusing this can cause convergence issues.', + ) + + inference_group = 'Inference' + inference_settings = self._model.inference_settings + builder.add_spin_box( + inference_settings.n_nearest_neighbors, 'Number of Neighbors:', group=inference_group + ) + builder.add_spin_box( + inference_settings.n_samples, 'Number of Samples:', group=inference_group + ) + + training_group = 'Training' + training_settings = self._model.training_settings + builder.add_view_controller( + PowerTwoSpinBoxParameterViewController(training_settings.batch_size), + 'Batch Size:', + group=training_group, + ) + builder.add_spin_box(training_settings.nepochs, 'Number of Epochs:', group=training_group) + + if is_pinn: + builder.add_decimal_slider( + training_settings.mae_weight, 'Weight for MAE loss:', group=training_group + ) + builder.add_decimal_slider( + training_settings.nll_weight, 'Weight for NLL loss:', group=training_group + ) + builder.add_decimal_slider( + training_settings.realspace_mae_weight, + 'Realspace MAE Weight:', + group=training_group, + ) + builder.add_decimal_slider( + training_settings.realspace_weight, 'Realspace Weight:', group=training_group + ) + builder.add_check_box( + training_settings.positions_provided, + 'Positions Provided', + group=training_group, + ) + builder.add_check_box( + training_settings.probe_trainable, + 'Probe Trainable', + group=training_group, + tool_tip='Optimizes the probe function during training. Experimental feature.', + ) + builder.add_check_box( + training_settings.intensity_scale_trainable, + 'Intensity Scale Trainable', + group=training_group, + tool_tip="Optimize the model's internal amplitude scaling factor during training. " + 'Typically left True.', + ) + + return builder.build_widget() diff --git a/src/ptychodus/controller/reconstructor.py b/src/ptychodus/controller/reconstructor.py index a9ca3d7c..9045d3c8 100644 --- a/src/ptychodus/controller/reconstructor.py +++ b/src/ptychodus/controller/reconstructor.py @@ -3,8 +3,8 @@ from collections.abc import Iterable, Sequence import logging -from PyQt5.QtCore import Qt, QAbstractItemModel -from PyQt5.QtWidgets import QLabel, QWidget +from PyQt5.QtCore import Qt, QAbstractItemModel, QTimer +from PyQt5.QtWidgets import QActionGroup, QLabel, QWidget from ptychodus.api.observer import Observable, Observer @@ -18,7 +18,7 @@ from ..model.product.probe import ProbeRepositoryItem from ..model.product.scan import ScanRepositoryItem from ..model.reconstructor import ReconstructorPresenter -from ..view.reconstructor import ReconstructorParametersView, ReconstructorPlotView +from ..view.reconstructor import ReconstructorView, ReconstructorPlotView from ..view.widgets import ExceptionDialog from .data import FileDialogFactory @@ -28,11 +28,11 @@ class ReconstructorViewControllerFactory(ABC): @property @abstractmethod - def backendName(self) -> str: + def backend_name(self) -> str: pass @abstractmethod - def createViewController(self, reconstructorName: str) -> QWidget: + def create_view_controller(self, reconstructor_name: str) -> QWidget: pass @@ -40,274 +40,262 @@ class ReconstructorController(ProductRepositoryObserver, Observer): def __init__( self, presenter: ReconstructorPresenter, - productRepository: ProductRepository, - view: ReconstructorParametersView, - plotView: ReconstructorPlotView, - fileDialogFactory: FileDialogFactory, - viewControllerFactoryList: Iterable[ReconstructorViewControllerFactory], + product_repository: ProductRepository, + view: ReconstructorView, + plot_view: ReconstructorPlotView, + product_table_model: QAbstractItemModel, + file_dialog_factory: FileDialogFactory, + view_controller_factories: Iterable[ReconstructorViewControllerFactory], ) -> None: super().__init__() self._presenter = presenter - self._productRepository = productRepository + self._product_repository = product_repository self._view = view - self._plotView = plotView - self._fileDialogFactory = fileDialogFactory - self._viewControllerFactoryDict: dict[str, ReconstructorViewControllerFactory] = { - vcf.backendName: vcf for vcf in viewControllerFactoryList + self._plot_view = plot_view + self._file_dialog_factory = file_dialog_factory + self._view_controller_factories: dict[str, ReconstructorViewControllerFactory] = { + vcf.backend_name: vcf for vcf in view_controller_factories } - @classmethod - def createInstance( - cls, - presenter: ReconstructorPresenter, - productRepository: ProductRepository, - view: ReconstructorParametersView, - plotView: ReconstructorPlotView, - fileDialogFactory: FileDialogFactory, - productTableModel: QAbstractItemModel, - viewControllerFactoryList: list[ReconstructorViewControllerFactory], - ) -> ReconstructorController: - controller = cls( - presenter, - productRepository, - view, - plotView, - fileDialogFactory, - viewControllerFactoryList, + for name in presenter.reconstructors(): + self._add_reconstructor(name) + + view.parameters_view.algorithm_combo_box.textActivated.connect(presenter.set_reconstructor) + view.parameters_view.algorithm_combo_box.currentIndexChanged.connect( + view.stacked_widget.setCurrentIndex ) - presenter.addObserver(controller) - productRepository.addObserver(controller) - for name in presenter.getReconstructorList(): - controller._addReconstructor(name) + view.parameters_view.product_combo_box.textActivated.connect(self._redraw_plot) + view.parameters_view.product_combo_box.setModel(product_table_model) - view.reconstructorView.algorithmComboBox.textActivated.connect(presenter.setReconstructor) - view.reconstructorView.algorithmComboBox.currentIndexChanged.connect( - view.stackedWidget.setCurrentIndex - ) + self._progress_timer = QTimer() + self._progress_timer.timeout.connect(self._update_progress) + self._progress_timer.start(5 * 1000) # TODO customize (in milliseconds) - view.reconstructorView.productComboBox.textActivated.connect(controller._redrawPlot) - view.reconstructorView.productComboBox.setModel(productTableModel) + view.progress_dialog.setModal(True) + view.progress_dialog.setWindowModality(Qt.ApplicationModal) + view.progress_dialog.setWindowFlags(Qt.Window | Qt.WindowTitleHint | Qt.CustomizeWindowHint) + view.progress_dialog.text_edit.setReadOnly(True) - openModelAction = view.reconstructorView.modelMenu.addAction('Open...') - openModelAction.triggered.connect(controller._openModel) - saveModelAction = view.reconstructorView.modelMenu.addAction('Save...') - saveModelAction.triggered.connect(controller._saveModel) + open_model_action = view.parameters_view.reconstructor_menu.addAction('Open Model...') + open_model_action.triggered.connect(self._open_model) + save_model_action = view.parameters_view.reconstructor_menu.addAction('Save Model...') + save_model_action.triggered.connect(self._save_model) - openTrainingDataAction = view.reconstructorView.trainerMenu.addAction( - 'Open Training Data...' - ) - openTrainingDataAction.triggered.connect(controller._openTrainingData) - saveTrainingDataAction = view.reconstructorView.trainerMenu.addAction( - 'Save Training Data...' - ) - saveTrainingDataAction.triggered.connect(controller._saveTrainingData) - ingestTrainingDataAction = view.reconstructorView.trainerMenu.addAction( - 'Ingest Training Data' + self._model_action_group = QActionGroup(view.parameters_view.reconstructor_menu) + self._model_action_group.setExclusive(False) + self._model_action_group.addAction(open_model_action) + self._model_action_group.addAction(save_model_action) + self._model_action_group.addAction(view.parameters_view.reconstructor_menu.addSeparator()) + + reconstruct_transformed_action = view.parameters_view.reconstructor_menu.addAction( + 'Reconstruct Transformed Points' ) - ingestTrainingDataAction.triggered.connect(controller._ingestTrainingData) - clearTrainingDataAction = view.reconstructorView.trainerMenu.addAction( - 'Clear Training Data' + reconstruct_transformed_action.triggered.connect(self._reconstruct_transformed) + reconstruct_split_action = view.parameters_view.reconstructor_menu.addAction( + 'Reconstruct Odd/Even Split' ) - clearTrainingDataAction.triggered.connect(controller._clearTrainingData) - view.reconstructorView.trainerMenu.addSeparator() - trainAction = view.reconstructorView.trainerMenu.addAction('Train') - trainAction.triggered.connect(controller._train) + reconstruct_split_action.triggered.connect(self._reconstruct_split) + reconstruct_action = view.parameters_view.reconstructor_menu.addAction('Reconstruct') + reconstruct_action.triggered.connect(self._reconstruct) - reconstructSplitAction = view.reconstructorView.reconstructorMenu.addAction( - 'Reconstruct Odd/Even Split' + export_training_data_action = view.parameters_view.trainer_menu.addAction( + 'Export Training Data...' ) - reconstructSplitAction.triggered.connect(controller._reconstructSplit) - reconstructAction = view.reconstructorView.reconstructorMenu.addAction('Reconstruct') - reconstructAction.triggered.connect(controller._reconstruct) + export_training_data_action.triggered.connect(self._export_training_data) + train_action = view.parameters_view.trainer_menu.addAction('Train') + train_action.triggered.connect(self._train) + + presenter.add_observer(self) + product_repository.add_observer(self) + self._sync_model_to_view() + + def _update_progress(self) -> None: + is_reconstructing = self._presenter.is_reconstructing + + for button in self._view.progress_dialog.button_box.buttons(): + button.setEnabled(not is_reconstructing) - controller._syncAlgorithmToView() + for text in self._presenter.flush_log(): + self._view.progress_dialog.text_edit.appendPlainText(text) - return controller + self._presenter.process_results(block=False) - def _addReconstructor(self, name: str) -> None: - backendName, reconstructorName = name.split('/') # TODO REDO - self._view.reconstructorView.algorithmComboBox.addItem( - name, self._view.reconstructorView.algorithmComboBox.count() + def _add_reconstructor(self, name: str) -> None: + backend_name, reconstructor_name = name.split('/') # TODO REDO + self._view.parameters_view.algorithm_combo_box.addItem( + name, self._view.parameters_view.algorithm_combo_box.count() ) - if backendName in self._viewControllerFactoryDict: - viewControllerFactory = self._viewControllerFactoryDict[backendName] - widget = viewControllerFactory.createViewController(reconstructorName) + if backend_name in self._view_controller_factories: + view_controller_factory = self._view_controller_factories[backend_name] + widget = view_controller_factory.create_view_controller(reconstructor_name) else: - widget = QLabel(f'{backendName} not found!') + widget = QLabel(f'{backend_name} not found!') widget.setAlignment(Qt.AlignmentFlag.AlignCenter) - self._view.stackedWidget.addWidget(widget) + self._view.stacked_widget.addWidget(widget) def _reconstruct(self) -> None: - outputProductName = self._presenter.getReconstructor() - inputProductIndex = self._view.reconstructorView.productComboBox.currentIndex() + input_product_index = self._view.parameters_view.product_combo_box.currentIndex() - if inputProductIndex < 0: + if input_product_index < 0: return try: - self._presenter.reconstruct(inputProductIndex, outputProductName) + output_product_index = self._presenter.reconstruct(input_product_index) except Exception as err: logger.exception(err) - ExceptionDialog.showException('Reconstructor', err) + ExceptionDialog.show_exception('Reconstructor', err) + else: + self._view.parameters_view.product_combo_box.setCurrentIndex(output_product_index) + self._view.progress_dialog.show() - def _reconstructSplit(self) -> None: - outputProductName = self._presenter.getReconstructor() - inputProductIndex = self._view.reconstructorView.productComboBox.currentIndex() + def _reconstruct_split(self) -> None: + input_product_index = self._view.parameters_view.product_combo_box.currentIndex() - if inputProductIndex < 0: + if input_product_index < 0: return try: - self._presenter.reconstructSplit(inputProductIndex, outputProductName) + self._presenter.reconstruct_split(input_product_index) except Exception as err: logger.exception(err) - ExceptionDialog.showException('Split Reconstructor', err) + ExceptionDialog.show_exception('Split Reconstructor', err) - def _openModel(self) -> None: - filePath, nameFilter = self._fileDialogFactory.getOpenFilePath( - self._view, - 'Open Model', - nameFilters=self._presenter.getOpenModelFileFilterList(), - selectedNameFilter=self._presenter.getOpenModelFileFilter(), + self._view.progress_dialog.show() + + def _reconstruct_transformed(self) -> None: + input_product_index = self._view.parameters_view.product_combo_box.currentIndex() + + if input_product_index < 0: + return + + try: + self._presenter.reconstruct_transformed(input_product_index) + except Exception as err: + logger.exception(err) + ExceptionDialog.show_exception('Split Reconstructor', err) + + self._view.progress_dialog.show() + + def _open_model(self) -> None: + name_filter = self._presenter.get_model_file_filter() + file_path, name_filter = self._file_dialog_factory.get_open_file_path( + self._view, 'Open Model', name_filters=[name_filter], selected_name_filter=name_filter ) - if filePath: + if file_path: try: - self._presenter.openModel(filePath) + self._presenter.open_model(file_path) except Exception as err: logger.exception(err) - ExceptionDialog.showException('Model Reader', err) + ExceptionDialog.show_exception('Model Reader', err) - def _saveModel(self) -> None: - filePath, _ = self._fileDialogFactory.getSaveFilePath( - self._view, - 'Save Model', - nameFilters=self._presenter.getSaveModelFileFilterList(), - selectedNameFilter=self._presenter.getSaveModelFileFilter(), + def _save_model(self) -> None: + name_filter = self._presenter.get_model_file_filter() + file_path, _ = self._file_dialog_factory.get_save_file_path( + self._view, 'Save Model', name_filters=[name_filter], selected_name_filter=name_filter ) - if filePath: + if file_path: try: - self._presenter.saveModel(filePath) + self._presenter.save_model(file_path) except Exception as err: logger.exception(err) - ExceptionDialog.showException('Model Writer', err) + ExceptionDialog.show_exception('Model Writer', err) + + def _export_training_data(self) -> None: + input_product_index = self._view.parameters_view.product_combo_box.currentIndex() + + if input_product_index < 0: + return - def _openTrainingData(self) -> None: - filePath, nameFilter = self._fileDialogFactory.getOpenFilePath( + name_filter = self._presenter.get_training_data_file_filter() + file_path, _ = self._file_dialog_factory.get_save_file_path( self._view, - 'Open Training Data', - nameFilters=self._presenter.getOpenTrainingDataFileFilterList(), - selectedNameFilter=self._presenter.getOpenTrainingDataFileFilter(), + 'Export Training Data', + name_filters=[name_filter], + selected_name_filter=name_filter, ) - if filePath: + if file_path: try: - self._presenter.openTrainingData(filePath) + self._presenter.export_training_data(file_path, input_product_index) except Exception as err: logger.exception(err) - ExceptionDialog.showException('Training Data Reader', err) + ExceptionDialog.show_exception('Training Data Writer', err) - def _saveTrainingData(self) -> None: - filePath, _ = self._fileDialogFactory.getSaveFilePath( + def _train(self) -> None: + data_path = self._file_dialog_factory.get_existing_directory_path( self._view, - 'Save Training Data', - nameFilters=self._presenter.getSaveTrainingDataFileFilterList(), - selectedNameFilter=self._presenter.getSaveTrainingDataFileFilter(), + 'Choose Training Data Directory', + initial_directory=self._presenter.get_training_data_path(), ) - if filePath: + if data_path: try: - self._presenter.saveTrainingData(filePath) + self._presenter.train(data_path) except Exception as err: logger.exception(err) - ExceptionDialog.showException('Training Data Writer', err) - - def _ingestTrainingData(self) -> None: - inputProductIndex = self._view.reconstructorView.productComboBox.currentIndex() - - if inputProductIndex < 0: - return - - try: - self._presenter.ingestTrainingData(inputProductIndex) - except Exception as err: - logger.exception(err) - ExceptionDialog.showException('Ingester', err) - - def _clearTrainingData(self) -> None: - try: - self._presenter.clearTrainingData() - except Exception as err: - logger.exception(err) - ExceptionDialog.showException('Clear', err) - - def _train(self) -> None: - try: - self._presenter.train() - except Exception as err: - logger.exception(err) - ExceptionDialog.showException('Trainer', err) + ExceptionDialog.show_exception('Trainer', err) - def _redrawPlot(self) -> None: - productIndex = self._view.reconstructorView.productComboBox.currentIndex() + def _redraw_plot(self) -> None: + product_index = self._view.parameters_view.product_combo_box.currentIndex() - if productIndex < 0: - self._plotView.axes.clear() + if product_index < 0: + self._plot_view.axes.clear() return try: - item = self._productRepository[productIndex] + item = self._product_repository[product_index] except IndexError as err: logger.exception(err) return - ax = self._plotView.axes + ax = self._plot_view.axes ax.clear() ax.set_xlabel('Iteration') ax.set_ylabel('Cost') ax.grid(True) - ax.plot(item.getCosts(), '.-', label='Cost', linewidth=1.5) - self._plotView.figureCanvas.draw() + ax.plot(item.get_costs(), '.-', label='Cost', linewidth=1.5) + self._plot_view.figure_canvas.draw() - def _syncAlgorithmToView(self) -> None: - self._view.reconstructorView.algorithmComboBox.setCurrentText( - self._presenter.getReconstructor() + def _sync_model_to_view(self) -> None: + self._view.parameters_view.algorithm_combo_box.setCurrentText( + self._presenter.get_reconstructor() ) - isTrainable = self._presenter.isTrainable - self._view.reconstructorView.modelButton.setVisible(isTrainable) - self._view.reconstructorView.trainerButton.setVisible(isTrainable) + is_trainable = self._presenter.is_trainable + self._model_action_group.setVisible(is_trainable) + self._view.parameters_view.trainer_button.setVisible(is_trainable) - self._redrawPlot() + self._redraw_plot() - def handleItemInserted(self, index: int, item: ProductRepositoryItem) -> None: + def handle_item_inserted(self, index: int, item: ProductRepositoryItem) -> None: pass - def handleMetadataChanged(self, index: int, item: MetadataRepositoryItem) -> None: + def handle_metadata_changed(self, index: int, item: MetadataRepositoryItem) -> None: pass - def handleScanChanged(self, index: int, item: ScanRepositoryItem) -> None: + def handle_scan_changed(self, index: int, item: ScanRepositoryItem) -> None: pass - def handleProbeChanged(self, index: int, item: ProbeRepositoryItem) -> None: + def handle_probe_changed(self, index: int, item: ProbeRepositoryItem) -> None: pass - def handleObjectChanged(self, index: int, item: ObjectRepositoryItem) -> None: + def handle_object_changed(self, index: int, item: ObjectRepositoryItem) -> None: pass - def handleCostsChanged(self, index: int, costs: Sequence[float]) -> None: - currentIndex = self._view.reconstructorView.productComboBox.currentIndex() + def handle_costs_changed(self, index: int, costs: Sequence[float]) -> None: + current_index = self._view.parameters_view.product_combo_box.currentIndex() - if index == currentIndex: - self._redrawPlot() + if index == current_index: + self._redraw_plot() - def handleItemRemoved(self, index: int, item: ProductRepositoryItem) -> None: + def handle_item_removed(self, index: int, item: ProductRepositoryItem) -> None: pass - def update(self, observable: Observable) -> None: + def _update(self, observable: Observable) -> None: if observable is self._presenter: - self._syncAlgorithmToView() + self._sync_model_to_view() diff --git a/src/ptychodus/controller/scan/core.py b/src/ptychodus/controller/scan/core.py index cdee2d7e..3486b3c3 100644 --- a/src/ptychodus/controller/scan/core.py +++ b/src/ptychodus/controller/scan/core.py @@ -2,7 +2,7 @@ import logging from PyQt5.QtCore import QModelIndex, QSortFilterProxyModel, QStringListModel -from PyQt5.QtWidgets import QAbstractItemView, QDialog +from PyQt5.QtWidgets import QAbstractItemView, QDialog, QMessageBox from ptychodus.api.observer import SequenceObserver @@ -12,8 +12,8 @@ from ...view.scan import ScanPlotView from ...view.widgets import ComboBoxItemDelegate, ExceptionDialog from ..data import FileDialogFactory -from .editorFactory import ScanEditorViewControllerFactory -from .tableModel import ScanTableModel +from .editor_factory import ScanEditorViewControllerFactory +from .table_model import ScanTableModel logger = logging.getLogger(__name__) @@ -24,201 +24,202 @@ def __init__( repository: ScanRepository, api: ScanAPI, view: RepositoryTableView, - plotView: ScanPlotView, - fileDialogFactory: FileDialogFactory, - tableModel: ScanTableModel, - tableProxyModel: QSortFilterProxyModel, + plot_view: ScanPlotView, + file_dialog_factory: FileDialogFactory, + *, + is_developer_mode_enabled: bool, ) -> None: super().__init__() self._repository = repository self._api = api self._view = view - self._plotView = plotView - self._fileDialogFactory = fileDialogFactory - self._tableModel = tableModel - self._tableProxyModel = tableProxyModel - self._editorFactory = ScanEditorViewControllerFactory() - - @classmethod - def createInstance( - cls, - repository: ScanRepository, - api: ScanAPI, - view: RepositoryTableView, - plotView: ScanPlotView, - fileDialogFactory: FileDialogFactory, - ) -> ScanController: - tableModel = ScanTableModel(repository, api) - tableProxyModel = QSortFilterProxyModel() - tableProxyModel.setSourceModel(tableModel) - controller = cls( - repository, api, view, plotView, fileDialogFactory, tableModel, tableProxyModel + self._plot_view = plot_view + self._file_dialog_factory = file_dialog_factory + self._table_model = ScanTableModel(repository, api) + self._table_proxy_model = QSortFilterProxyModel() + self._editor_factory = ScanEditorViewControllerFactory() + + self._table_proxy_model.setSourceModel(self._table_model) + self._table_proxy_model.dataChanged.connect( + lambda top_left, bottom_right, roles: self._redraw_plot() ) - tableProxyModel.dataChanged.connect( - lambda topLeft, bottomRight, roles: controller._redrawPlot() - ) - repository.addObserver(controller) + repository.add_observer(self) - builderListModel = QStringListModel() - builderListModel.setStringList([name for name in api.builderNames()]) - builderItemDelegate = ComboBoxItemDelegate(builderListModel, view.tableView) + builder_list_model = QStringListModel() + builder_list_model.setStringList([name for name in api.builder_names()]) + builder_item_delegate = ComboBoxItemDelegate(builder_list_model, view.table_view) - view.tableView.setModel(tableProxyModel) - view.tableView.setSortingEnabled(True) - view.tableView.setSelectionBehavior(QAbstractItemView.SelectionBehavior.SelectRows) - view.tableView.setItemDelegateForColumn(2, builderItemDelegate) - view.tableView.selectionModel().currentChanged.connect(controller._updateView) - controller._updateView(QModelIndex(), QModelIndex()) + view.table_view.setModel(self._table_proxy_model) + view.table_view.setSortingEnabled(True) + view.table_view.setSelectionBehavior(QAbstractItemView.SelectionBehavior.SelectRows) + view.table_view.setItemDelegateForColumn(2, builder_item_delegate) + view.table_view.selectionModel().currentChanged.connect(self._update_view) + self._update_view(QModelIndex(), QModelIndex()) - view.tableView.horizontalHeader().sectionClicked.connect( - lambda logicalIndex: controller._redrawPlot() + view.table_view.horizontalHeader().sectionClicked.connect( + lambda logical_index: self._redraw_plot() ) - loadFromFileAction = view.buttonBox.loadMenu.addAction('Open File...') - loadFromFileAction.triggered.connect(controller._loadCurrentScanFromFile) + load_from_file_action = view.button_box.load_menu.addAction('Open File...') + load_from_file_action.triggered.connect(self._load_current_scan_from_file) - copyAction = view.buttonBox.loadMenu.addAction('Copy...') - copyAction.triggered.connect(controller._copyToCurrentScan) + copy_action = view.button_box.load_menu.addAction('Copy...') + copy_action.triggered.connect(self._copy_to_current_scan) - saveToFileAction = view.buttonBox.saveMenu.addAction('Save File...') - saveToFileAction.triggered.connect(controller._saveCurrentScanToFile) + save_to_file_action = view.button_box.save_menu.addAction('Save File...') + save_to_file_action.triggered.connect(self._save_current_scan_to_file) - syncToSettingsAction = view.buttonBox.saveMenu.addAction('Sync To Settings') - syncToSettingsAction.triggered.connect(controller._syncCurrentScanToSettings) + sync_to_settings_action = view.button_box.save_menu.addAction('Sync To Settings') + sync_to_settings_action.triggered.connect(self._sync_current_scan_to_settings) - view.copierDialog.setWindowTitle('Copy Scan') - view.copierDialog.sourceComboBox.setModel(tableModel) - view.copierDialog.destinationComboBox.setModel(tableModel) - view.copierDialog.finished.connect(controller._finishCopyingScan) + view.copier_dialog.setWindowTitle('Copy Scan') + view.copier_dialog.source_combo_box.setModel(self._table_model) + view.copier_dialog.destination_combo_box.setModel(self._table_model) + view.copier_dialog.finished.connect(self._finish_copying_scan) - view.buttonBox.editButton.clicked.connect(controller._editCurrentScan) + view.button_box.edit_button.clicked.connect(self._edit_current_scan) - return controller + estimate_transform_action = view.button_box.analyze_menu.addAction('Estimate Transform...') + estimate_transform_action.triggered.connect(self._estimate_transform) + estimate_transform_action.setEnabled(is_developer_mode_enabled) - def _getCurrentItemIndex(self) -> int: - proxyIndex = self._view.tableView.currentIndex() + def _get_current_item_index(self) -> int: + proxy_index = self._view.table_view.currentIndex() - if proxyIndex.isValid(): - modelIndex = self._tableProxyModel.mapToSource(proxyIndex) - return modelIndex.row() + if proxy_index.isValid(): + model_index = self._table_proxy_model.mapToSource(proxy_index) + return model_index.row() logger.warning('No current index!') return -1 - def _loadCurrentScanFromFile(self) -> None: - itemIndex = self._getCurrentItemIndex() + def _load_current_scan_from_file(self) -> None: + item_index = self._get_current_item_index() - if itemIndex < 0: + if item_index < 0: return - filePath, nameFilter = self._fileDialogFactory.getOpenFilePath( + file_path, name_filter = self._file_dialog_factory.get_open_file_path( self._view, 'Open Scan', - nameFilters=self._api.getOpenFileFilterList(), - selectedNameFilter=self._api.getOpenFileFilter(), + name_filters=[nf for nf in self._api.get_open_file_filters()], + selected_name_filter=self._api.get_open_file_filter(), ) - if filePath: + if file_path: try: - self._api.openScan(itemIndex, filePath, fileType=nameFilter) + self._api.open_scan(item_index, file_path, file_type=name_filter) except Exception as err: logger.exception(err) - ExceptionDialog.showException('File Reader', err) + ExceptionDialog.show_exception('File Reader', err) - def _copyToCurrentScan(self) -> None: - itemIndex = self._getCurrentItemIndex() + def _copy_to_current_scan(self) -> None: + item_index = self._get_current_item_index() - if itemIndex >= 0: - self._view.copierDialog.destinationComboBox.setCurrentIndex(itemIndex) - self._view.copierDialog.open() + if item_index >= 0: + self._view.copier_dialog.destination_combo_box.setCurrentIndex(item_index) + self._view.copier_dialog.open() - def _finishCopyingScan(self, result: int) -> None: + def _finish_copying_scan(self, result: int) -> None: if result == QDialog.DialogCode.Accepted: - sourceIndex = self._view.copierDialog.sourceComboBox.currentIndex() - destinationIndex = self._view.copierDialog.destinationComboBox.currentIndex() - self._api.copyScan(sourceIndex, destinationIndex) + source_index = self._view.copier_dialog.source_combo_box.currentIndex() + destination_index = self._view.copier_dialog.destination_combo_box.currentIndex() + self._api.copy_scan(source_index, destination_index) - def _editCurrentScan(self) -> None: - itemIndex = self._getCurrentItemIndex() + def _edit_current_scan(self) -> None: + item_index = self._get_current_item_index() - if itemIndex < 0: + if item_index < 0: return - itemName = self._repository.getName(itemIndex) - item = self._repository[itemIndex] - dialog = self._editorFactory.createEditorDialog(itemName, item, self._view) + item_name = self._repository.get_name(item_index) + item = self._repository[item_index] + dialog = self._editor_factory.create_editor_dialog(item_name, item, self._view) dialog.open() - def _saveCurrentScanToFile(self) -> None: - itemIndex = self._getCurrentItemIndex() + def _save_current_scan_to_file(self) -> None: + item_index = self._get_current_item_index() - if itemIndex < 0: + if item_index < 0: return - filePath, nameFilter = self._fileDialogFactory.getSaveFilePath( + file_path, name_filter = self._file_dialog_factory.get_save_file_path( self._view, 'Save Scan', - nameFilters=self._api.getSaveFileFilterList(), - selectedNameFilter=self._api.getSaveFileFilter(), + name_filters=[nameFilter for nameFilter in self._api.get_save_file_filters()], + selected_name_filter=self._api.get_save_file_filter(), ) - if filePath: + if file_path: try: - self._api.saveScan(itemIndex, filePath, nameFilter) + self._api.save_scan(item_index, file_path, name_filter) except Exception as err: logger.exception(err) - ExceptionDialog.showException('File Writer', err) + ExceptionDialog.show_exception('File Writer', err) - def _syncCurrentScanToSettings(self) -> None: - itemIndex = self._getCurrentItemIndex() + def _sync_current_scan_to_settings(self) -> None: + item_index = self._get_current_item_index() - if itemIndex < 0: + if item_index < 0: logger.warning('No current item!') else: - item = self._repository[itemIndex] - item.syncToSettings() + item = self._repository[item_index] + item.sync_to_settings() + + def _redraw_plot(self) -> None: + self._plot_view.axes.clear() + + for row in range(self._table_proxy_model.rowCount()): + proxy_index = self._table_proxy_model.index(row, 0) + item_index = self._table_proxy_model.mapToSource(proxy_index).row() - def _redrawPlot(self) -> None: - self._plotView.axes.clear() + if self._table_model.is_item_checked(item_index): + item_name = self._repository.get_name(item_index) + scan = self._repository[item_index].get_scan() + x = [point.position_x_m for point in scan] + y = [point.position_y_m for point in scan] + self._plot_view.axes.plot(x, y, '.-', label=item_name, linewidth=1.5) - for row in range(self._tableProxyModel.rowCount()): - proxyIndex = self._tableProxyModel.index(row, 0) - itemIndex = self._tableProxyModel.mapToSource(proxyIndex).row() + self._plot_view.axes.invert_yaxis() + self._plot_view.axes.axis('equal') + self._plot_view.axes.grid(True) + self._plot_view.axes.set_xlabel('X [m]') + self._plot_view.axes.set_ylabel('Y [m]') - if self._tableModel.isItemChecked(itemIndex): - itemName = self._repository.getName(itemIndex) - scan = self._repository[itemIndex].getScan() - x = [point.positionXInMeters for point in scan] - y = [point.positionYInMeters for point in scan] - self._plotView.axes.plot(x, y, '.-', label=itemName, linewidth=1.5) + if len(self._plot_view.axes.lines) > 0: + self._plot_view.axes.legend(loc='best') - self._plotView.axes.invert_yaxis() - self._plotView.axes.axis('equal') - self._plotView.axes.grid(True) - self._plotView.axes.set_xlabel('X [m]') - self._plotView.axes.set_ylabel('Y [m]') + self._plot_view.figure_canvas.draw() - if len(self._plotView.axes.lines) > 0: - self._plotView.axes.legend(loc='best') + def _estimate_transform(self) -> None: # TODO + item_index = self._get_current_item_index() - self._plotView.figureCanvas.draw() + if item_index < 0: + logger.warning('No current item!') + return + + _ = QMessageBox.information( + self._view, + 'Not Implemented', + 'Affine transform estimator is not yet implemented.', + ) - def _updateView(self, current: QModelIndex, previous: QModelIndex) -> None: + def _update_view(self, current: QModelIndex, previous: QModelIndex) -> None: enabled = current.isValid() - self._view.buttonBox.loadButton.setEnabled(enabled) - self._view.buttonBox.saveButton.setEnabled(enabled) - self._view.buttonBox.editButton.setEnabled(enabled) - self._view.buttonBox.analyzeButton.setEnabled(enabled) - self._redrawPlot() + self._view.button_box.load_button.setEnabled(enabled) + self._view.button_box.save_button.setEnabled(enabled) + self._view.button_box.edit_button.setEnabled(enabled) + self._view.button_box.analyze_button.setEnabled(enabled) + self._redraw_plot() - def handleItemInserted(self, index: int, item: ScanRepositoryItem) -> None: - self._tableModel.insertItem(index, item) + def handle_item_inserted(self, index: int, item: ScanRepositoryItem) -> None: + self._table_model.insert_item(index, item) - def handleItemChanged(self, index: int, item: ScanRepositoryItem) -> None: - self._tableModel.updateItem(index, item) + def handle_item_changed(self, index: int, item: ScanRepositoryItem) -> None: + self._table_model.update_item(index, item) - if self._tableModel.isItemChecked(index): - self._redrawPlot() + if self._table_model.is_item_checked(index): + self._redraw_plot() - def handleItemRemoved(self, index: int, item: ScanRepositoryItem) -> None: - self._tableModel.removeItem(index, item) + def handle_item_removed(self, index: int, item: ScanRepositoryItem) -> None: + self._table_model.remove_item(index, item) diff --git a/src/ptychodus/controller/scan/editorFactory.py b/src/ptychodus/controller/scan/editorFactory.py deleted file mode 100644 index 39e3b6ef..00000000 --- a/src/ptychodus/controller/scan/editorFactory.py +++ /dev/null @@ -1,245 +0,0 @@ -from __future__ import annotations - -from PyQt5.QtCore import Qt -from PyQt5.QtWidgets import ( - QDialog, - QFormLayout, - QGridLayout, - QGroupBox, - QLabel, - QMessageBox, - QWidget, -) - -from ptychodus.api.observer import Observable, Observer - -from ...model.product.scan import ( - CartesianScanBuilder, - ConcentricScanBuilder, - FromFileScanBuilder, - FromMemoryScanBuilder, - LissajousScanBuilder, - ScanPointTransform, - ScanRepositoryItem, - SpiralScanBuilder, -) -from ..parametric import ( - DecimalLineEditParameterViewController, - LengthWidgetParameterViewController, - ParameterViewBuilder, - ParameterViewController, -) -from ...view.widgets import GroupBoxWithPresets - -__all__ = [ - 'ScanEditorViewControllerFactory', -] - - -class ScanTransformViewController(ParameterViewController): - def __init__(self, transform: ScanPointTransform) -> None: - super().__init__() - self._widget = GroupBoxWithPresets('Transformation') - - for index, presetsLabel in enumerate(transform.labelsForPresets()): - action = self._widget.presetsMenu.addAction(presetsLabel) - action.triggered.connect(lambda _, index=index: transform.applyPresets(index)) - - self._labelXP = QLabel('x\u2032 =') - self._labelXP.setAlignment(Qt.AlignRight | Qt.AlignVCenter) - self._affineAXViewController = DecimalLineEditParameterViewController( - transform.affineAX, is_signed=True - ) - self._labelAX = QLabel('x +') - self._affineAYViewController = DecimalLineEditParameterViewController( - transform.affineAY, is_signed=True - ) - self._labelAY = QLabel('y +') - self._affineATViewController = LengthWidgetParameterViewController( - transform.affineATInMeters, is_signed=True - ) - - self._labelYP = QLabel('y\u2032 =') - self._labelYP.setAlignment(Qt.AlignRight | Qt.AlignVCenter) - self._affineBXViewController = DecimalLineEditParameterViewController( - transform.affineBX, is_signed=True - ) - self._labelBX = QLabel('x +') - self._affineBYViewController = DecimalLineEditParameterViewController( - transform.affineBY, is_signed=True - ) - self._labelBY = QLabel('y +') - self._affineBTViewController = LengthWidgetParameterViewController( - transform.affineBTInMeters, is_signed=True - ) - - self._jitterRadiusLabel = QLabel('Jitter Radius:') - self._jitterRadiusViewController = LengthWidgetParameterViewController( - transform.jitterRadiusInMeters, is_signed=False - ) - - layout = QGridLayout() - layout.setContentsMargins(0, 0, 0, 0) - layout.addWidget(self._labelXP, 0, 0) - layout.addWidget(self._affineAXViewController.getWidget(), 0, 1) - layout.addWidget(self._labelAX, 0, 2) - layout.addWidget(self._affineAYViewController.getWidget(), 0, 3) - layout.addWidget(self._labelAY, 0, 4) - layout.addWidget(self._affineATViewController.getWidget(), 0, 5) - - layout.addWidget(self._labelYP, 1, 0) - layout.addWidget(self._affineBXViewController.getWidget(), 1, 1) - layout.addWidget(self._labelBX, 1, 2) - layout.addWidget(self._affineBYViewController.getWidget(), 1, 3) - layout.addWidget(self._labelBY, 1, 4) - layout.addWidget(self._affineBTViewController.getWidget(), 1, 5) - - layout.addWidget(self._jitterRadiusLabel, 2, 0) - layout.addWidget(self._jitterRadiusViewController.getWidget(), 2, 1, 1, 5) - self._widget.contents.setLayout(layout) - - def getWidget(self) -> QWidget: - return self._widget - - -class ScanBoundingBoxViewController(ParameterViewController, Observer): - def __init__(self, item: ScanRepositoryItem) -> None: - super().__init__() - self._parameter = item.expandBoundingBox - self._widget = QGroupBox('Expand Bounding Box') - self._widget.setCheckable(True) - - self._minimumXController = LengthWidgetParameterViewController( - item.expandedBoundingBoxMinimumXInMeters, is_signed=True - ) - self._maximumXController = LengthWidgetParameterViewController( - item.expandedBoundingBoxMaximumXInMeters, is_signed=True - ) - self._minimumYController = LengthWidgetParameterViewController( - item.expandedBoundingBoxMinimumYInMeters, is_signed=True - ) - self._maximumYController = LengthWidgetParameterViewController( - item.expandedBoundingBoxMaximumYInMeters, is_signed=True - ) - - layout = QFormLayout() - layout.addRow('Minimum X:', self._minimumXController.getWidget()) - layout.addRow('Maximum X:', self._maximumXController.getWidget()) - layout.addRow('Minimum Y:', self._minimumYController.getWidget()) - layout.addRow('Maximum Y:', self._maximumYController.getWidget()) - self._widget.setLayout(layout) - - self._syncModelToView() - self._widget.toggled.connect(self._parameter.setValue) - self._parameter.addObserver(self) - - def getWidget(self) -> QWidget: - return self._widget - - def _syncModelToView(self) -> None: - self._widget.setChecked(self._parameter.getValue()) - - def update(self, observable: Observable) -> None: - if observable is self._parameter: - self._syncModelToView() - - -class ScanEditorViewControllerFactory: - def _appendCommonControls( - self, dialogBuilder: ParameterViewBuilder, item: ScanRepositoryItem - ) -> None: - dialogBuilder.addViewControllerToBottom(ScanTransformViewController(item.getTransform())) - dialogBuilder.addViewControllerToBottom(ScanBoundingBoxViewController(item)) - - def createEditorDialog( - self, itemName: str, item: ScanRepositoryItem, parent: QWidget - ) -> QDialog: - scanBuilder = item.getBuilder() - builderName = scanBuilder.getName() - baseScanGroup = 'Base Scan' - title = f'{itemName} [{builderName}]' - - if isinstance(scanBuilder, CartesianScanBuilder): - dialogBuilder = ParameterViewBuilder() - dialogBuilder.addSpinBox( - scanBuilder.numberOfPointsX, 'Number of Points X:', group=baseScanGroup - ) - dialogBuilder.addSpinBox( - scanBuilder.numberOfPointsY, 'Number of Points Y:', group=baseScanGroup - ) - dialogBuilder.addLengthWidget( - scanBuilder.stepSizeXInMeters, 'Step Size X:', group=baseScanGroup - ) - - if not scanBuilder.isEquilateral: - dialogBuilder.addLengthWidget( - scanBuilder.stepSizeYInMeters, 'Step Size Y:', group=baseScanGroup - ) - - self._appendCommonControls(dialogBuilder, item) - return dialogBuilder.buildDialog(title, parent) - elif isinstance(scanBuilder, ConcentricScanBuilder): - dialogBuilder = ParameterViewBuilder() - dialogBuilder.addSpinBox( - scanBuilder.numberOfShells, 'Number of Shells:', group=baseScanGroup - ) - dialogBuilder.addSpinBox( - scanBuilder.numberOfPointsInFirstShell, - 'Number of Points in First Shell:', - group=baseScanGroup, - ) - dialogBuilder.addLengthWidget( - scanBuilder.radialStepSizeInMeters, - 'Radial Step Size:', - group=baseScanGroup, - ) - self._appendCommonControls(dialogBuilder, item) - return dialogBuilder.buildDialog(title, parent) - elif isinstance(scanBuilder, FromFileScanBuilder): - dialogBuilder = ParameterViewBuilder() - self._appendCommonControls(dialogBuilder, item) - return dialogBuilder.buildDialog(title, parent) - elif isinstance(scanBuilder, FromMemoryScanBuilder): - dialogBuilder = ParameterViewBuilder() - self._appendCommonControls(dialogBuilder, item) - return dialogBuilder.buildDialog(title, parent) - elif isinstance(scanBuilder, SpiralScanBuilder): - dialogBuilder = ParameterViewBuilder() - dialogBuilder.addSpinBox( - scanBuilder.numberOfPoints, 'Number of Points:', group=baseScanGroup - ) - dialogBuilder.addLengthWidget( - scanBuilder.radiusScalarInMeters, 'Radius Scalar:', group=baseScanGroup - ) - self._appendCommonControls(dialogBuilder, item) - return dialogBuilder.buildDialog(title, parent) - elif isinstance(scanBuilder, LissajousScanBuilder): - dialogBuilder = ParameterViewBuilder() - dialogBuilder.addSpinBox( - scanBuilder.numberOfPoints, 'Number of Points:', group=baseScanGroup - ) - dialogBuilder.addLengthWidget( - scanBuilder.amplitudeXInMeters, 'Amplitude X:', group=baseScanGroup - ) - dialogBuilder.addLengthWidget( - scanBuilder.amplitudeYInMeters, 'Amplitude Y:', group=baseScanGroup - ) - dialogBuilder.addAngleWidget( - scanBuilder.angularStepXInTurns, 'Angular Step X:', group=baseScanGroup - ) - dialogBuilder.addAngleWidget( - scanBuilder.angularStepYInTurns, 'Angular Step Y:', group=baseScanGroup - ) - dialogBuilder.addAngleWidget( - scanBuilder.angularShiftInTurns, 'Angular Shift:', group=baseScanGroup - ) - self._appendCommonControls(dialogBuilder, item) - return dialogBuilder.buildDialog(title, parent) - - return QMessageBox( - QMessageBox.Icon.Information, - title, - f'"{builderName}" has no editable parameters!', - QMessageBox.Ok, - parent, - ) diff --git a/src/ptychodus/controller/scan/editor_factory.py b/src/ptychodus/controller/scan/editor_factory.py new file mode 100644 index 00000000..e59afecb --- /dev/null +++ b/src/ptychodus/controller/scan/editor_factory.py @@ -0,0 +1,229 @@ +from __future__ import annotations + +from PyQt5.QtCore import Qt +from PyQt5.QtWidgets import ( + QDialog, + QFormLayout, + QGridLayout, + QLabel, + QMessageBox, + QWidget, +) + +from ...model.product.scan import ( + CartesianScanBuilder, + ConcentricScanBuilder, + FromFileScanBuilder, + FromMemoryScanBuilder, + LissajousScanBuilder, + ScanPointTransform, + ScanRepositoryItem, + SpiralScanBuilder, +) +from ..parametric import ( + CheckableGroupBoxParameterViewController, + DecimalLineEditParameterViewController, + LengthWidgetParameterViewController, + ParameterViewBuilder, + ParameterViewController, +) +from ...view.widgets import GroupBoxWithPresets + +__all__ = [ + 'ScanEditorViewControllerFactory', +] + + +class ScanTransformViewController(ParameterViewController): + def __init__(self, transform: ScanPointTransform) -> None: + super().__init__() + self._widget = GroupBoxWithPresets('Transformation') + + for index, presets_label in enumerate(transform.labels_for_presets()): + action = self._widget.presets_menu.addAction(presets_label) + action.triggered.connect(lambda _, index=index: transform.apply_presets(index)) + + self._label_ye = QLabel('y\u2032 =') + self._label_ye.setAlignment(Qt.AlignRight | Qt.AlignVCenter) + self._affine00_view_controller = DecimalLineEditParameterViewController( + transform.affine00, is_signed=True + ) + self._label_yp0 = QLabel('y +') + self._affine01_view_controller = DecimalLineEditParameterViewController( + transform.affine01, is_signed=True + ) + self._label_xp0 = QLabel('x +') + self._affine02_view_controller = LengthWidgetParameterViewController( + transform.affine02, is_signed=True + ) + + self._label_xe = QLabel('x\u2032 =') + self._label_xe.setAlignment(Qt.AlignRight | Qt.AlignVCenter) + self._affine10_view_controller = DecimalLineEditParameterViewController( + transform.affine10, is_signed=True + ) + self._label_yp1 = QLabel('y +') + self._affine11_view_controller = DecimalLineEditParameterViewController( + transform.affine11, is_signed=True + ) + self._label_xp1 = QLabel('x +') + self._affine12_view_controller = LengthWidgetParameterViewController( + transform.affine12, is_signed=True + ) + + self._jitter_radius_label = QLabel('Jitter Radius:') + self._jitter_radius_view_controller = LengthWidgetParameterViewController( + transform.jitter_radius_m, is_signed=False + ) + + layout = QGridLayout() + layout.setContentsMargins(0, 0, 0, 0) + layout.addWidget(self._label_ye, 0, 0) + layout.addWidget(self._affine00_view_controller.get_widget(), 0, 1) + layout.addWidget(self._label_yp0, 0, 2) + layout.addWidget(self._affine01_view_controller.get_widget(), 0, 3) + layout.addWidget(self._label_xp0, 0, 4) + layout.addWidget(self._affine02_view_controller.get_widget(), 0, 5) + + layout.addWidget(self._label_xe, 1, 0) + layout.addWidget(self._affine10_view_controller.get_widget(), 1, 1) + layout.addWidget(self._label_yp1, 1, 2) + layout.addWidget(self._affine11_view_controller.get_widget(), 1, 3) + layout.addWidget(self._label_xp1, 1, 4) + layout.addWidget(self._affine12_view_controller.get_widget(), 1, 5) + + layout.addWidget(self._jitter_radius_label, 2, 0) + layout.addWidget(self._jitter_radius_view_controller.get_widget(), 2, 1, 1, 5) + self._widget.contents.setLayout(layout) + + def get_widget(self) -> QWidget: + return self._widget + + +class ScanBoundingBoxViewController(CheckableGroupBoxParameterViewController): + def __init__(self, item: ScanRepositoryItem) -> None: + super().__init__(item.expand_bbox, 'Expand Bounding Box') + self._xmin_controller = LengthWidgetParameterViewController( + item.expand_bbox_xmin_m, is_signed=True + ) + self._xmax_controller = LengthWidgetParameterViewController( + item.expand_bbox_xmax_m, is_signed=True + ) + self._ymin_controller = LengthWidgetParameterViewController( + item.expand_bbox_ymin_m, is_signed=True + ) + self._ymax_controller = LengthWidgetParameterViewController( + item.expand_bbox_ymax_m, is_signed=True + ) + + layout = QFormLayout() + layout.addRow('Minimum X:', self._xmin_controller.get_widget()) + layout.addRow('Maximum X:', self._xmax_controller.get_widget()) + layout.addRow('Minimum Y:', self._ymin_controller.get_widget()) + layout.addRow('Maximum Y:', self._ymax_controller.get_widget()) + self.get_widget().setLayout(layout) + + +class ScanEditorViewControllerFactory: + def _append_common_controls( + self, dialog_builder: ParameterViewBuilder, item: ScanRepositoryItem + ) -> None: + transform = item.get_transform() + + if transform is not None: + dialog_builder.add_view_controller_to_bottom(ScanTransformViewController(transform)) + + dialog_builder.add_view_controller_to_bottom(ScanBoundingBoxViewController(item)) + + def create_editor_dialog( + self, item_name: str, item: ScanRepositoryItem, parent: QWidget + ) -> QDialog: + scan_builder = item.get_builder() + builder_name = scan_builder.get_name() + base_scan_group = 'Base Scan' + title = f'{item_name} [{builder_name}]' + + if isinstance(scan_builder, CartesianScanBuilder): + dialog_builder = ParameterViewBuilder() + dialog_builder.add_spin_box( + scan_builder.num_points_x, 'Number of Points X:', group=base_scan_group + ) + dialog_builder.add_spin_box( + scan_builder.num_points_y, 'Number of Points Y:', group=base_scan_group + ) + dialog_builder.add_length_widget( + scan_builder.step_size_x_m, 'Step Size X:', group=base_scan_group + ) + + if not scan_builder.is_equilateral: + dialog_builder.add_length_widget( + scan_builder.step_size_y_m, 'Step Size Y:', group=base_scan_group + ) + + self._append_common_controls(dialog_builder, item) + return dialog_builder.build_dialog(title, parent) + elif isinstance(scan_builder, ConcentricScanBuilder): + dialog_builder = ParameterViewBuilder() + dialog_builder.add_spin_box( + scan_builder.num_shells, 'Number of Shells:', group=base_scan_group + ) + dialog_builder.add_spin_box( + scan_builder.num_points_1st_shell, + 'Number of Points in First Shell:', + group=base_scan_group, + ) + dialog_builder.add_length_widget( + scan_builder.radial_step_size_m, + 'Radial Step Size:', + group=base_scan_group, + ) + self._append_common_controls(dialog_builder, item) + return dialog_builder.build_dialog(title, parent) + elif isinstance(scan_builder, FromFileScanBuilder): + dialog_builder = ParameterViewBuilder() + self._append_common_controls(dialog_builder, item) + return dialog_builder.build_dialog(title, parent) + elif isinstance(scan_builder, FromMemoryScanBuilder): + dialog_builder = ParameterViewBuilder() + self._append_common_controls(dialog_builder, item) + return dialog_builder.build_dialog(title, parent) + elif isinstance(scan_builder, SpiralScanBuilder): + dialog_builder = ParameterViewBuilder() + dialog_builder.add_spin_box( + scan_builder.num_points, 'Number of Points:', group=base_scan_group + ) + dialog_builder.add_length_widget( + scan_builder.radius_scalar_m, 'Radius Scalar:', group=base_scan_group + ) + self._append_common_controls(dialog_builder, item) + return dialog_builder.build_dialog(title, parent) + elif isinstance(scan_builder, LissajousScanBuilder): + dialog_builder = ParameterViewBuilder() + dialog_builder.add_spin_box( + scan_builder.num_points, 'Number of Points:', group=base_scan_group + ) + dialog_builder.add_length_widget( + scan_builder.amplitude_x_m, 'Amplitude X:', group=base_scan_group + ) + dialog_builder.add_length_widget( + scan_builder.amplitude_y_m, 'Amplitude Y:', group=base_scan_group + ) + dialog_builder.add_angle_widget( + scan_builder.angular_step_x_turns, 'Angular Step X:', group=base_scan_group + ) + dialog_builder.add_angle_widget( + scan_builder.angular_step_y_turns, 'Angular Step Y:', group=base_scan_group + ) + dialog_builder.add_angle_widget( + scan_builder.angular_shift_turns, 'Angular Shift:', group=base_scan_group + ) + self._append_common_controls(dialog_builder, item) + return dialog_builder.build_dialog(title, parent) + + return QMessageBox( + QMessageBox.Icon.Information, + title, + f'"{builder_name}" has no editable parameters!', + QMessageBox.Ok, + parent, + ) diff --git a/src/ptychodus/controller/scan/tableModel.py b/src/ptychodus/controller/scan/table_model.py similarity index 68% rename from src/ptychodus/controller/scan/tableModel.py rename to src/ptychodus/controller/scan/table_model.py index 4c1d26ca..acd9fdd7 100644 --- a/src/ptychodus/controller/scan/tableModel.py +++ b/src/ptychodus/controller/scan/table_model.py @@ -2,6 +2,8 @@ from PyQt5.QtCore import Qt, QAbstractTableModel, QModelIndex, QObject +from ptychodus.api.units import BYTES_PER_MEGABYTE + from ...model.product import ScanAPI, ScanRepository from ...model.product.scan import ScanRepositoryItem @@ -14,25 +16,25 @@ def __init__( self._repository = repository self._api = api self._header = ['Name', 'Plot', 'Builder', 'Points', 'Length [m]', 'Size [MB]'] - self._checkedItemIndexes: set[int] = set() + self._checked_item_indexes: set[int] = set() - def insertItem(self, index: int, item: ScanRepositoryItem) -> None: + def insert_item(self, index: int, item: ScanRepositoryItem) -> None: self.beginInsertRows(QModelIndex(), index, index) self.endInsertRows() - def updateItem(self, index: int, item: ScanRepositoryItem) -> None: - topLeft = self.index(index, 0) - bottomRight = self.index(index, len(self._header)) - self.dataChanged.emit(topLeft, bottomRight) + def update_item(self, index: int, item: ScanRepositoryItem) -> None: + top_left = self.index(index, 0) + bottom_right = self.index(index, len(self._header)) + self.dataChanged.emit(top_left, bottom_right) - def removeItem(self, index: int, item: ScanRepositoryItem) -> None: + def remove_item(self, index: int, item: ScanRepositoryItem) -> None: self.beginRemoveRows(QModelIndex(), index, index) self.endRemoveRows() - def isItemChecked(self, itemIndex: int) -> bool: - return itemIndex in self._checkedItemIndexes + def is_item_checked(self, item_index: int) -> bool: + return item_index in self._checked_item_indexes - def headerData( + def headerData( # noqa: N802 self, section: int, orientation: Qt.Orientation, @@ -46,26 +48,26 @@ def data(self, index: QModelIndex, role: int = Qt.ItemDataRole.DisplayRole) -> A return None item = self._repository[index.row()] - scan = item.getScan() + scan = item.get_scan() if role == Qt.ItemDataRole.DisplayRole: if index.column() == 0: - return self._repository.getName(index.row()) + return self._repository.get_name(index.row()) elif index.column() == 1: return None elif index.column() == 2: - return item.getBuilder().getName() + return item.get_builder().get_name() elif index.column() == 3: return len(scan) elif index.column() == 4: - return f'{item.getLengthInMeters():.6f}' + return f'{item.get_length_m():.6f}' elif index.column() == 5: - return f'{scan.sizeInBytes / (1024 * 1024):.2f}' + return f'{scan.nbytes / BYTES_PER_MEGABYTE:.2f}' elif role == Qt.ItemDataRole.CheckStateRole: if index.column() == 1: return ( Qt.CheckState.Checked - if index.row() in self._checkedItemIndexes + if index.row() in self._checked_item_indexes else Qt.CheckState.Unchecked ) @@ -81,23 +83,23 @@ def flags(self, index: QModelIndex) -> Qt.ItemFlags: return value - def setData(self, index: QModelIndex, value: Any, role: int = Qt.ItemDataRole.EditRole) -> bool: + def setData(self, index: QModelIndex, value: Any, role: int = Qt.ItemDataRole.EditRole) -> bool: # noqa: N802 if not index.isValid(): return False if role == Qt.ItemDataRole.EditRole: if index.column() == 0: - self._repository.setName(index.row(), str(value)) + self._repository.set_name(index.row(), str(value)) return True elif index.column() == 2: - self._api.buildScan(index.row(), str(value)) + self._api.build_scan(index.row(), str(value)) return True elif role == Qt.ItemDataRole.CheckStateRole: if index.column() == 1: if value == Qt.CheckState.Checked: - self._checkedItemIndexes.add(index.row()) + self._checked_item_indexes.add(index.row()) else: - self._checkedItemIndexes.discard(index.row()) + self._checked_item_indexes.discard(index.row()) self.dataChanged.emit(index, index) @@ -105,8 +107,8 @@ def setData(self, index: QModelIndex, value: Any, role: int = Qt.ItemDataRole.Ed return False - def rowCount(self, parent: QModelIndex = QModelIndex()) -> int: + def rowCount(self, parent: QModelIndex = QModelIndex()) -> int: # noqa: N802 return len(self._repository) - def columnCount(self, parent: QModelIndex = QModelIndex()) -> int: + def columnCount(self, parent: QModelIndex = QModelIndex()) -> int: # noqa: N802 return len(self._header) diff --git a/src/ptychodus/controller/settings.py b/src/ptychodus/controller/settings.py index a5c5b36a..62e9e629 100644 --- a/src/ptychodus/controller/settings.py +++ b/src/ptychodus/controller/settings.py @@ -18,13 +18,13 @@ def __init__(self, parent: QObject | None = None) -> None: self._names: Sequence[str] = list() self._values: Sequence[str] = list() - def setNamesAndValues(self, names: Sequence[str], values: Sequence[str]) -> None: + def set_names_and_values(self, names: Sequence[str], values: Sequence[str]) -> None: self.beginResetModel() self._names = names self._values = values self.endResetModel() - def headerData( + def headerData( # noqa: N802 self, section: int, orientation: Qt.Orientation, @@ -43,85 +43,85 @@ def data(self, index: QModelIndex, role: int = Qt.ItemDataRole.DisplayRole) -> A elif index.column() == 1: return str(self._values[index.row()]) - def rowCount(self, parent: QModelIndex = QModelIndex()) -> int: + def rowCount(self, parent: QModelIndex = QModelIndex()) -> int: # noqa: N802 return len(self._names) - def columnCount(self, parent: QModelIndex = QModelIndex()) -> int: + def columnCount(self, parent: QModelIndex = QModelIndex()) -> int: # noqa: N802 return 2 class SettingsController(Observer): def __init__( self, - settingsRegistry: SettingsRegistry, + settings_registry: SettingsRegistry, view: SettingsView, - tableView: QTableView, - fileDialogFactory: FileDialogFactory, + table_view: QTableView, + file_dialog_factory: FileDialogFactory, ) -> None: super().__init__() - self._settingsRegistry = settingsRegistry + self._settings_registry = settings_registry self._view = view - self._tableView = tableView - self._fileDialogFactory = fileDialogFactory + self._table_view = table_view + self._file_dialog_factory = file_dialog_factory - self._listModel = QStringListModel() - self._tableModel = SettingsTableModel() + self._list_model = QStringListModel() + self._table_model = SettingsTableModel() - settingsRegistry.addObserver(self) + settings_registry.add_observer(self) - view.listView.setModel(self._listModel) - view.listView.selectionModel().currentChanged.connect(self._updateView) + view.list_view.setModel(self._list_model) + view.list_view.selectionModel().currentChanged.connect(self._update_view) - self._tableView.setModel(self._tableModel) + self._table_view.setModel(self._table_model) - view.buttonBox.openButton.clicked.connect(self._openSettings) - view.buttonBox.saveButton.clicked.connect(self._saveSettings) + view.button_box.open_button.clicked.connect(self._open_settings) + view.button_box.save_button.clicked.connect(self._save_settings) - self._syncModelToView() + self._sync_model_to_view() - def _openSettings(self) -> None: - filePath, _ = self._fileDialogFactory.getOpenFilePath( + def _open_settings(self) -> None: + file_path, _ = self._file_dialog_factory.get_open_file_path( self._view, 'Open Settings', - nameFilters=self._settingsRegistry.getOpenFileFilterList(), - selectedNameFilter=self._settingsRegistry.getOpenFileFilter(), + name_filters=self._settings_registry.get_open_file_filters(), + selected_name_filter=self._settings_registry.get_open_file_filter(), ) - if filePath: - self._settingsRegistry.openSettings(filePath) + if file_path: + self._settings_registry.open_settings(file_path) - def _saveSettings(self) -> None: - filePath, _ = self._fileDialogFactory.getSaveFilePath( + def _save_settings(self) -> None: + file_path, _ = self._file_dialog_factory.get_save_file_path( self._view, 'Save Settings', - nameFilters=self._settingsRegistry.getSaveFileFilterList(), - selectedNameFilter=self._settingsRegistry.getSaveFileFilter(), + name_filters=self._settings_registry.get_save_file_filters(), + selected_name_filter=self._settings_registry.get_save_file_filter(), ) - if filePath: - self._settingsRegistry.saveSettings(filePath) + if file_path: + self._settings_registry.save_settings(file_path) - def _updateView(self, current: QModelIndex, previous: QModelIndex) -> None: + def _update_view(self, current: QModelIndex, previous: QModelIndex) -> None: if not current.isValid(): return - groupName = self._listModel.data(current, Qt.DisplayRole) - group = self._settingsRegistry[groupName] + group_name = self._list_model.data(current, Qt.DisplayRole) + group = self._settings_registry[group_name] names: list[str] = list() values: list[str] = list() - for parameterName, parameter in group.parameters().items(): - names.append(parameterName) - values.append(parameter.getValueAsString()) + for parameter_name, parameter in group.parameters().items(): + names.append(parameter_name) + values.append(parameter.get_value_as_string()) - self._tableModel.setNamesAndValues(names, values) + self._table_model.set_names_and_values(names, values) - def _syncModelToView(self) -> None: - self._listModel.setStringList(sorted(iter(self._settingsRegistry))) + def _sync_model_to_view(self) -> None: + self._list_model.setStringList(sorted(iter(self._settings_registry))) - current = self._view.listView.currentIndex() - self._updateView(current, QModelIndex()) + current = self._view.list_view.currentIndex() + self._update_view(current, QModelIndex()) - def update(self, observable: Observable) -> None: - if observable is self._settingsRegistry: - self._syncModelToView() + def _update(self, observable: Observable) -> None: + if observable is self._settings_registry: + self._sync_model_to_view() diff --git a/src/ptychodus/controller/tike/core.py b/src/ptychodus/controller/tike/core.py index fdaf9cb0..422e0103 100644 --- a/src/ptychodus/controller/tike/core.py +++ b/src/ptychodus/controller/tike/core.py @@ -1,13 +1,8 @@ -from __future__ import annotations - -from PyQt5.QtWidgets import ( - QVBoxLayout, - QWidget, -) +from PyQt5.QtWidgets import QVBoxLayout, QWidget from ...model.tike import TikeReconstructorLibrary from ..reconstructor import ReconstructorViewControllerFactory -from .viewControllers import ( +from .view_controllers import ( TikeMultigridViewController, TikeObjectCorrectionViewController, TikeParametersViewController, @@ -18,29 +13,29 @@ class TikeViewController(QWidget): def __init__( - self, model: TikeReconstructorLibrary, showAlpha: bool, parent: QWidget | None = None + self, model: TikeReconstructorLibrary, show_alpha: bool, parent: QWidget | None = None ) -> None: super().__init__(parent) - self._parametersViewController = TikeParametersViewController( - model.settings, showAlpha=showAlpha + self._parameters_view_controller = TikeParametersViewController( + model.settings, show_alpha=show_alpha ) - self._multigridViewController = TikeMultigridViewController(model.multigridSettings) - self._objectCorrectionViewController = TikeObjectCorrectionViewController( - model.objectCorrectionSettings + self._multigrid_view_controller = TikeMultigridViewController(model.multigrid_settings) + self._object_correction_view_controller = TikeObjectCorrectionViewController( + model.object_correction_settings ) - self._probeCorrectionViewController = TikeProbeCorrectionViewController( - model.probeCorrectionSettings + self._probe_correction_view_controller = TikeProbeCorrectionViewController( + model.probe_correction_settings ) - self._positionCorrectionViewController = TikePositionCorrectionViewController( - model.positionCorrectionSettings + self._position_correction_view_controller = TikePositionCorrectionViewController( + model.position_correction_settings ) layout = QVBoxLayout() - layout.addWidget(self._parametersViewController.getWidget()) - layout.addWidget(self._multigridViewController.getWidget()) - layout.addWidget(self._positionCorrectionViewController.getWidget()) - layout.addWidget(self._probeCorrectionViewController.getWidget()) - layout.addWidget(self._objectCorrectionViewController.getWidget()) + layout.addWidget(self._parameters_view_controller.get_widget()) + layout.addWidget(self._multigrid_view_controller.get_widget()) + layout.addWidget(self._position_correction_view_controller.get_widget()) + layout.addWidget(self._probe_correction_view_controller.get_widget()) + layout.addWidget(self._object_correction_view_controller.get_widget()) layout.addStretch() self.setLayout(layout) @@ -49,17 +44,15 @@ class TikeViewControllerFactory(ReconstructorViewControllerFactory): def __init__(self, model: TikeReconstructorLibrary) -> None: super().__init__() self._model = model - self._controllerList: list[TikeViewController] = list() @property - def backendName(self) -> str: + def backend_name(self) -> str: return 'Tike' - def createViewController(self, reconstructorName: str) -> QWidget: - if reconstructorName == 'rpie': - viewController = TikeViewController(self._model, showAlpha=True) + def create_view_controller(self, reconstructor_name: str) -> QWidget: + if reconstructor_name == 'rpie': + view_controller = TikeViewController(self._model, show_alpha=True) else: - viewController = TikeViewController(self._model, showAlpha=False) + view_controller = TikeViewController(self._model, show_alpha=False) - self._controllerList.append(viewController) - return viewController + return view_controller diff --git a/src/ptychodus/controller/tike/viewControllers.py b/src/ptychodus/controller/tike/viewControllers.py deleted file mode 100644 index fc71fa0a..00000000 --- a/src/ptychodus/controller/tike/viewControllers.py +++ /dev/null @@ -1,347 +0,0 @@ -from ptychodus.api.observer import Observable, Observer -from ptychodus.api.parametric import BooleanParameter, RealParameter - -from PyQt5.QtCore import QRegularExpression -from PyQt5.QtGui import QRegularExpressionValidator -from PyQt5.QtWidgets import QComboBox, QFormLayout, QGroupBox, QWidget - -from ...model.tike import ( - TikeMultigridSettings, - TikeObjectCorrectionSettings, - TikePositionCorrectionSettings, - TikeProbeCorrectionSettings, - TikeSettings, -) -from ..parametric import ( - CheckBoxParameterViewController, - ComboBoxParameterViewController, - DecimalLineEditParameterViewController, - DecimalSliderParameterViewController, - LineEditParameterViewController, - ParameterViewController, - SpinBoxParameterViewController, -) - -__all__ = [ - 'TikeMultigridViewController', - 'TikeObjectCorrectionViewController', - 'TikePositionCorrectionViewController', - 'TikeProbeCorrectionViewController', -] - - -class TikeParametersViewController(ParameterViewController, Observer): - def __init__(self, settings: TikeSettings, *, showAlpha: bool) -> None: - super().__init__() - self._settings = settings - self._numGpusViewController = LineEditParameterViewController( - settings.numGpus, - QRegularExpressionValidator(QRegularExpression('[\\d,]+')), - tool_tip='The number of GPUs to use. If the number of GPUs is less than the requested number, only workers for the available GPUs are allocated.', - ) - self._noiseModelViewController = ComboBoxParameterViewController( - settings.noiseModel, - settings.getNoiseModels(), - tool_tip='The noise model to use for the cost function.', - ) - self._numBatchViewController = SpinBoxParameterViewController( - settings.numBatch, - tool_tip='The dataset is divided into this number of groups where each group is processed sequentially.', - ) - self._batchMethodViewController = ComboBoxParameterViewController( - settings.batchMethod, - settings.getBatchMethods(), - tool_tip='The name of the batch selection method.', - ) - self._numIterViewController = SpinBoxParameterViewController( - settings.numIter, tool_tip='The number of epochs to process before returning.' - ) - self._convergenceWindowViewController = SpinBoxParameterViewController( - settings.convergenceWindow, - tool_tip='The number of epochs to consider for convergence monitoring. Set to any value less than 2 to disable.', - ) - self._alphaViewController = DecimalSliderParameterViewController( - settings.alpha, tool_tip='RPIE becomes EPIE when this parameter is 1.' - ) - self._logLevelComboBox = QComboBox() - - for model in settings.getLogLevels(): - self._logLevelComboBox.addItem(model) - - self._logLevelComboBox.textActivated.connect(settings.setLogLevel) - - self._widget = QGroupBox('Tike Parameters') - - layout = QFormLayout() - layout.addRow('Number of GPUs:', self._numGpusViewController.getWidget()) - layout.addRow('Noise Model:', self._noiseModelViewController.getWidget()) - layout.addRow('Number of Batches:', self._numBatchViewController.getWidget()) - layout.addRow('Batch Method:', self._batchMethodViewController.getWidget()) - layout.addRow('Number of Iterations:', self._numIterViewController.getWidget()) - layout.addRow('Convergence Window:', self._convergenceWindowViewController.getWidget()) - - if showAlpha: - layout.addRow('Alpha:', self._alphaViewController.getWidget()) - - layout.addRow('Log Level:', self._logLevelComboBox) - self._widget.setLayout(layout) - - self._syncModelToView() - self._settings.addObserver(self) - - def getWidget(self) -> QWidget: - return self._widget - - def _syncModelToView(self) -> None: - self._logLevelComboBox.setCurrentText(self._settings.getLogLevel()) - - def update(self, observable: Observable) -> None: - if observable is self._settings: - self._syncModelToView() - - -class TikeMultigridViewController(ParameterViewController, Observer): - def __init__(self, settings: TikeMultigridSettings) -> None: - super().__init__() - self._useMultigrid = settings.useMultigrid - self._numLevelsController = SpinBoxParameterViewController( - settings.numLevels, - tool_tip='The number of times to reduce the problem by a factor of two.', - ) - self._widget = QGroupBox('Multigrid') - self._widget.setCheckable(True) - - layout = QFormLayout() - layout.addRow('Number of Levels:', self._numLevelsController.getWidget()) - self._widget.setLayout(layout) - - self._syncModelToView() - self._widget.toggled.connect(settings.useMultigrid.setValue) - self._useMultigrid.addObserver(self) - - def getWidget(self) -> QWidget: - return self._widget - - def _syncModelToView(self) -> None: - self._widget.setChecked(self._useMultigrid.getValue()) - - def update(self, observable: Observable) -> None: - if observable is self._useMultigrid: - self._syncModelToView() - - -class TikeAdaptiveMomentViewController(ParameterViewController, Observer): - def __init__( - self, useAdaptiveMoment: BooleanParameter, mdecay: RealParameter, vdecay: RealParameter - ) -> None: - super().__init__() - self._useAdaptiveMoment = useAdaptiveMoment - self._mdecayViewController = DecimalSliderParameterViewController( - mdecay, tool_tip='The proportion of the first moment that is previous first moments.' - ) - self._vdecayViewController = DecimalSliderParameterViewController( - vdecay, tool_tip='The proportion of the second moment that is previous second moments.' - ) - self._widget = QGroupBox('Adaptive Moment') - self._widget.setCheckable(True) - - layout = QFormLayout() - layout.addRow('M Decay:', self._mdecayViewController.getWidget()) - layout.addRow('V Decay:', self._vdecayViewController.getWidget()) - self._widget.setLayout(layout) - - self._syncModelToView() - self._widget.toggled.connect(useAdaptiveMoment.setValue) - self._useAdaptiveMoment.addObserver(self) - - def getWidget(self) -> QWidget: - return self._widget - - def _syncModelToView(self) -> None: - self._widget.setChecked(self._useAdaptiveMoment.getValue()) - - def update(self, observable: Observable) -> None: - if observable is self._useAdaptiveMoment: - self._syncModelToView() - - -class TikeObjectCorrectionViewController(ParameterViewController, Observer): - def __init__(self, settings: TikeObjectCorrectionSettings) -> None: - super().__init__() - self._useObjectCorrection = settings.useObjectCorrection - self._positivityConstraintViewController = DecimalSliderParameterViewController( - settings.positivityConstraint - ) - self._smoothnessConstraintViewController = DecimalSliderParameterViewController( - settings.smoothnessConstraint - ) - self._adaptiveMomentViewController = TikeAdaptiveMomentViewController( - settings.useAdaptiveMoment, settings.mdecay, settings.vdecay - ) - self._useMagnitudeClippingViewController = CheckBoxParameterViewController( - settings.useMagnitudeClipping, - 'Magnitude Clipping', - tool_tip='Forces the object magnitude to be <= 1.', - ) - - self._widget = QGroupBox('Object Correction') - self._widget.setCheckable(True) - - layout = QFormLayout() - layout.addRow( - 'Positivity Constraint:', self._positivityConstraintViewController.getWidget() - ) - layout.addRow( - 'Smoothness Constraint:', self._smoothnessConstraintViewController.getWidget() - ) - layout.addRow(self._adaptiveMomentViewController.getWidget()) - layout.addRow(self._useMagnitudeClippingViewController.getWidget()) - self._widget.setLayout(layout) - - self._syncModelToView() - self._widget.toggled.connect(settings.useObjectCorrection.setValue) - self._useObjectCorrection.addObserver(self) - - def getWidget(self) -> QWidget: - return self._widget - - def _syncModelToView(self) -> None: - self._widget.setChecked(self._useObjectCorrection.getValue()) - - def update(self, observable: Observable) -> None: - if observable is self._useObjectCorrection: - self._syncModelToView() - - -class TikeProbeSupportViewController(ParameterViewController, Observer): - def __init__(self, settings: TikeProbeCorrectionSettings) -> None: - super().__init__() - self._useFiniteProbeSupport = settings.useFiniteProbeSupport - self._weightViewController = DecimalLineEditParameterViewController( - settings.probeSupportWeight, tool_tip='Weight of the finite probe constraint.' - ) - self._radiusViewController = DecimalSliderParameterViewController( - settings.probeSupportRadius, - tool_tip='Radius of probe support as fraction of probe grid.', - ) - self._degreeViewController = DecimalLineEditParameterViewController( - settings.probeSupportDegree, - tool_tip='Degree of the supergaussian defining the probe support.', - ) - self._widget = QGroupBox('Finite Probe Support') - self._widget.setCheckable(True) - - layout = QFormLayout() - layout.addRow('Weight:', self._weightViewController.getWidget()) - layout.addRow('Radius:', self._radiusViewController.getWidget()) - layout.addRow('Degree:', self._degreeViewController.getWidget()) - self._widget.setLayout(layout) - - self._syncModelToView() - self._widget.toggled.connect(settings.useFiniteProbeSupport.setValue) - self._useFiniteProbeSupport.addObserver(self) - - def getWidget(self) -> QWidget: - return self._widget - - def _syncModelToView(self) -> None: - self._widget.setChecked(self._useFiniteProbeSupport.getValue()) - - def update(self, observable: Observable) -> None: - if observable is self._useFiniteProbeSupport: - self._syncModelToView() - - -class TikeProbeCorrectionViewController(ParameterViewController, Observer): - def __init__(self, settings: TikeProbeCorrectionSettings) -> None: - super().__init__() - self._useProbeCorrection = settings.useProbeCorrection - self._forceSparsityViewController = DecimalSliderParameterViewController( - settings.forceSparsity, tool_tip='Forces this proportion of zero elements.' - ) - self._forceOrthogonalityViewController = CheckBoxParameterViewController( - settings.forceOrthogonality, - 'Force Orthogonality', - tool_tip='Forces probes to be orthogonal each iteration.', - ) - self._forceCenteredIntensityViewController = CheckBoxParameterViewController( - settings.forceCenteredIntensity, - 'Force Centered Intensity', - tool_tip='Forces the probe intensity to be centered.', - ) - self._supportViewController = TikeProbeSupportViewController(settings) - self._adaptiveMomentViewController = TikeAdaptiveMomentViewController( - settings.useAdaptiveMoment, settings.mdecay, settings.vdecay - ) - self._additionalProbePenaltyViewController = DecimalLineEditParameterViewController( - settings.additionalProbePenalty, - tool_tip='Penalty applied to the last probe for existing.', - ) - self._widget = QGroupBox('Probe Correction') - self._widget.setCheckable(True) - - layout = QFormLayout() - layout.addRow('Force Sparsity:', self._forceSparsityViewController.getWidget()) - layout.addRow(self._forceOrthogonalityViewController.getWidget()) - layout.addRow(self._forceCenteredIntensityViewController.getWidget()) - layout.addRow(self._supportViewController.getWidget()) - layout.addRow(self._adaptiveMomentViewController.getWidget()) - layout.addRow( - 'Additional Probe Penalty:', self._additionalProbePenaltyViewController.getWidget() - ) - self._widget.setLayout(layout) - - self._syncModelToView() - self._widget.toggled.connect(settings.useProbeCorrection.setValue) - self._useProbeCorrection.addObserver(self) - - def getWidget(self) -> QWidget: - return self._widget - - def _syncModelToView(self) -> None: - self._widget.setChecked(self._useProbeCorrection.getValue()) - - def update(self, observable: Observable) -> None: - if observable is self._useProbeCorrection: - self._syncModelToView() - - -class TikePositionCorrectionViewController(ParameterViewController, Observer): - def __init__(self, settings: TikePositionCorrectionSettings) -> None: - self._usePositionCorrection = settings.usePositionCorrection - self._usePositionRegularizationViewController = CheckBoxParameterViewController( - settings.usePositionRegularization, - 'Use Regularization', - tool_tip='Whether the positions are constrained to fit a random error plus affine error model.', - ) - self._adaptiveMomentViewController = TikeAdaptiveMomentViewController( - settings.useAdaptiveMoment, settings.mdecay, settings.vdecay - ) - self._updateMagnitudeLimitViewController = DecimalLineEditParameterViewController( - settings.updateMagnitudeLimit, - tool_tip='When set to a positive number, x and y update magnitudes are clipped (limited) to this value.', - ) - self._widget = QGroupBox('Position Correction') - self._widget.setCheckable(True) - - layout = QFormLayout() - layout.addRow(self._usePositionRegularizationViewController.getWidget()) - layout.addRow(self._adaptiveMomentViewController.getWidget()) - layout.addRow( - 'Update Magnitude Limit:', self._updateMagnitudeLimitViewController.getWidget() - ) - self._widget.setLayout(layout) - - self._syncModelToView() - self._widget.toggled.connect(settings.usePositionCorrection.setValue) - self._usePositionCorrection.addObserver(self) - - def getWidget(self) -> QWidget: - return self._widget - - def _syncModelToView(self) -> None: - self._widget.setChecked(self._usePositionCorrection.getValue()) - - def update(self, observable: Observable) -> None: - if observable is self._usePositionCorrection: - self._syncModelToView() diff --git a/src/ptychodus/controller/tike/view_controllers.py b/src/ptychodus/controller/tike/view_controllers.py new file mode 100644 index 00000000..c33a3fe6 --- /dev/null +++ b/src/ptychodus/controller/tike/view_controllers.py @@ -0,0 +1,248 @@ +from ptychodus.api.observer import Observable, Observer +from ptychodus.api.parametric import BooleanParameter, RealParameter + +from PyQt5.QtCore import QRegularExpression +from PyQt5.QtGui import QRegularExpressionValidator +from PyQt5.QtWidgets import QComboBox, QFormLayout, QGroupBox, QWidget + +from ...model.tike import ( + TikeMultigridSettings, + TikeObjectCorrectionSettings, + TikePositionCorrectionSettings, + TikeProbeCorrectionSettings, + TikeSettings, +) +from ..parametric import ( + CheckBoxParameterViewController, + CheckableGroupBoxParameterViewController, + ComboBoxParameterViewController, + DecimalLineEditParameterViewController, + DecimalSliderParameterViewController, + LineEditParameterViewController, + ParameterViewController, + SpinBoxParameterViewController, +) + +__all__ = [ + 'TikeMultigridViewController', + 'TikeObjectCorrectionViewController', + 'TikePositionCorrectionViewController', + 'TikeProbeCorrectionViewController', +] + + +class TikeParametersViewController(ParameterViewController, Observer): + def __init__(self, settings: TikeSettings, *, show_alpha: bool) -> None: + super().__init__() + self._settings = settings + self._num_gpus_view_controller = LineEditParameterViewController( + settings.num_gpus, + QRegularExpressionValidator(QRegularExpression('[\\d,]+')), + tool_tip='The number of GPUs to use. If the number of GPUs is less than the requested number, only workers for the available GPUs are allocated.', + ) + self._noise_model_view_controller = ComboBoxParameterViewController( + settings.noise_model, + settings.get_noise_models(), + tool_tip='The noise model to use for the cost function.', + ) + self._num_batch_view_controller = SpinBoxParameterViewController( + settings.num_batch, + tool_tip='The dataset is divided into this number of groups where each group is processed sequentially.', + ) + self._batch_method_view_controller = ComboBoxParameterViewController( + settings.batch_method, + settings.get_batch_methods(), + tool_tip='The name of the batch selection method.', + ) + self._num_iter_view_controller = SpinBoxParameterViewController( + settings.num_iter, tool_tip='The number of epochs to process before returning.' + ) + self._convergence_window_view_controller = SpinBoxParameterViewController( + settings.convergence_window, + tool_tip='The number of epochs to consider for convergence monitoring. Set to any value less than 2 to disable.', + ) + self._alpha_view_controller = DecimalSliderParameterViewController( + settings.alpha, tool_tip='RPIE becomes EPIE when this parameter is 1.' + ) + self._log_level_combo_box = QComboBox() + + for model in settings.get_log_levels(): + self._log_level_combo_box.addItem(model) + + self._log_level_combo_box.textActivated.connect(settings.set_log_level) + + self._widget = QGroupBox('Tike Parameters') + + layout = QFormLayout() + layout.addRow('Number of GPUs:', self._num_gpus_view_controller.get_widget()) + layout.addRow('Noise Model:', self._noise_model_view_controller.get_widget()) + layout.addRow('Number of Batches:', self._num_batch_view_controller.get_widget()) + layout.addRow('Batch Method:', self._batch_method_view_controller.get_widget()) + layout.addRow('Number of Iterations:', self._num_iter_view_controller.get_widget()) + layout.addRow('Convergence Window:', self._convergence_window_view_controller.get_widget()) + + if show_alpha: + layout.addRow('Alpha:', self._alpha_view_controller.get_widget()) + + layout.addRow('Log Level:', self._log_level_combo_box) + self._widget.setLayout(layout) + + self._sync_model_to_view() + self._settings.add_observer(self) + + def get_widget(self) -> QWidget: + return self._widget + + def _sync_model_to_view(self) -> None: + self._log_level_combo_box.setCurrentText(self._settings.get_log_level()) + + def _update(self, observable: Observable) -> None: + if observable is self._settings: + self._sync_model_to_view() + + +class TikeMultigridViewController(CheckableGroupBoxParameterViewController): + def __init__(self, settings: TikeMultigridSettings) -> None: + super().__init__(settings.use_multigrid, 'Multigrid') + self._use_multigrid = settings.use_multigrid + self._num_levels_controller = SpinBoxParameterViewController( + settings.num_levels, + tool_tip='The number of times to reduce the problem by a factor of two.', + ) + + layout = QFormLayout() + layout.addRow('Number of Levels:', self._num_levels_controller.get_widget()) + self.get_widget().setLayout(layout) + + +class TikeAdaptiveMomentViewController(CheckableGroupBoxParameterViewController): + def __init__( + self, use_adaptive_moment: BooleanParameter, mdecay: RealParameter, vdecay: RealParameter + ) -> None: + super().__init__(use_adaptive_moment, 'Adaptive Moment') + self._use_adaptive_moment = use_adaptive_moment + self._mdecay_view_controller = DecimalSliderParameterViewController( + mdecay, tool_tip='The proportion of the first moment that is previous first moments.' + ) + self._vdecay_view_controller = DecimalSliderParameterViewController( + vdecay, tool_tip='The proportion of the second moment that is previous second moments.' + ) + + layout = QFormLayout() + layout.addRow('M Decay:', self._mdecay_view_controller.get_widget()) + layout.addRow('V Decay:', self._vdecay_view_controller.get_widget()) + self.get_widget().setLayout(layout) + + +class TikeObjectCorrectionViewController(CheckableGroupBoxParameterViewController): + def __init__(self, settings: TikeObjectCorrectionSettings) -> None: + super().__init__(settings.use_object_correction, 'Object Correction') + self._positivity_constraint_view_controller = DecimalSliderParameterViewController( + settings.positivity_constraint + ) + self._smoothness_constraint_view_controller = DecimalSliderParameterViewController( + settings.smoothness_constraint + ) + self._adaptive_moment_view_controller = TikeAdaptiveMomentViewController( + settings.use_adaptive_moment, settings.mdecay, settings.vdecay + ) + self._use_magnitude_clipping_view_controller = CheckBoxParameterViewController( + settings.use_magnitude_clipping, + 'Magnitude Clipping', + tool_tip='Forces the object magnitude to be <= 1.', + ) + + layout = QFormLayout() + layout.addRow( + 'Positivity Constraint:', self._positivity_constraint_view_controller.get_widget() + ) + layout.addRow( + 'Smoothness Constraint:', self._smoothness_constraint_view_controller.get_widget() + ) + layout.addRow(self._adaptive_moment_view_controller.get_widget()) + layout.addRow(self._use_magnitude_clipping_view_controller.get_widget()) + self.get_widget().setLayout(layout) + + +class TikeProbeSupportViewController(CheckableGroupBoxParameterViewController): + def __init__(self, settings: TikeProbeCorrectionSettings) -> None: + super().__init__(settings.use_finite_probe_support, 'Finite Probe Support') + self._weight_view_controller = DecimalLineEditParameterViewController( + settings.probe_support_weight, tool_tip='Weight of the finite probe constraint.' + ) + self._radius_view_controller = DecimalSliderParameterViewController( + settings.probe_support_radius, + tool_tip='Radius of probe support as fraction of probe grid.', + ) + self._degree_view_controller = DecimalLineEditParameterViewController( + settings.probe_support_degree, + tool_tip='Degree of the supergaussian defining the probe support.', + ) + + layout = QFormLayout() + layout.addRow('Weight:', self._weight_view_controller.get_widget()) + layout.addRow('Radius:', self._radius_view_controller.get_widget()) + layout.addRow('Degree:', self._degree_view_controller.get_widget()) + self.get_widget().setLayout(layout) + + +class TikeProbeCorrectionViewController(CheckableGroupBoxParameterViewController): + def __init__(self, settings: TikeProbeCorrectionSettings) -> None: + super().__init__(settings.use_probe_correction, 'Probe Correction') + self._force_sparsity_view_controller = DecimalSliderParameterViewController( + settings.force_sparsity, tool_tip='Forces this proportion of zero elements.' + ) + self._force_orthogonality_view_controller = CheckBoxParameterViewController( + settings.force_orthogonality, + 'Force Orthogonality', + tool_tip='Forces probes to be orthogonal each iteration.', + ) + self._force_centered_intensity_view_controller = CheckBoxParameterViewController( + settings.force_centered_intensity, + 'Force Centered Intensity', + tool_tip='Forces the probe intensity to be centered.', + ) + self._support_view_controller = TikeProbeSupportViewController(settings) + self._adaptive_moment_view_controller = TikeAdaptiveMomentViewController( + settings.use_adaptive_moment, settings.mdecay, settings.vdecay + ) + self._additional_probe_penalty_view_controller = DecimalLineEditParameterViewController( + settings.additional_probe_penalty, + tool_tip='Penalty applied to the last probe for existing.', + ) + + layout = QFormLayout() + layout.addRow('Force Sparsity:', self._force_sparsity_view_controller.get_widget()) + layout.addRow(self._force_orthogonality_view_controller.get_widget()) + layout.addRow(self._force_centered_intensity_view_controller.get_widget()) + layout.addRow(self._support_view_controller.get_widget()) + layout.addRow(self._adaptive_moment_view_controller.get_widget()) + layout.addRow( + 'Additional Probe Penalty:', self._additional_probe_penalty_view_controller.get_widget() + ) + self.get_widget().setLayout(layout) + + +class TikePositionCorrectionViewController(CheckableGroupBoxParameterViewController): + def __init__(self, settings: TikePositionCorrectionSettings) -> None: + super().__init__(settings.use_position_correction, 'Position Correction') + self._use_position_regularization_view_controller = CheckBoxParameterViewController( + settings.use_position_regularization, + 'Use Regularization', + tool_tip='Whether the positions are constrained to fit a random error plus affine error model.', + ) + self._adaptive_moment_view_controller = TikeAdaptiveMomentViewController( + settings.use_adaptive_moment, settings.mdecay, settings.vdecay + ) + self._update_magnitude_limit_view_controller = DecimalLineEditParameterViewController( + settings.update_magnitude_limit, + tool_tip='When set to a positive number, x and y update magnitudes are clipped (limited) to this value.', + ) + + layout = QFormLayout() + layout.addRow(self._use_position_regularization_view_controller.get_widget()) + layout.addRow(self._adaptive_moment_view_controller.get_widget()) + layout.addRow( + 'Update Magnitude Limit:', self._update_magnitude_limit_view_controller.get_widget() + ) + self.get_widget().setLayout(layout) diff --git a/src/ptychodus/controller/visualization/controller.py b/src/ptychodus/controller/visualization/controller.py index dbc7be15..3aa5c556 100644 --- a/src/ptychodus/controller/visualization/controller.py +++ b/src/ptychodus/controller/visualization/controller.py @@ -8,7 +8,7 @@ from ptychodus.api.geometry import Box2D, Line2D, PixelGeometry, Point2D from ptychodus.api.observer import Observable, Observer -from ptychodus.api.visualization import NumberArrayType +from ptychodus.api.typing import NumberArrayType from ...model.visualization import VisualizationEngine from ...view.visualization import ( @@ -33,33 +33,33 @@ def __init__( engine: VisualizationEngine, view: VisualizationView, item: ImageItem, - statusBar: QStatusBar, - fileDialogFactory: FileDialogFactory, + status_bar: QStatusBar, + file_dialog_factory: FileDialogFactory, ) -> None: super().__init__() self._engine = engine self._view = view self._item = item - self._statusBar = statusBar - self._fileDialogFactory = fileDialogFactory - self._lineCutDialog = LineCutDialog.createInstance(view) - self._histogramDialog = HistogramDialog.createInstance(view) + self._status_bar = status_bar + self._file_dialog_factory = file_dialog_factory + self._line_cut_dialog = LineCutDialog.create_instance(view) + self._histogram_dialog = HistogramDialog.create_instance(view) @classmethod - def createInstance( + def create_instance( cls, engine: VisualizationEngine, view: VisualizationView, - statusBar: QStatusBar, - fileDialogFactory: FileDialogFactory, + status_bar: QStatusBar, + file_dialog_factory: FileDialogFactory, ) -> VisualizationController: - itemEvents = ImageItemEvents() - item = ImageItem(itemEvents, statusBar) - controller = cls(engine, view, item, statusBar, fileDialogFactory) - engine.addObserver(controller) + item_events = ImageItemEvents() + item = ImageItem(item_events, status_bar) + controller = cls(engine, view, item, status_bar, file_dialog_factory) + engine.add_observer(controller) - itemEvents.lineCutFinished.connect(controller._analyzeLineCut) - itemEvents.rectangleFinished.connect(controller._analyzeRegion) + item_events.line_cut_finished.connect(controller._analyze_line_cut) + item_events.rectangle_finished.connect(controller._analyze_region) scene = QGraphicsScene() scene.addItem(item) @@ -70,62 +70,66 @@ def createInstance( return controller - def setArray( + def set_array( self, array: NumberArrayType, - pixelGeometry: PixelGeometry, + pixel_geometry: PixelGeometry, *, - autoscaleColorAxis: bool = False, + autoscale_color_axis: bool = False, ) -> None: - try: - product = self._engine.render( - array, pixelGeometry, autoscaleColorAxis=autoscaleColorAxis - ) - except ValueError as err: - logger.exception(err) - ExceptionDialog.showException('Renderer', err) + if numpy.all(numpy.isfinite(array)): + try: + product = self._engine.render( + array, pixel_geometry, autoscale_color_axis=autoscale_color_axis + ) + except ValueError as err: + logger.exception(err) + ExceptionDialog.show_exception('Renderer', err) + else: + self._item.set_product(product) else: - self._item.setProduct(product) + logger.warning('Array contains infinite or NaN values!') + self._item.clear_product() - def clearArray(self) -> None: - self._item.clearProduct() + def clear_array(self) -> None: + self._item.clear_product() - def setMouseTool(self, mouseTool: ImageMouseTool) -> None: - self._item.setMouseTool(mouseTool) + def set_mouse_tool(self, mouse_tool: ImageMouseTool) -> None: + self._item.set_mouse_tool(mouse_tool) - def saveImage(self) -> None: - filePath, _ = self._fileDialogFactory.getSaveFilePath( - self._view, 'Save Image', mimeTypeFilters=VisualizationController.MIME_TYPES + def save_image(self) -> None: + file_path, _ = self._file_dialog_factory.get_save_file_path( + self._view, 'Save Image', mime_type_filters=VisualizationController.MIME_TYPES ) - if filePath: + if file_path: pixmap = self._item.pixmap() - pixmap.save(str(filePath)) + pixmap.save(str(file_path)) - def _analyzeLineCut(self, line: QLineF) -> None: + def _analyze_line_cut(self, line: QLineF) -> None: p1 = Point2D(line.x1(), line.y1()) p2 = Point2D(line.x2(), line.y2()) - line2D = Line2D(p1, p2) + line2d = Line2D(p1, p2) - product = self._item.getProduct() + product = self._item.get_product() if product is None: logger.warning('No visualization product!') return - valueLabel = product.getValueLabel() - lineCut = product.getLineCut(line2D) + value_label = product.get_value_label() + line_cut = product.get_line_cut(line2d) - ax = self._lineCutDialog.axes + ax = self._line_cut_dialog.axes ax.clear() - ax.plot(lineCut.distanceInMeters, lineCut.value, '.-', linewidth=1.5) + ax.plot(line_cut.distance_m, line_cut.value, '.-', linewidth=1.5) ax.set_xlabel('Distance [m]') - ax.set_ylabel(valueLabel) + ax.set_ylabel(value_label) ax.grid(True) - self._lineCutDialog.figureCanvas.draw() - self._lineCutDialog.open() + self._line_cut_dialog.figure_canvas.draw() + self._line_cut_dialog.open() - def _analyzeRegion(self, rect: QRectF) -> None: + def _analyze_region(self, rect: QRectF) -> None: if rect.isEmpty(): logger.debug('QRectF is empty!') return @@ -137,51 +141,51 @@ def _analyzeRegion(self, rect: QRectF) -> None: height=rect.height(), ) - product = self._item.getProduct() + product = self._item.get_product() if product is None: logger.warning('No visualization product!') return - valueLabel = product.getValueLabel() - kde = product.estimateKernelDensity(box) - values = numpy.linspace(kde.valueLower, kde.valueUpper, 1000) + value_label = product.get_value_label() + kde = product.estimate_kernel_density(box) + values = numpy.linspace(kde.value_lower, kde.value_upper, 1000) - ax = self._histogramDialog.axes + ax = self._histogram_dialog.axes ax.clear() ax.plot(values, kde.kde(values), '.-', linewidth=1.5) - ax.set_xlabel(valueLabel) + ax.set_xlabel(value_label) ax.set_ylabel('Density') ax.grid(True) - self._histogramDialog.figureCanvas.draw() + self._histogram_dialog.figure_canvas.draw() - rectangleView = self._histogramDialog.rectangleView - rectCenter = rect.center() - rectangleView.centerXLineEdit.setText(f'{rectCenter.x():.1f}') - rectangleView.centerYLineEdit.setText(f'{rectCenter.y():.1f}') - rectangleView.widthLineEdit.setText(f'{rect.width():.1f}') - rectangleView.heightLineEdit.setText(f'{rect.height():.1f}') + rectangle_view = self._histogram_dialog.rectangle_view + rect_center = rect.center() + rectangle_view.center_x_line_edit.setText(f'{rect_center.x():.1f}') + rectangle_view.center_y_line_edit.setText(f'{rect_center.y():.1f}') + rectangle_view.width_line_edit.setText(f'{rect.width():.1f}') + rectangle_view.height_line_edit.setText(f'{rect.height():.1f}') # TODO use rect for crop - self._histogramDialog.open() + self._histogram_dialog.open() - def zoomToFit(self) -> None: + def zoom_to_fit(self) -> None: self._item.setPos(0, 0) scene = self._view.scene() - boundingRect = scene.itemsBoundingRect() - scene.setSceneRect(boundingRect) + bounding_rect = scene.itemsBoundingRect() + scene.setSceneRect(bounding_rect) self._view.fitInView(scene.sceneRect(), Qt.AspectRatioMode.KeepAspectRatio) - def rerenderImage(self, *, autoscaleColorAxis: bool = False) -> None: - product = self._item.getProduct() + def rerender_image(self, *, autoscale_color_axis: bool = False) -> None: + product = self._item.get_product() if product is not None: - self.setArray( - product.getValues(), - product.getPixelGeometry(), - autoscaleColorAxis=autoscaleColorAxis, + self.set_array( + product.get_values(), + product.get_pixel_geometry(), + autoscale_color_axis=autoscale_color_axis, ) - def update(self, observable: Observable) -> None: + def _update(self, observable: Observable) -> None: if observable is self._engine: - self.rerenderImage() + self.rerender_image() diff --git a/src/ptychodus/controller/visualization/parameters.py b/src/ptychodus/controller/visualization/parameters.py index 3e585111..6c7a5234 100644 --- a/src/ptychodus/controller/visualization/parameters.py +++ b/src/ptychodus/controller/visualization/parameters.py @@ -14,58 +14,58 @@ def __init__(self, engine: VisualizationEngine, view: VisualizationParametersVie super().__init__() self._engine = engine self._view = view - self._rendererModel = QStringListModel() - self._transformationModel = QStringListModel() - self._variantModel = QStringListModel() + self._renderer_model = QStringListModel() + self._transformation_model = QStringListModel() + self._variant_model = QStringListModel() @classmethod - def createInstance( + def create_instance( cls, engine: VisualizationEngine, view: VisualizationParametersView ) -> VisualizationParametersController: controller = cls(engine, view) - view.rendererComboBox.setModel(controller._rendererModel) - view.transformationComboBox.setModel(controller._transformationModel) - view.variantComboBox.setModel(controller._variantModel) + view.renderer_combo_box.setModel(controller._renderer_model) + view.transformation_combo_box.setModel(controller._transformation_model) + view.variant_combo_box.setModel(controller._variant_model) - view.minDisplayValueLineEdit.valueChanged.connect( - lambda value: engine.setMinDisplayValue(float(value)) + view.min_display_value_line_edit.value_changed.connect( + lambda value: engine.set_min_display_value(float(value)) ) - view.maxDisplayValueLineEdit.valueChanged.connect( - lambda value: engine.setMaxDisplayValue(float(value)) + view.max_display_value_line_edit.value_changed.connect( + lambda value: engine.set_max_display_value(float(value)) ) - controller._syncModelToView() - engine.addObserver(controller) + controller._sync_model_to_view() + engine.add_observer(controller) - view.rendererComboBox.textActivated.connect(engine.setRenderer) - view.transformationComboBox.textActivated.connect(engine.setTransformation) - view.variantComboBox.textActivated.connect(engine.setVariant) + view.renderer_combo_box.textActivated.connect(engine.set_renderer) + view.transformation_combo_box.textActivated.connect(engine.set_transformation) + view.variant_combo_box.textActivated.connect(engine.set_variant) return controller - def _syncModelToView(self) -> None: - self._view.rendererComboBox.blockSignals(True) - self._rendererModel.setStringList([name for name in self._engine.renderers()]) - self._view.rendererComboBox.setCurrentText(self._engine.getRenderer()) - self._view.rendererComboBox.blockSignals(False) + def _sync_model_to_view(self) -> None: + self._view.renderer_combo_box.blockSignals(True) + self._renderer_model.setStringList([name for name in self._engine.renderers()]) + self._view.renderer_combo_box.setCurrentText(self._engine.get_renderer()) + self._view.renderer_combo_box.blockSignals(False) - self._view.transformationComboBox.blockSignals(True) - self._transformationModel.setStringList([name for name in self._engine.transformations()]) - self._view.transformationComboBox.setCurrentText(self._engine.getTransformation()) - self._view.transformationComboBox.blockSignals(False) + self._view.transformation_combo_box.blockSignals(True) + self._transformation_model.setStringList([name for name in self._engine.transformations()]) + self._view.transformation_combo_box.setCurrentText(self._engine.get_transformation()) + self._view.transformation_combo_box.blockSignals(False) - self._view.variantComboBox.blockSignals(True) - self._variantModel.setStringList([name for name in self._engine.variants()]) - self._view.variantComboBox.setCurrentText(self._engine.getVariant()) - self._view.variantComboBox.blockSignals(False) + self._view.variant_combo_box.blockSignals(True) + self._variant_model.setStringList([name for name in self._engine.variants()]) + self._view.variant_combo_box.setCurrentText(self._engine.get_variant()) + self._view.variant_combo_box.blockSignals(False) - self._view.minDisplayValueLineEdit.setValue( - Decimal(repr(self._engine.getMinDisplayValue())) + self._view.min_display_value_line_edit.set_value( + Decimal(repr(self._engine.get_min_display_value())) ) - self._view.maxDisplayValueLineEdit.setValue( - Decimal(repr(self._engine.getMaxDisplayValue())) + self._view.max_display_value_line_edit.set_value( + Decimal(repr(self._engine.get_max_display_value())) ) - def update(self, observable: Observable) -> None: + def _update(self, observable: Observable) -> None: if observable is self._engine: - self._syncModelToView() + self._sync_model_to_view() diff --git a/src/ptychodus/controller/visualization/widget.py b/src/ptychodus/controller/visualization/widget.py index 4ca92d7c..7a7ab40a 100644 --- a/src/ptychodus/controller/visualization/widget.py +++ b/src/ptychodus/controller/visualization/widget.py @@ -3,7 +3,7 @@ from PyQt5.QtWidgets import QStatusBar from ptychodus.api.geometry import PixelGeometry -from ptychodus.api.visualization import NumberArrayType +from ptychodus.api.typing import NumberArrayType from ...model.visualization import VisualizationEngine from ...view.visualization import VisualizationWidget @@ -16,23 +16,23 @@ def __init__( self, engine: VisualizationEngine, widget: VisualizationWidget, - statusBar: QStatusBar, - fileDialogFactory: FileDialogFactory, + status_bar: QStatusBar, + file_dialog_factory: FileDialogFactory, ) -> None: self._widget = widget - self._controller = VisualizationController.createInstance( - engine, widget.visualizationView, statusBar, fileDialogFactory + self._controller = VisualizationController.create_instance( + engine, widget.visualization_view, status_bar, file_dialog_factory ) - self._widget.homeAction.triggered.connect(self._controller.zoomToFit) - self._widget.saveAction.triggered.connect(self._controller.saveImage) - self._widget.autoscaleAction.triggered.connect(self._autoDisplayRange) + self._widget.home_action.triggered.connect(self._controller.zoom_to_fit) + self._widget.save_action.triggered.connect(self._controller.save_image) + self._widget.autoscale_action.triggered.connect(self._auto_display_range) - def _autoDisplayRange(self) -> None: - self._controller.rerenderImage(autoscaleColorAxis=True) + def _auto_display_range(self) -> None: + self._controller.rerender_image(autoscale_color_axis=True) - def setArray(self, array: NumberArrayType, pixelGeometry: PixelGeometry) -> None: - self._controller.setArray(array, pixelGeometry) + def set_array(self, array: NumberArrayType, pixel_geometry: PixelGeometry) -> None: + self._controller.set_array(array, pixel_geometry) - def clearArray(self) -> None: - self._controller.clearArray() + def clear_array(self) -> None: + self._controller.clear_array() diff --git a/src/ptychodus/controller/workflow/authorization.py b/src/ptychodus/controller/workflow/authorization.py index 074c8b26..883663be 100644 --- a/src/ptychodus/controller/workflow/authorization.py +++ b/src/ptychodus/controller/workflow/authorization.py @@ -17,38 +17,38 @@ def __init__( self._dialog = dialog @classmethod - def createInstance( + def create_instance( cls, presenter: WorkflowAuthorizationPresenter, dialog: WorkflowAuthorizationDialog, ) -> WorkflowAuthorizationController: controller = cls(presenter, dialog) - dialog.finished.connect(controller._finishAuthorization) - dialog.lineEdit.textChanged.connect(controller._setDialogButtonsEnabled) - controller._setDialogButtonsEnabled() + dialog.finished.connect(controller._finish_authorization) + dialog.line_edit.textChanged.connect(controller._set_dialog_buttons_enabled) + controller._set_dialog_buttons_enabled() return controller - def _setDialogButtonsEnabled(self) -> None: - text = self._dialog.lineEdit.text() - self._dialog.okButton.setEnabled(len(text) > 0) + def _set_dialog_buttons_enabled(self) -> None: + text = self._dialog.line_edit.text() + self._dialog.ok_button.setEnabled(len(text) > 0) - def startAuthorizationIfNeeded(self) -> None: - if not (self._presenter.isAuthorized or self._dialog.isVisible()): - self._startAuthorization() + def start_authorization_if_needed(self) -> None: + if not (self._presenter.is_authorized or self._dialog.isVisible()): + self._start_authorization() - def _startAuthorization(self) -> None: - authorizeURL = self._presenter.getAuthorizeURL() - text = f'Input the Globus authorization code from this link:' + def _start_authorization(self) -> None: + authorize_url = self._presenter.get_authorize_url() + text = f'Input the Globus authorization code from this link:' self._dialog.label.setText(text) - self._dialog.lineEdit.clear() + self._dialog.line_edit.clear() self._dialog.open() - def _finishAuthorization(self, result: int) -> None: + def _finish_authorization(self, result: int) -> None: if result != QDialog.DialogCode.Accepted: return - authCode = self._dialog.lineEdit.text() - self._presenter.setCodeFromAuthorizeURL(authCode) + auth_code = self._dialog.line_edit.text() + self._presenter.set_code_from_authorize_url(auth_code) diff --git a/src/ptychodus/controller/workflow/compute.py b/src/ptychodus/controller/workflow/compute.py index 888d675f..dd3554c7 100644 --- a/src/ptychodus/controller/workflow/compute.py +++ b/src/ptychodus/controller/workflow/compute.py @@ -15,45 +15,57 @@ def __init__(self, presenter: WorkflowParametersPresenter, view: WorkflowCompute self._view = view @classmethod - def createInstance( + def create_instance( cls, presenter: WorkflowParametersPresenter, view: WorkflowComputeView ) -> WorkflowComputeController: controller = cls(presenter, view) - presenter.addObserver(controller) + presenter.add_observer(controller) - view.computeEndpointIDLineEdit.editingFinished.connect( - controller._syncComputeEndpointIDToModel + view.compute_endpoint_id_line_edit.editingFinished.connect( + controller._sync_compute_endpoint_id_to_model ) - view.dataEndpointIDLineEdit.editingFinished.connect(controller._syncDataEndpointIDToModel) - view.dataGlobusPathLineEdit.editingFinished.connect(controller._syncGlobusPathToModel) - view.dataPosixPathLineEdit.editingFinished.connect(controller._syncPosixPathToModel) + view.data_endpoint_id_line_edit.editingFinished.connect( + controller._sync_data_endpoint_id_to_model + ) + view.data_globus_path_line_edit.editingFinished.connect( + controller._sync_globus_path_to_model + ) + view.data_posix_path_line_edit.editingFinished.connect(controller._sync_posix_path_to_model) - controller._syncModelToView() + controller._sync_model_to_view() return controller - def _syncComputeEndpointIDToModel(self) -> None: - endpointID = UUID(self._view.computeEndpointIDLineEdit.text()) - self._presenter.setComputeEndpointID(endpointID) + def _sync_compute_endpoint_id_to_model(self) -> None: + endpoint_id = UUID(self._view.compute_endpoint_id_line_edit.text()) + self._presenter.set_compute_endpoint_id(endpoint_id) - def _syncDataEndpointIDToModel(self) -> None: - endpointID = UUID(self._view.dataEndpointIDLineEdit.text()) - self._presenter.setComputeDataEndpointID(endpointID) + def _sync_data_endpoint_id_to_model(self) -> None: + endpoint_id = UUID(self._view.data_endpoint_id_line_edit.text()) + self._presenter.set_compute_data_endpoint_id(endpoint_id) - def _syncGlobusPathToModel(self) -> None: - dataPath = self._view.dataGlobusPathLineEdit.text() - self._presenter.setComputeDataGlobusPath(dataPath) + def _sync_globus_path_to_model(self) -> None: + data_path = self._view.data_globus_path_line_edit.text() + self._presenter.set_compute_data_globus_path(data_path) - def _syncPosixPathToModel(self) -> None: - dataPath = Path(self._view.dataPosixPathLineEdit.text()) - self._presenter.setComputeDataPosixPath(dataPath) + def _sync_posix_path_to_model(self) -> None: + data_path = Path(self._view.data_posix_path_line_edit.text()) + self._presenter.set_compute_data_posix_path(data_path) - def _syncModelToView(self) -> None: - self._view.computeEndpointIDLineEdit.setText(str(self._presenter.getComputeEndpointID())) - self._view.dataEndpointIDLineEdit.setText(str(self._presenter.getComputeDataEndpointID())) - self._view.dataGlobusPathLineEdit.setText(str(self._presenter.getComputeDataGlobusPath())) - self._view.dataPosixPathLineEdit.setText(str(self._presenter.getComputeDataPosixPath())) + def _sync_model_to_view(self) -> None: + self._view.compute_endpoint_id_line_edit.setText( + str(self._presenter.get_compute_endpoint_id()) + ) + self._view.data_endpoint_id_line_edit.setText( + str(self._presenter.get_compute_data_endpoint_id()) + ) + self._view.data_globus_path_line_edit.setText( + str(self._presenter.get_compute_data_globus_path()) + ) + self._view.data_posix_path_line_edit.setText( + str(self._presenter.get_compute_data_posix_path()) + ) - def update(self, observable: Observable) -> None: + def _update(self, observable: Observable) -> None: if observable is self._presenter: - self._syncModelToView() + self._sync_model_to_view() diff --git a/src/ptychodus/controller/workflow/controller.py b/src/ptychodus/controller/workflow/controller.py index 6f757127..692a2a48 100644 --- a/src/ptychodus/controller/workflow/controller.py +++ b/src/ptychodus/controller/workflow/controller.py @@ -1,5 +1,3 @@ -from __future__ import annotations - from PyQt5.QtCore import QAbstractItemModel, QTimer from PyQt5.QtWidgets import QTableView @@ -18,58 +16,34 @@ class WorkflowController: def __init__( self, - parametersPresenter: WorkflowParametersPresenter, - authorizationPresenter: WorkflowAuthorizationPresenter, - statusPresenter: WorkflowStatusPresenter, - executionPresenter: WorkflowExecutionPresenter, - parametersView: WorkflowParametersView, - tableView: QTableView, - productItemModel: QAbstractItemModel, + parameters_presenter: WorkflowParametersPresenter, + authorization_presenter: WorkflowAuthorizationPresenter, + status_presenter: WorkflowStatusPresenter, + execution_presenter: WorkflowExecutionPresenter, + parameters_view: WorkflowParametersView, + table_view: QTableView, + product_item_model: QAbstractItemModel, ) -> None: - self._parametersPresenter = parametersPresenter - self._authorizationPresenter = authorizationPresenter - self._executionPresenter = executionPresenter - self._parametersView = parametersView - self._authorizationController = WorkflowAuthorizationController.createInstance( - authorizationPresenter, parametersView.authorizationDialog + self._parameters_presenter = parameters_presenter + self._authorization_presenter = authorization_presenter + self._execution_presenter = execution_presenter + self._parameters_view = parameters_view + self._authorization_controller = WorkflowAuthorizationController.create_instance( + authorization_presenter, parameters_view.authorization_dialog ) - self._statusController = WorkflowStatusController.createInstance( - statusPresenter, parametersView.statusView, tableView + self._status_controller = WorkflowStatusController( + status_presenter, parameters_view.status_view, table_view ) - self._executionController = WorkflowExecutionController.createInstance( - parametersPresenter, - executionPresenter, - parametersView.executionView, - productItemModel, + self._execution_controller = WorkflowExecutionController.create_instance( + parameters_presenter, + execution_presenter, + parameters_view.execution_view, + product_item_model, ) self._timer = QTimer() + self._timer.timeout.connect(self._process_events) + self._timer.start(5 * 1000) # TODO customize - @classmethod - def createInstance( - cls, - parametersPresenter: WorkflowParametersPresenter, - authorizationPresenter: WorkflowAuthorizationPresenter, - statusPresenter: WorkflowStatusPresenter, - executionPresenter: WorkflowExecutionPresenter, - parametersView: WorkflowParametersView, - tableView: QTableView, - productItemModel: QAbstractItemModel, - ) -> WorkflowController: - controller = cls( - parametersPresenter, - authorizationPresenter, - statusPresenter, - executionPresenter, - parametersView, - tableView, - productItemModel, - ) - - controller._timer.timeout.connect(controller._processEvents) - controller._timer.start(1000) # TODO customize - - return controller - - def _processEvents(self) -> None: - self._authorizationController.startAuthorizationIfNeeded() - self._statusController.refreshTableView() + def _process_events(self) -> None: + self._authorization_controller.start_authorization_if_needed() + self._status_controller.refresh_table_view() diff --git a/src/ptychodus/controller/workflow/execution.py b/src/ptychodus/controller/workflow/execution.py index 6df874a0..27219f8b 100644 --- a/src/ptychodus/controller/workflow/execution.py +++ b/src/ptychodus/controller/workflow/execution.py @@ -7,8 +7,8 @@ from ...view.widgets import ExceptionDialog from ...view.workflow import WorkflowExecutionView from .compute import WorkflowComputeController -from .inputData import WorkflowInputDataController -from .outputData import WorkflowOutputDataController +from .input_data import WorkflowInputDataController +from .output_data import WorkflowOutputDataController logger = logging.getLogger(__name__) @@ -16,44 +16,44 @@ class WorkflowExecutionController: def __init__( self, - parametersPresenter: WorkflowParametersPresenter, - executionPresenter: WorkflowExecutionPresenter, + parameters_presenter: WorkflowParametersPresenter, + execution_presenter: WorkflowExecutionPresenter, view: WorkflowExecutionView, ) -> None: - self._executionPresenter = executionPresenter + self._execution_presenter = execution_presenter self._view = view - self._inputDataController = WorkflowInputDataController.createInstance( - parametersPresenter, view.inputDataView + self._input_data_controller = WorkflowInputDataController.create_instance( + parameters_presenter, view.input_data_view ) - self._computeController = WorkflowComputeController.createInstance( - parametersPresenter, view.computeView + self._compute_controller = WorkflowComputeController.create_instance( + parameters_presenter, view.compute_view ) - self._outputDataController = WorkflowOutputDataController.createInstance( - parametersPresenter, view.outputDataView + self._output_data_controller = WorkflowOutputDataController.create_instance( + parameters_presenter, view.output_data_view ) @classmethod - def createInstance( + def create_instance( cls, - parametersPresenter: WorkflowParametersPresenter, - executionPresenter: WorkflowExecutionPresenter, + parameters_presenter: WorkflowParametersPresenter, + execution_presenter: WorkflowExecutionPresenter, view: WorkflowExecutionView, - productItemModel: QAbstractItemModel, + product_item_model: QAbstractItemModel, ) -> WorkflowExecutionController: - controller = cls(parametersPresenter, executionPresenter, view) - view.productComboBox.setModel(productItemModel) - view.executeButton.clicked.connect(controller._execute) + controller = cls(parameters_presenter, execution_presenter, view) + view.product_combo_box.setModel(product_item_model) + view.execute_button.clicked.connect(controller._execute) return controller def _execute(self) -> None: - inputProductIndex = self._view.productComboBox.currentIndex() + input_product_index = self._view.product_combo_box.currentIndex() - if inputProductIndex < 0: + if input_product_index < 0: logger.debug('No current index!') return try: - self._executionPresenter.runFlow(inputProductIndex) + self._execution_presenter.run_flow(input_product_index) except Exception as err: logger.exception(err) - ExceptionDialog.showException('Reconstructor', err) + ExceptionDialog.show_exception('Reconstructor', err) diff --git a/src/ptychodus/controller/workflow/inputData.py b/src/ptychodus/controller/workflow/inputData.py deleted file mode 100644 index 0ce6aebe..00000000 --- a/src/ptychodus/controller/workflow/inputData.py +++ /dev/null @@ -1,51 +0,0 @@ -from __future__ import annotations -from pathlib import Path -from uuid import UUID - -from ptychodus.api.observer import Observable, Observer - -from ...model.workflow import WorkflowParametersPresenter -from ...view.workflow import WorkflowInputDataView - - -class WorkflowInputDataController(Observer): - def __init__(self, presenter: WorkflowParametersPresenter, view: WorkflowInputDataView) -> None: - super().__init__() - self._presenter = presenter - self._view = view - - @classmethod - def createInstance( - cls, presenter: WorkflowParametersPresenter, view: WorkflowInputDataView - ) -> WorkflowInputDataController: - controller = cls(presenter, view) - presenter.addObserver(controller) - - view.endpointIDLineEdit.editingFinished.connect(controller._syncEndpointIDToModel) - view.globusPathLineEdit.editingFinished.connect(controller._syncGlobusPathToModel) - view.posixPathLineEdit.editingFinished.connect(controller._syncPosixPathToModel) - - controller._syncModelToView() - - return controller - - def _syncEndpointIDToModel(self) -> None: - endpointID = UUID(self._view.endpointIDLineEdit.text()) - self._presenter.setInputDataEndpointID(endpointID) - - def _syncGlobusPathToModel(self) -> None: - dataPath = self._view.globusPathLineEdit.text() - self._presenter.setInputDataGlobusPath(dataPath) - - def _syncPosixPathToModel(self) -> None: - dataPath = Path(self._view.posixPathLineEdit.text()).expanduser() - self._presenter.setInputDataPosixPath(dataPath) - - def _syncModelToView(self) -> None: - self._view.endpointIDLineEdit.setText(str(self._presenter.getInputDataEndpointID())) - self._view.globusPathLineEdit.setText(str(self._presenter.getInputDataGlobusPath())) - self._view.posixPathLineEdit.setText(str(self._presenter.getInputDataPosixPath())) - - def update(self, observable: Observable) -> None: - if observable is self._presenter: - self._syncModelToView() diff --git a/src/ptychodus/controller/workflow/input_data.py b/src/ptychodus/controller/workflow/input_data.py new file mode 100644 index 00000000..9c6cc172 --- /dev/null +++ b/src/ptychodus/controller/workflow/input_data.py @@ -0,0 +1,51 @@ +from __future__ import annotations +from pathlib import Path +from uuid import UUID + +from ptychodus.api.observer import Observable, Observer + +from ...model.workflow import WorkflowParametersPresenter +from ...view.workflow import WorkflowInputDataView + + +class WorkflowInputDataController(Observer): + def __init__(self, presenter: WorkflowParametersPresenter, view: WorkflowInputDataView) -> None: + super().__init__() + self._presenter = presenter + self._view = view + + @classmethod + def create_instance( + cls, presenter: WorkflowParametersPresenter, view: WorkflowInputDataView + ) -> WorkflowInputDataController: + controller = cls(presenter, view) + presenter.add_observer(controller) + + view.endpoint_id_line_edit.editingFinished.connect(controller._sync_endpoint_id_to_model) + view.globus_path_line_edit.editingFinished.connect(controller._sync_globus_path_to_model) + view.posix_path_line_edit.editingFinished.connect(controller._sync_posix_path_to_model) + + controller._sync_model_to_view() + + return controller + + def _sync_endpoint_id_to_model(self) -> None: + endpoint_id = UUID(self._view.endpoint_id_line_edit.text()) + self._presenter.set_input_data_endpoint_id(endpoint_id) + + def _sync_globus_path_to_model(self) -> None: + data_path = self._view.globus_path_line_edit.text() + self._presenter.set_input_data_globus_path(data_path) + + def _sync_posix_path_to_model(self) -> None: + data_path = Path(self._view.posix_path_line_edit.text()).expanduser() + self._presenter.set_input_data_posix_path(data_path) + + def _sync_model_to_view(self) -> None: + self._view.endpoint_id_line_edit.setText(str(self._presenter.get_input_data_endpoint_id())) + self._view.globus_path_line_edit.setText(str(self._presenter.get_input_data_globus_path())) + self._view.posix_path_line_edit.setText(str(self._presenter.get_input_data_posix_path())) + + def _update(self, observable: Observable) -> None: + if observable is self._presenter: + self._sync_model_to_view() diff --git a/src/ptychodus/controller/workflow/outputData.py b/src/ptychodus/controller/workflow/outputData.py deleted file mode 100644 index 1991d2d8..00000000 --- a/src/ptychodus/controller/workflow/outputData.py +++ /dev/null @@ -1,59 +0,0 @@ -from __future__ import annotations -from pathlib import Path -from uuid import UUID - -from ptychodus.api.observer import Observable, Observer - -from ...model.workflow import WorkflowParametersPresenter -from ...view.workflow import WorkflowOutputDataView - - -class WorkflowOutputDataController(Observer): - def __init__( - self, presenter: WorkflowParametersPresenter, view: WorkflowOutputDataView - ) -> None: - super().__init__() - self._presenter = presenter - self._view = view - - @classmethod - def createInstance( - cls, presenter: WorkflowParametersPresenter, view: WorkflowOutputDataView - ) -> WorkflowOutputDataController: - controller = cls(presenter, view) - presenter.addObserver(controller) - - view.roundTripCheckBox.toggled.connect(presenter.setRoundTripEnabled) - view.endpointIDLineEdit.editingFinished.connect(controller._syncEndpointIDToModel) - view.globusPathLineEdit.editingFinished.connect(controller._syncGlobusPathToModel) - view.posixPathLineEdit.editingFinished.connect(controller._syncPosixPathToModel) - - controller._syncModelToView() - - return controller - - def _syncEndpointIDToModel(self) -> None: - endpointID = UUID(self._view.endpointIDLineEdit.text()) - self._presenter.setOutputDataEndpointID(endpointID) - - def _syncGlobusPathToModel(self) -> None: - dataPath = self._view.globusPathLineEdit.text() - self._presenter.setOutputDataGlobusPath(dataPath) - - def _syncPosixPathToModel(self) -> None: - dataPath = Path(self._view.posixPathLineEdit.text()) - self._presenter.setOutputDataPosixPath(dataPath) - - def _syncModelToView(self) -> None: - isRoundTripEnabled = self._presenter.isRoundTripEnabled() - self._view.roundTripCheckBox.setChecked(isRoundTripEnabled) - self._view.endpointIDLineEdit.setText(str(self._presenter.getOutputDataEndpointID())) - self._view.endpointIDLineEdit.setEnabled(not isRoundTripEnabled) - self._view.globusPathLineEdit.setText(str(self._presenter.getOutputDataGlobusPath())) - self._view.globusPathLineEdit.setEnabled(not isRoundTripEnabled) - self._view.posixPathLineEdit.setText(str(self._presenter.getOutputDataPosixPath())) - self._view.posixPathLineEdit.setEnabled(not isRoundTripEnabled) - - def update(self, observable: Observable) -> None: - if observable is self._presenter: - self._syncModelToView() diff --git a/src/ptychodus/controller/workflow/output_data.py b/src/ptychodus/controller/workflow/output_data.py new file mode 100644 index 00000000..4b5f5e59 --- /dev/null +++ b/src/ptychodus/controller/workflow/output_data.py @@ -0,0 +1,59 @@ +from __future__ import annotations +from pathlib import Path +from uuid import UUID + +from ptychodus.api.observer import Observable, Observer + +from ...model.workflow import WorkflowParametersPresenter +from ...view.workflow import WorkflowOutputDataView + + +class WorkflowOutputDataController(Observer): + def __init__( + self, presenter: WorkflowParametersPresenter, view: WorkflowOutputDataView + ) -> None: + super().__init__() + self._presenter = presenter + self._view = view + + @classmethod + def create_instance( + cls, presenter: WorkflowParametersPresenter, view: WorkflowOutputDataView + ) -> WorkflowOutputDataController: + controller = cls(presenter, view) + presenter.add_observer(controller) + + view.round_trip_check_box.toggled.connect(presenter.set_round_trip_enabled) + view.endpoint_id_line_edit.editingFinished.connect(controller._sync_endpoint_id_to_model) + view.globus_path_line_edit.editingFinished.connect(controller._sync_globus_path_to_model) + view.posix_path_line_edit.editingFinished.connect(controller._sync_posix_path_to_model) + + controller._sync_model_to_view() + + return controller + + def _sync_endpoint_id_to_model(self) -> None: + endpoint_id = UUID(self._view.endpoint_id_line_edit.text()) + self._presenter.set_output_data_endpoint_id(endpoint_id) + + def _sync_globus_path_to_model(self) -> None: + data_path = self._view.globus_path_line_edit.text() + self._presenter.set_output_data_globus_path(data_path) + + def _sync_posix_path_to_model(self) -> None: + data_path = Path(self._view.posix_path_line_edit.text()) + self._presenter.set_output_data_posix_path(data_path) + + def _sync_model_to_view(self) -> None: + is_round_trip_enabled = self._presenter.is_round_trip_enabled() + self._view.round_trip_check_box.setChecked(is_round_trip_enabled) + self._view.endpoint_id_line_edit.setText(str(self._presenter.get_output_data_endpoint_id())) + self._view.endpoint_id_line_edit.setEnabled(not is_round_trip_enabled) + self._view.globus_path_line_edit.setText(str(self._presenter.get_output_data_globus_path())) + self._view.globus_path_line_edit.setEnabled(not is_round_trip_enabled) + self._view.posix_path_line_edit.setText(str(self._presenter.get_output_data_posix_path())) + self._view.posix_path_line_edit.setEnabled(not is_round_trip_enabled) + + def _update(self, observable: Observable) -> None: + if observable is self._presenter: + self._sync_model_to_view() diff --git a/src/ptychodus/controller/workflow/status.py b/src/ptychodus/controller/workflow/status.py index eb095415..1bd8ff7c 100644 --- a/src/ptychodus/controller/workflow/status.py +++ b/src/ptychodus/controller/workflow/status.py @@ -1,82 +1,78 @@ -from __future__ import annotations import logging from PyQt5.QtCore import Qt, QModelIndex, QSortFilterProxyModel, QTimer from PyQt5.QtWidgets import QAbstractItemView, QTableView from PyQt5.QtGui import QDesktopServices +from ptychodus.api.observer import Observable, Observer + from ...model.workflow import WorkflowStatusPresenter from ...view.workflow import WorkflowStatusView -from .tableModel import WorkflowTableModel +from .table_model import WorkflowTableModel logger = logging.getLogger(__name__) -class WorkflowStatusController: +class WorkflowStatusController(Observer): def __init__( self, presenter: WorkflowStatusPresenter, view: WorkflowStatusView, - tableView: QTableView, + table_view: QTableView, ) -> None: + super().__init__() self._presenter = presenter self._view = view - self._tableView = tableView - self._tableModel = WorkflowTableModel(presenter) - self._proxyModel = QSortFilterProxyModel() + self._table_view = table_view + self._table_model = WorkflowTableModel(presenter) + self._proxy_model = QSortFilterProxyModel() + self._proxy_model.setSourceModel(self._table_model) self._timer = QTimer() + self._timer.timeout.connect(presenter.refresh_status) - @classmethod - def createInstance( - cls, - presenter: WorkflowStatusPresenter, - view: WorkflowStatusView, - tableView: QTableView, - ) -> WorkflowStatusController: - controller = cls(presenter, view, tableView) - - controller._proxyModel.setSourceModel(controller._tableModel) - tableView.setModel(controller._proxyModel) - tableView.setSortingEnabled(True) - tableView.setSelectionBehavior(QAbstractItemView.SelectionBehavior.SelectRows) - tableView.clicked.connect(controller._handleTableViewClick) - - controller._timer.timeout.connect(presenter.refreshStatus) - view.autoRefreshCheckBox.toggled.connect(controller._autoRefreshStatus) - view.autoRefreshSpinBox.valueChanged.connect(presenter.setRefreshIntervalInSeconds) - view.refreshButton.clicked.connect(presenter.refreshStatus) + table_view.setModel(self._proxy_model) + table_view.setSortingEnabled(True) + table_view.setSelectionBehavior(QAbstractItemView.SelectionBehavior.SelectRows) + table_view.clicked.connect(self._handle_table_view_clicked) - controller._syncModelToView() + view.auto_refresh_check_box.toggled.connect(self._auto_refresh_status) + view.auto_refresh_spin_box.valueChanged.connect(presenter.set_refresh_interval_s) + view.refresh_button.clicked.connect(presenter.refresh_status) - return controller + self._sync_model_to_view() + presenter.add_observer(self) - def _handleTableViewClick(self, index: QModelIndex) -> None: + def _handle_table_view_clicked(self, index: QModelIndex) -> None: if index.column() == 5: url = index.data(Qt.ItemDataRole.UserRole) - logger.debug(f'Opening URL: "{url.toString()}"') + logger.info(f'Opening URL: "{url.toString()}"') QDesktopServices.openUrl(url) - def _autoRefreshStatus(self) -> None: - if self._view.autoRefreshCheckBox.isChecked(): - self._timer.start(1000 * self._presenter.getRefreshIntervalInSeconds()) - self._view.autoRefreshSpinBox.setEnabled(False) - self._view.refreshButton.setEnabled(False) + def _auto_refresh_status(self) -> None: + if self._view.auto_refresh_check_box.isChecked(): + self._timer.start(1000 * self._presenter.get_refresh_interval_s()) + self._view.auto_refresh_spin_box.setEnabled(False) + self._view.refresh_button.setEnabled(False) else: self._timer.stop() - self._view.autoRefreshSpinBox.setEnabled(True) - self._view.refreshButton.setEnabled(True) + self._view.auto_refresh_spin_box.setEnabled(True) + self._view.refresh_button.setEnabled(True) - def refreshTableView(self) -> None: + def refresh_table_view(self) -> None: # TODO only reset if changed - self._tableModel.beginResetModel() - self._tableModel.endResetModel() + self._table_model.beginResetModel() + self._table_model.endResetModel() - def _syncModelToView(self) -> None: - refreshIntervalLimitsInSeconds = self._presenter.getRefreshIntervalLimitsInSeconds() + def _sync_model_to_view(self) -> None: + refresh_interval_limits_s = self._presenter.get_refresh_interval_limits_s() - self._view.autoRefreshSpinBox.blockSignals(True) - self._view.autoRefreshSpinBox.setRange( - refreshIntervalLimitsInSeconds.lower, refreshIntervalLimitsInSeconds.upper + self._view.auto_refresh_spin_box.blockSignals(True) + self._view.auto_refresh_spin_box.setRange( + refresh_interval_limits_s.lower, refresh_interval_limits_s.upper ) - self._view.autoRefreshSpinBox.setValue(self._presenter.getRefreshIntervalInSeconds()) - self._view.autoRefreshSpinBox.blockSignals(False) + self._view.auto_refresh_spin_box.setValue(self._presenter.get_refresh_interval_s()) + self._view.auto_refresh_spin_box.blockSignals(False) + + def _update(self, observable: Observable) -> None: + if observable is self._presenter: + self._sync_model_to_view() diff --git a/src/ptychodus/controller/workflow/tableModel.py b/src/ptychodus/controller/workflow/table_model.py similarity index 62% rename from src/ptychodus/controller/workflow/tableModel.py rename to src/ptychodus/controller/workflow/table_model.py index 7ca1aa25..c3c7d407 100644 --- a/src/ptychodus/controller/workflow/tableModel.py +++ b/src/ptychodus/controller/workflow/table_model.py @@ -10,7 +10,7 @@ class WorkflowTableModel(QAbstractTableModel): def __init__(self, presenter: WorkflowStatusPresenter, parent: QObject | None = None) -> None: super().__init__(parent) self._presenter = presenter - self._sectionHeaders = [ + self._section_headers = [ 'Label', 'Start Time', 'Completion Time', @@ -18,9 +18,9 @@ def __init__(self, presenter: WorkflowStatusPresenter, parent: QObject | None = 'Action', 'Run ID', ] - self._dtFormat = '%Y-%m-%d %H:%M:%S' + self._dt_format = '%Y-%m-%d %H:%M:%S' - def headerData( + def headerData( # noqa: N802 self, section: int, orientation: Qt.Orientation, @@ -28,31 +28,32 @@ def headerData( ) -> Any: if role == Qt.ItemDataRole.DisplayRole: if orientation == Qt.Orientation.Horizontal: - return self._sectionHeaders[section] + return self._section_headers[section] elif orientation == Qt.Orientation.Vertical: return section def data(self, index: QModelIndex, role: int = Qt.ItemDataRole.DisplayRole) -> Any: if index.isValid(): - flowRun = self._presenter[index.row()] + flow_run = self._presenter[index.row()] if role == Qt.ItemDataRole.DisplayRole: - if index.column() == 0: - return flowRun.label - elif index.column() == 1: - return flowRun.startTime.strftime(self._dtFormat) - elif index.column() == 2: - if flowRun.completionTime is not None: - return flowRun.completionTime.strftime(self._dtFormat) - elif index.column() == 3: - return flowRun.status - elif index.column() == 4: - return flowRun.action - elif index.column() == 5: - return flowRun.runID + match index.column(): + case 0: + return flow_run.label + case 1: + return flow_run.start_time.strftime(self._dt_format) + case 2: + if flow_run.completion_time is not None: + return flow_run.completion_time.strftime(self._dt_format) + case 3: + return flow_run.status + case 4: + return flow_run.action + case 5: + return flow_run.run_id elif index.column() == 5: if role == Qt.ItemDataRole.ToolTipRole: - return flowRun.runURL + return flow_run.run_url elif role == Qt.ItemDataRole.FontRole: font = QFont() font.setUnderline(True) @@ -61,10 +62,10 @@ def data(self, index: QModelIndex, role: int = Qt.ItemDataRole.DisplayRole) -> A color = QColor(Qt.GlobalColor.blue) return color elif role == Qt.ItemDataRole.UserRole: - return QUrl(flowRun.runURL) + return QUrl(flow_run.run_url) - def rowCount(self, parent: QModelIndex = QModelIndex()) -> int: + def rowCount(self, parent: QModelIndex = QModelIndex()) -> int: # noqa: N802 return len(self._presenter) - def columnCount(self, parent: QModelIndex = QModelIndex()) -> int: - return len(self._sectionHeaders) + def columnCount(self, parent: QModelIndex = QModelIndex()) -> int: # noqa: N802 + return len(self._section_headers) diff --git a/src/ptychodus/model/agent/__init__.py b/src/ptychodus/model/agent/__init__.py new file mode 100644 index 00000000..939e1a1f --- /dev/null +++ b/src/ptychodus/model/agent/__init__.py @@ -0,0 +1,13 @@ +from .chat import ChatMessage, ChatHistory, ChatObserver, ChatRole +from .core import AgentCore, AgentPresenter +from .settings import ArgoSettings + +__all__ = [ + 'AgentCore', + 'AgentPresenter', + 'ArgoSettings', + 'ChatHistory', + 'ChatMessage', + 'ChatObserver', + 'ChatRole', +] diff --git a/src/ptychodus/model/agent/argo.py b/src/ptychodus/model/agent/argo.py new file mode 100644 index 00000000..edaeae63 --- /dev/null +++ b/src/ptychodus/model/agent/argo.py @@ -0,0 +1,73 @@ +from collections.abc import Sequence +import logging +import requests + +from .chat import ChatHistory, ChatMessage, ChatRole, ChatTerminal +from .settings import ArgoSettings + +logger = logging.getLogger(__name__) + + +class ArgoChatTerminal(ChatTerminal): + def __init__(self, settings: ArgoSettings, history: ChatHistory) -> None: + self._settings = settings + self._history = history + + def send_message(self, content: str, stop: Sequence[str] = []) -> None: + if not content: + return + + messages = [ + ChatMessage( + role=ChatRole.SYSTEM, content='You are a large language model with the name Argo.' + ) + ] + + for line in content.splitlines(): + message = ChatMessage(role=ChatRole.USER, content=line) + messages.append(message) + self._history.add_message(message) + + logger.debug(f'{messages=}') + + url = self._settings.chat_endpoint_url.get_value() + payload = { + 'user': self._settings.user.get_value(), + 'model': self._settings.chat_model.get_value(), + 'messages': [m.to_dict() for m in messages], + 'stop': stop, + 'temperature': self._settings.temperature.get_value(), + 'top_p': self._settings.top_p.get_value(), + 'max_tokens': self._settings.max_tokens.get_value(), + 'max_completion_tokens': self._settings.max_completion_tokens.get_value(), + } + headers = {'Content-Type': 'application/json'} + response = requests.post(url, json=payload, headers=headers) + + logger.debug(f'{response=}') + logger.debug(f'Status Code: {response.status_code}') + response_json = response.json() + logger.debug(f'JSON Response: {response_json}') + response.raise_for_status() + + response_message = ChatMessage(role=ChatRole.AGENT, content=response_json['response']) + self._history.add_message(response_message) + + def embed_texts(self, texts: Sequence[str]) -> Sequence[Sequence[float]]: + """Generates embeddings for a list of strings.""" + url = self._settings.embeddings_endpoint_url.get_value() + payload = { + 'user': self._settings.user.get_value(), + 'model': self._settings.embeddings_model.get_value(), + 'prompt': texts, + } + headers = {'Content-Type': 'application/json'} + response = requests.post(url, json=payload, headers=headers) + + logger.debug(response) + logger.debug(f'Status Code: {response.status_code}') + response_json = response.json() + logger.debug(f'JSON Response: {response_json}') + response.raise_for_status() + + return response_json['embedding'] diff --git a/src/ptychodus/model/agent/chat.py b/src/ptychodus/model/agent/chat.py new file mode 100644 index 00000000..ffd415a0 --- /dev/null +++ b/src/ptychodus/model/agent/chat.py @@ -0,0 +1,81 @@ +from abc import ABC, abstractmethod +from collections.abc import Sequence +from dataclasses import dataclass +from enum import Enum, auto +from typing import overload + + +class ChatRole(Enum): + SYSTEM = auto() + USER = auto() + AGENT = auto() + + +@dataclass(frozen=True) +class ChatMessage: + role: ChatRole + content: str + + def to_dict(self) -> dict[str, str]: + return { + 'role': self.role.name.lower(), + 'content': self.content, + } + + def __str__(self) -> str: + return str(self.to_dict()) + + +class ChatTerminal(ABC): + @abstractmethod + def send_message(self, content: str) -> None: + pass + + @abstractmethod + def embed_texts(self, texts: Sequence[str]) -> Sequence[Sequence[float]]: + pass + + +class ChatObserver(ABC): + @abstractmethod + def handle_new_message(self, message: ChatMessage, index: int) -> None: + pass + + @abstractmethod + def handle_chat_cleared(self) -> None: + pass + + +class ChatHistory(Sequence[ChatMessage]): + def __init__(self) -> None: + self._messages: list[ChatMessage] = [] + self._observers: list[ChatObserver] = [] + + @overload + def __getitem__(self, index: int) -> ChatMessage: ... + + @overload + def __getitem__(self, index: slice) -> Sequence[ChatMessage]: ... + + def __getitem__(self, index: int | slice) -> ChatMessage | Sequence[ChatMessage]: + return self._messages[index] + + def __len__(self) -> int: + return len(self._messages) + + def add_observer(self, observer: ChatObserver) -> None: + if observer not in self._observers: + self._observers.append(observer) + + def add_message(self, message: ChatMessage) -> None: + index = len(self._messages) + self._messages.append(message) + + for observer in self._observers: + observer.handle_new_message(message, index) + + def clear(self) -> None: + self._messages.clear() + + for observer in self._observers: + observer.handle_chat_cleared() diff --git a/src/ptychodus/model/agent/core.py b/src/ptychodus/model/agent/core.py new file mode 100644 index 00000000..034505ce --- /dev/null +++ b/src/ptychodus/model/agent/core.py @@ -0,0 +1,49 @@ +from collections.abc import Iterator, Sequence +from dataclasses import dataclass +import logging + +from ptychodus.api.settings import SettingsRegistry + +from .argo import ArgoChatTerminal +from .chat import ChatHistory, ChatTerminal +from .settings import ArgoSettings + +logger = logging.getLogger(__name__) + + +class AgentPresenter: + def __init__(self, terminal: ChatTerminal) -> None: + self._terminal = terminal + + def get_available_chat_models(self) -> Iterator[str]: + for model in [ + 'gpt35', + 'gpt35large', + 'gpt4', + 'gpt4large', + 'gpt4o', + 'gpt4olatestgpt4turbo', + 'gpto1', + 'gpto1mini', + 'gpto3mini', + ]: + yield model + + def send_message(self, content: str) -> None: + if self._terminal is not None: + self._terminal.send_message(content) + + def get_available_embeddings_models(self) -> Iterator[str]: + for model in ['ada002', 'v3large', 'v3small']: + yield model + + def embed_text(self, texts: Sequence[str]) -> Sequence[Sequence[float]]: + return [[]] if self._terminal is None else self._terminal.embed_texts(texts) + + +class AgentCore: + def __init__(self, settings_registry: SettingsRegistry): + self.settings = ArgoSettings(settings_registry) + self.chat_history = ChatHistory() + self._terminal = ArgoChatTerminal(self.settings, self.chat_history) + self.presenter = AgentPresenter(self._terminal) diff --git a/src/ptychodus/model/agent/properties.py b/src/ptychodus/model/agent/properties.py new file mode 100644 index 00000000..3f58b4be --- /dev/null +++ b/src/ptychodus/model/agent/properties.py @@ -0,0 +1,50 @@ +from collections.abc import Sequence +from dataclasses import dataclass + + +@dataclass(frozen=True) +class ChatModelProperties: + name: str + + @property + def is_o_series_model(self) -> bool: + return self.name.startswith('gpto') + + @property + def accepts_system_prompt(self) -> bool: + return not self.is_o_series_model + + @property + def accepts_stop_sequence(self) -> bool: + return not self.is_o_series_model + + @property + def accepts_temperature(self) -> bool: + return not self.is_o_series_model + + @property + def accepts_top_p(self) -> bool: + return not self.is_o_series_model + + @property + def accepts_max_tokens(self) -> bool: + return not self.is_o_series_model + + @property + def accepts_max_completion_tokens(self) -> bool: + return self.is_o_series_model + + +def list_argo_model_properties() -> Sequence[ChatModelProperties]: + return [ + ChatModelProperties(name='gpt35'), + ChatModelProperties(name='gpt35large'), + ChatModelProperties(name='gpt4'), + ChatModelProperties(name='gpt4large'), + ChatModelProperties(name='gpt4o'), + ChatModelProperties(name='gpt4olatest'), + ChatModelProperties(name='gpt4turbo'), + ChatModelProperties(name='gpto1'), + ChatModelProperties(name='gpto1mini'), + ChatModelProperties(name='gpto3mini'), + ] diff --git a/src/ptychodus/model/agent/settings.py b/src/ptychodus/model/agent/settings.py new file mode 100644 index 00000000..42f7eeac --- /dev/null +++ b/src/ptychodus/model/agent/settings.py @@ -0,0 +1,35 @@ +import getpass + +from ptychodus.api.observer import Observable, Observer +from ptychodus.api.settings import SettingsRegistry + + +class ArgoSettings(Observable, Observer): + def __init__(self, registry: SettingsRegistry) -> None: + super().__init__() + self._group = registry.create_group('Argo') + self._group.add_observer(self) + + self.user = self._group.create_string_parameter('User', getpass.getuser()) + self.chat_endpoint_url = self._group.create_string_parameter( + 'ChatEndpointURL', 'https://apps.inside.anl.gov/argoapi/api/v1/resource/chat/' + ) + self.chat_model = self._group.create_string_parameter('ChatModel', 'gpt35') + self.temperature = self._group.create_real_parameter( + 'Temperature', 0.1, minimum=0.0, maximum=2.0 + ) + self.top_p = self._group.create_real_parameter('TopP', 0.9, minimum=0.0, maximum=1.0) + self.max_tokens = self._group.create_integer_parameter( + 'MaxTokens', 1000, minimum=0, maximum=128000 + ) + self.max_completion_tokens = self._group.create_integer_parameter( + 'MaxCompletionTokens', 1000, minimum=0, maximum=128000 + ) + self.embeddings_endpoint_url = self._group.create_string_parameter( + 'EmbeddingsEndpointURL', 'https://apps.inside.anl.gov/argoapi/api/v1/resource/embed/' + ) + self.embeddings_model = self._group.create_string_parameter('EmbeddingsModel', 'ada002') + + def _update(self, observable: Observable) -> None: + if observable is self._group: + self.notify_observers() diff --git a/src/ptychodus/model/analysis/__init__.py b/src/ptychodus/model/analysis/__init__.py index 250c79ca..c17674bd 100644 --- a/src/ptychodus/model/analysis/__init__.py +++ b/src/ptychodus/model/analysis/__init__.py @@ -1,21 +1,20 @@ +from .barycentric import BarycentricArrayInterpolator, BarycentricArrayStitcher from .core import AnalysisCore -from .exposure import ExposureAnalyzer, ExposureMap from .frc import FourierRingCorrelator -from .objectInterpolator import ObjectLinearInterpolator -from .objectStitcher import ObjectStitcher +from .illumination import IlluminationMapper, IlluminationMap from .propagator import ProbePropagator from .stxm import STXMSimulator -from .xmcd import XMCDAnalyzer, XMCDResult +from .xmcd import XMCDAnalyzer, XMCDData __all__ = [ 'AnalysisCore', - 'ExposureAnalyzer', - 'ExposureMap', + 'BarycentricArrayInterpolator', + 'BarycentricArrayStitcher', 'FourierRingCorrelator', - 'ObjectLinearInterpolator', - 'ObjectStitcher', + 'IlluminationMap', + 'IlluminationMapper', 'ProbePropagator', 'STXMSimulator', 'XMCDAnalyzer', - 'XMCDResult', + 'XMCDData', ] diff --git a/src/ptychodus/model/analysis/affine.py b/src/ptychodus/model/analysis/affine.py new file mode 100644 index 00000000..bf38b704 --- /dev/null +++ b/src/ptychodus/model/analysis/affine.py @@ -0,0 +1,154 @@ +from collections.abc import Sequence +from dataclasses import dataclass + +import numpy + +from ptychodus.api.geometry import AffineTransform +from ptychodus.api.observer import Observable +from ptychodus.api.typing import RealArrayType + +from ..product import ScanRepository +from .settings import AffineTransformEstimatorSettings + +__all__ = ['AffineTransformEstimator'] + + +@dataclass(frozen=True) +class PreprocessedCoordinates: + coordinates: RealArrayType + centroid_x: float + centroid_y: float + rms_distance: float + + +def estimate_mean_hodges_lehman(values: RealArrayType) -> float: + mean = numpy.median((values[numpy.newaxis, :] + values[:, numpy.newaxis]) / 2) + return float(mean) + + +def estimate_affine_transform( + uncorrected_coordinates: RealArrayType, + corrected_coordinates: RealArrayType, +) -> AffineTransform: + ones_col = numpy.ones((uncorrected_coordinates.shape[0])) + a = numpy.repeat(numpy.column_stack((uncorrected_coordinates, ones_col)), 2, axis=0) + b = corrected_coordinates.flatten() + x, residuals, rank, singular_values = numpy.linalg.lstsq(a, b) + return AffineTransform(x[0], x[1], x[2], x[3], x[4], x[5]) + + +def evaluate_error( + uncorrected_coordinates: RealArrayType, + corrected_coordinates: RealArrayType, + model: AffineTransform, +) -> RealArrayType: + y0 = uncorrected_coordinates[-2] + x0 = uncorrected_coordinates[-1] + + transform = numpy.vectorize(model.__call__) + yt, xt = transform(y0, x0) + + y1 = corrected_coordinates[-2] + x1 = corrected_coordinates[-1] + + return numpy.hypot(xt - x1, yt - y1) + + +class AffineTransformEstimator(Observable): + def __init__( + self, + rng: numpy.random.Generator, + settings: AffineTransformEstimatorSettings, + repository: ScanRepository, + ) -> None: + self._rng = rng + self._settings = settings + self._repository = repository + + def _preprocess_coordinates(self, product_indexes: Sequence[int]) -> PreprocessedCoordinates: + coordinate_list: list[float] = [] + + for product_index in product_indexes: + positions = self._repository[product_index].get_scan() + + for point in positions: + coordinate_list.append(point.position_y_m) + coordinate_list.append(point.position_x_m) + + coordinates = numpy.reshape(coordinate_list, (-1, 2)) + + # robust centroid estimation + centroid_x = estimate_mean_hodges_lehman(coordinates[:, -1]) + centroid_y = estimate_mean_hodges_lehman(coordinates[:, -2]) + coordinates -= numpy.array((centroid_y, centroid_x)) + + # rescale for RMS distance = 1 + distance = numpy.hypot(coordinates[:, -1], coordinates[:, -2]) + rms_distance = numpy.sqrt(numpy.mean(numpy.square(distance))) + coordinates /= rms_distance + + return PreprocessedCoordinates(coordinates, centroid_x, centroid_y, rms_distance) + + def estimate( + self, + measured_product_indexes: Sequence[int], + corrected_product_indexes: Sequence[int], + ) -> AffineTransform: + corrected_set = set(corrected_product_indexes) + measured_set = set(measured_product_indexes) + + if len(corrected_set) != len(corrected_product_indexes): + raise ValueError('One or more duplicated corrected product indexes!') + + if len(measured_set) != len(measured_product_indexes): + raise ValueError('One or more duplicated measured product indexes!') + + if not corrected_set.isdisjoint(measured_set): + raise ValueError('Product index appears in corrected and measured sets!') + + corrected_coordinates = self._preprocess_coordinates(corrected_product_indexes) + measured_coordinates = self._preprocess_coordinates(measured_product_indexes) + indexes = numpy.arange(measured_coordinates.coordinates.shape[0]) + num_shuffles = self._settings.num_shuffles.get_value() + inlier_threshold = self._settings.inlier_threshold.get_value() + min_inliers = self._settings.min_inliers.get_value() + arity = 3 # minimum number of points needed to estimate the model + + best_error = numpy.inf + best_model = AffineTransform(1.0, 0.0, 0.0, 0.0, 1.0, 0.0) + + # RANSAC estimation of affine transform + for it in range(num_shuffles): + self._rng.shuffle(indexes) + + for chunk in range(0, len(indexes), arity): + samples = indexes[chunk : chunk + arity] + + corrected_subset = numpy.take(corrected_coordinates.coordinates, samples, axis=0) + uncorrected_subset = numpy.take(measured_coordinates.coordinates, samples, axis=0) + coarse_model = estimate_affine_transform(uncorrected_subset, corrected_subset) + error = evaluate_error(uncorrected_subset, corrected_subset, coarse_model) + inliers = numpy.where(error < inlier_threshold) + + if len(inliers) > min_inliers: + corrected_subset = numpy.take( + corrected_coordinates.coordinates, inliers, axis=0 + ) + uncorrected_subset = numpy.take( + measured_coordinates.coordinates, inliers, axis=0 + ) + candidate_model = estimate_affine_transform( + uncorrected_subset, corrected_subset + ) + candidate_error = evaluate_error( + uncorrected_subset, corrected_subset, candidate_model + ) + candidate_error_rms = numpy.sqrt(numpy.mean(numpy.square(candidate_error))) + + if candidate_error < best_error: + best_error = candidate_error_rms + best_model = candidate_model + + # TODO broken: unscale best_model + + return best_model diff --git a/src/ptychodus/model/analysis/barycentric.py b/src/ptychodus/model/analysis/barycentric.py new file mode 100644 index 00000000..5d49bc6e --- /dev/null +++ b/src/ptychodus/model/analysis/barycentric.py @@ -0,0 +1,112 @@ +from typing import Generic, TypeVar + +from numpy.typing import NDArray +import numpy + +from ptychodus.api.typing import RealArrayType + +__all__ = [ + 'BarycentricArrayInterpolator', + 'BarycentricArrayStitcher', +] + +InexactDType = TypeVar('InexactDType', bound=numpy.inexact) + + +def calculate_support_frac(x: float, n: int) -> tuple[slice, float]: + lower = x - n / 2 + whole = int(lower) + return slice(whole, whole + n + 1), lower - whole + + +class BarycentricArrayInterpolator(Generic[InexactDType]): + def __init__(self, array: NDArray[InexactDType]) -> None: + super().__init__() + self._array = array + + def get_patch( + self, center_x: float, center_y: float, width: int, height: int + ) -> NDArray[InexactDType]: + x_support, x_frac = calculate_support_frac(center_x, width) + y_support, y_frac = calculate_support_frac(center_y, height) + + # reused quantities + x_frac_c = 1.0 - x_frac + y_frac_c = 1.0 - y_frac + + # barycentric interpolant weights + weight00 = y_frac_c * x_frac_c + weight01 = y_frac_c * x_frac + weight10 = y_frac * x_frac_c + weight11 = y_frac * x_frac + + support = self._array[..., y_support, x_support] + patch = weight00 * support[..., :-1, :-1] + patch = patch + weight01 * support[..., :-1, 1:] + patch = patch + weight10 * support[..., 1:, :-1] + patch = patch + weight11 * support[..., 1:, 1:] + return patch # type: ignore + + +class BarycentricArrayStitcher(Generic[InexactDType]): + def __init__(self, upper: NDArray[InexactDType], lower: RealArrayType | None = None) -> None: + super().__init__() + self._upper = upper + self._lower = lower + + if lower is not None and upper.shape != lower.shape: + raise ValueError(f'Mismatched array shapes! ({upper.shape} != {lower.shape})') + + def add_patch( + self, + center_x: float, + center_y: float, + value: NDArray[InexactDType], + weight: RealArrayType | None = None, + ) -> None: + if numpy.iscomplexobj(self._upper) != numpy.iscomplexobj(value): + raise ValueError(f'Mismatched value dtypes! ({self._upper.dtype} != {value.dtype})') + + if weight is not None: + if self._lower is None: + raise ValueError('Provided weights without a lower array!') + + if value.shape != weight.shape: + raise ValueError(f'Mismatched patch shapes! ({value.shape=} != {weight.shape=})') + + x_support, x_frac = calculate_support_frac(center_x, value.shape[-1]) + y_support, y_frac = calculate_support_frac(center_y, value.shape[-2]) + + # reused quantities + x_frac_c = 1.0 - x_frac + y_frac_c = 1.0 - y_frac + + # barycentric interpolant weights + weight00 = y_frac_c * x_frac_c + weight01 = y_frac_c * x_frac + weight10 = y_frac * x_frac_c + weight11 = y_frac * x_frac + + # add patch update to upper array support + uvalue = value if weight is None else weight * value + usupport = self._upper[..., y_support, x_support] + usupport[..., :-1, :-1] = usupport[..., :-1, :-1] + weight00 * uvalue + usupport[..., :-1, 1:] = usupport[..., :-1, 1:] + weight01 * uvalue + usupport[..., 1:, :-1] = usupport[..., 1:, :-1] + weight10 * uvalue + usupport[..., 1:, 1:] = usupport[..., 1:, 1:] + weight11 * uvalue + + if self._lower is not None and weight is not None: + # add patch update to lower array support + lsupport = self._lower[..., y_support, x_support] + lsupport[..., :-1, :-1] += weight00 * weight + lsupport[..., :-1, 1:] += weight01 * weight + lsupport[..., 1:, :-1] += weight10 * weight + lsupport[..., 1:, 1:] += weight11 * weight + + def stitch(self) -> NDArray[InexactDType]: + if self._lower is None: + return self._upper + + return numpy.divide( + self._upper, self._lower, out=numpy.zeros_like(self._upper), where=(self._lower > 0) + ) diff --git a/src/ptychodus/model/analysis/core.py b/src/ptychodus/model/analysis/core.py index baf8d098..7da7f21e 100644 --- a/src/ptychodus/model/analysis/core.py +++ b/src/ptychodus/model/analysis/core.py @@ -5,7 +5,7 @@ from ..product import ObjectRepository, ProductRepository from ..reconstructor import DiffractionPatternPositionMatcher from ..visualization import VisualizationEngine -from .exposure import ExposureAnalyzer +from .illumination import IlluminationMapper from .frc import FourierRingCorrelator from .propagator import ProbePropagator from .settings import ProbePropagationSettings @@ -18,20 +18,24 @@ class AnalysisCore: def __init__( self, - settingsRegistry: SettingsRegistry, - dataMatcher: DiffractionPatternPositionMatcher, - productRepository: ProductRepository, - objectRepository: ObjectRepository, + settings_registry: SettingsRegistry, + data_matcher: DiffractionPatternPositionMatcher, + product_repository: ProductRepository, + object_repository: ObjectRepository, ) -> None: - self.stxmSimulator = STXMSimulator(dataMatcher) - self.stxmVisualizationEngine = VisualizationEngine(isComplex=False) - - self._probePropagationSettings = ProbePropagationSettings(settingsRegistry) - self.probePropagator = ProbePropagator(self._probePropagationSettings, productRepository) - self.probePropagatorVisualizationEngine = VisualizationEngine(isComplex=False) - self.exposureAnalyzer = ExposureAnalyzer(productRepository) - self.exposureVisualizationEngine = VisualizationEngine(isComplex=False) - self.fourierRingCorrelator = FourierRingCorrelator(objectRepository) - - self.xmcdAnalyzer = XMCDAnalyzer(objectRepository) - self.xmcdVisualizationEngine = VisualizationEngine(isComplex=False) + self.stxm_simulator = STXMSimulator(data_matcher) + self.stxm_visualization_engine = VisualizationEngine(is_complex=False) + + self._probe_propagation_settings = ProbePropagationSettings(settings_registry) + self.probe_propagator = ProbePropagator( + self._probe_propagation_settings, product_repository + ) + self.probe_propagator_visualization_engine = VisualizationEngine(is_complex=False) + + self.exposure_analyzer = IlluminationMapper(product_repository) + self.exposure_visualization_engine = VisualizationEngine(is_complex=False) + + self.fourier_ring_correlator = FourierRingCorrelator(object_repository) + + self.xmcd_analyzer = XMCDAnalyzer(product_repository) + self.xmcd_visualization_engine = VisualizationEngine(is_complex=False) diff --git a/src/ptychodus/model/analysis/exposure.py b/src/ptychodus/model/analysis/exposure.py deleted file mode 100644 index c884c403..00000000 --- a/src/ptychodus/model/analysis/exposure.py +++ /dev/null @@ -1,68 +0,0 @@ -from __future__ import annotations -from collections.abc import Sequence -from dataclasses import dataclass -from pathlib import Path -import logging - -import numpy - -from ptychodus.api.geometry import PixelGeometry -from ptychodus.api.visualization import RealArrayType - -from ..product import ProductRepository - -logger = logging.getLogger(__name__) - - -@dataclass(frozen=True) -class ExposureMap: - pixel_width_m: float - pixel_height_m: float - center_x_m: float - center_y_m: float - counts: RealArrayType - - @property - def pixel_geometry(self) -> PixelGeometry: - return PixelGeometry(self.pixel_width_m, self.pixel_height_m) - - -class ExposureAnalyzer: - def __init__(self, repository: ProductRepository) -> None: - self._repository = repository - - def analyze(self, itemIndex: int) -> ExposureMap: - item = self._repository[itemIndex] - objectItem = item.getObject() - object_ = objectItem.getObject() - - counts = numpy.zeros_like(object_.array, dtype=float) # FIXME - - return ExposureMap( - pixel_width_m=object_.pixelWidthInMeters, - pixel_height_m=object_.pixelHeightInMeters, - center_x_m=object_.centerXInMeters, - center_y_m=object_.centerYInMeters, - counts=counts, - ) - - def getSaveFileFilterList(self) -> Sequence[str]: - return [self.getSaveFileFilter()] - - def getSaveFileFilter(self) -> str: - return 'NumPy Zipped Archive (*.npz)' - - def saveResult(self, filePath: Path, result: ExposureMap) -> None: - numpy.savez( - filePath, - 'pixel_height_m', - result.pixel_height_m, - 'pixel_width_m', - result.pixel_width_m, - 'center_x_m', - result.center_x_m, - 'center_y_m', - result.center_y_m, - 'counts', - result.counts, - ) diff --git a/src/ptychodus/model/analysis/frc.py b/src/ptychodus/model/analysis/frc.py index a9735ab2..0c619acf 100644 --- a/src/ptychodus/model/analysis/frc.py +++ b/src/ptychodus/model/analysis/frc.py @@ -7,8 +7,7 @@ import numpy.typing import scipy.fft -from ptychodus.api.object import ObjectArrayType -from ptychodus.api.typing import IntegerArrayType +from ptychodus.api.typing import ComplexArrayType, IntegerArrayType from ptychodus.api.visualization import Plot2D, PlotAxis, PlotSeries from ..product import ObjectRepository @@ -18,24 +17,24 @@ @dataclass(frozen=True) class FourierRingCorrelation: - spatialFrequency_rm: Sequence[float] + spatial_frequency_per_m: Sequence[float] correlation: Sequence[float] - def getResolutionInMeters(self, threshold: float) -> float: + def get_resolution_m(self, threshold: float) -> float: # TODO threshold from bits - for freq_rm, frc in zip(self.spatialFrequency_rm, self.correlation): + for freq_rm, frc in zip(self.spatial_frequency_per_m, self.correlation): if frc < threshold: return 1.0 / freq_rm return numpy.nan - def getPlot(self) -> Plot2D: - freqSeries = PlotSeries('freq', [1.0e-9 * freq for freq in self.spatialFrequency_rm]) - frcSeries = PlotSeries('frc', self.correlation) + def get_plot(self) -> Plot2D: + freq_series = PlotSeries('freq', [1.0e-9 * freq for freq in self.spatial_frequency_per_m]) + frc_series = PlotSeries('frc', self.correlation) return Plot2D( - axisX=PlotAxis('Spatial Frequency [1/nm]', [freqSeries]), - axisY=PlotAxis('Fourier Ring Correlation', [frcSeries]), + axis_x=PlotAxis('Spatial Frequency [1/nm]', [freq_series]), + axis_y=PlotAxis('Fourier Ring Correlation', [frc_series]), ) @@ -44,7 +43,7 @@ def __init__(self, repository: ObjectRepository) -> None: self._repository = repository @staticmethod - def _integrateRings(rings: IntegerArrayType, array: ObjectArrayType) -> ObjectArrayType: + def _integrate_rings(rings: IntegerArrayType, array: ComplexArrayType) -> ComplexArrayType: total = numpy.zeros(numpy.max(rings) + 1, dtype=complex) for index, value in zip(rings.flat, array.flat): @@ -52,7 +51,7 @@ def _integrateRings(rings: IntegerArrayType, array: ObjectArrayType) -> ObjectAr return total - def correlate(self, itemIndex1: int, itemIndex2: int) -> FourierRingCorrelation: + def correlate(self, product_index_1: int, product_index_2: int) -> FourierRingCorrelation: """ See: Joan Vila-Comamala, Ana Diaz, Manuel Guizar-Sicairos, Alexandre Mantion, Cameron M. Kewish, Andreas Menzel, Oliver Bunk, and Christian David, @@ -60,12 +59,12 @@ def correlate(self, itemIndex1: int, itemIndex2: int) -> FourierRingCorrelation: coherent diffractive imaging," Opt. Express 19, 21333-21344 (2011) """ - object1 = self._repository[itemIndex1].getObject() - object2 = self._repository[itemIndex2].getObject() + object1 = self._repository[product_index_1].get_object() + object2 = self._repository[product_index_2].get_object() # TODO support multilayer objects - array1 = object1.getLayer(0) - array2 = object2.getLayer(0) + array1 = object1.get_layer(0) + array2 = object2.get_layer(0) if numpy.ndim(array1) != 2 or numpy.ndim(array2) != 2: raise ValueError('Arrays must be 2D!') @@ -74,33 +73,33 @@ def correlate(self, itemIndex1: int, itemIndex2: int) -> FourierRingCorrelation: raise ValueError('Arrays must have same shape!') # TODO verify compatible pixel geometry - pixelGeometry = object2.getPixelGeometry() + pixel_geometry = object2.get_pixel_geometry() # TODO subpixel image registration: skimage.registration.phase_cross_correlation # TODO remove phase offset and ramp # TODO apply soft-edged mask # TODO stats: SSNR, area under FRC curve, average SNR, etc. - x_rm = scipy.fft.fftfreq(array1.shape[-1], d=pixelGeometry.widthInMeters) - y_rm = scipy.fft.fftfreq(array1.shape[-2], d=pixelGeometry.heightInMeters) - radialBinSize_rm = max(x_rm[1], y_rm[1]) + x_rm = scipy.fft.fftfreq(array1.shape[-1], d=pixel_geometry.width_m) + y_rm = scipy.fft.fftfreq(array1.shape[-2], d=pixel_geometry.height_m) + radial_bin_size_per_m = max(x_rm[1], y_rm[1]) xx_rm, yy_rm = numpy.meshgrid(x_rm, y_rm) rr_rm = numpy.hypot(xx_rm, yy_rm) - rings = numpy.divide(rr_rm, radialBinSize_rm).astype(int) - spatialFrequency_rm = numpy.arange(numpy.max(rings) + 1) * radialBinSize_rm + rings = numpy.divide(rr_rm, radial_bin_size_per_m).astype(int) + spatial_frequency_per_m = numpy.arange(numpy.max(rings) + 1) * radial_bin_size_per_m sf1 = scipy.fft.fft2(array1) sf2 = scipy.fft.fft2(array2) - c11 = self._integrateRings(rings, numpy.multiply(sf1, numpy.conj(sf1))) - c12 = self._integrateRings(rings, numpy.multiply(sf1, numpy.conj(sf2))) - c22 = self._integrateRings(rings, numpy.multiply(sf2, numpy.conj(sf2))) + c11 = self._integrate_rings(rings, numpy.multiply(sf1, numpy.conj(sf1))) + c12 = self._integrate_rings(rings, numpy.multiply(sf1, numpy.conj(sf2))) + c22 = self._integrate_rings(rings, numpy.multiply(sf2, numpy.conj(sf2))) correlation = numpy.absolute(c12) / numpy.sqrt(numpy.absolute(numpy.multiply(c11, c22))) # TODO replace NaNs with interpolated values rnyquist = numpy.min(array1.shape) // 2 + 1 - return FourierRingCorrelation(spatialFrequency_rm[:rnyquist], correlation[:rnyquist]) + return FourierRingCorrelation(spatial_frequency_per_m[:rnyquist], correlation[:rnyquist]) diff --git a/src/ptychodus/model/analysis/illumination.py b/src/ptychodus/model/analysis/illumination.py new file mode 100644 index 00000000..5f8f40bc --- /dev/null +++ b/src/ptychodus/model/analysis/illumination.py @@ -0,0 +1,138 @@ +from __future__ import annotations +from collections.abc import Sequence +from dataclasses import dataclass +from pathlib import Path +from typing import Any +import logging + +import numpy + +from ptychodus.api.geometry import PixelGeometry +from ptychodus.api.object import ObjectCenter +from ptychodus.api.observer import Observable +from ptychodus.api.typing import RealArrayType + +from ..product import ProductRepository +from .barycentric import BarycentricArrayStitcher + + +__all__ = [ + 'IlluminationMapper', +] + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class IlluminationMap: + photon_number: RealArrayType + photon_energy_J: float # noqa: N815 + exposure_time_s: float + mass_attenuation_m2_kg: float + pixel_geometry: PixelGeometry + center: ObjectCenter + + @property + def photon_fluence_1_m2(self) -> RealArrayType: + return self.photon_number / self.pixel_geometry.area_m2 + + @property + def photon_fluence_rate_Hz_m2(self) -> RealArrayType: # noqa: N802 + return self.photon_fluence_1_m2 / self.exposure_time_s + + @property + def energy_fluence_J_m2(self) -> RealArrayType: # noqa: N802 + return self.photon_fluence_1_m2 * self.photon_energy_J + + @property + def energy_fluence_rate_W_m2(self) -> RealArrayType: # noqa: N802 + return self.photon_fluence_rate_Hz_m2 * self.photon_energy_J + + @property + def dose_Gy(self) -> RealArrayType: # noqa: N802 + return self.energy_fluence_J_m2 * self.mass_attenuation_m2_kg + + @property + def dose_rate_Gy_s(self) -> RealArrayType: # noqa: N802 + return self.energy_fluence_rate_W_m2 * self.mass_attenuation_m2_kg + + @property + def intensity_W_m2(self) -> RealArrayType: # noqa: N802 + return self.energy_fluence_rate_W_m2 + + +class IlluminationMapper(Observable): + def __init__(self, repository: ProductRepository) -> None: + super().__init__() + self._repository = repository + + self._product_index = -1 + self._product_data: IlluminationMap | None = None + + def set_product(self, product_index: int) -> None: + if self._product_index != product_index: + self._product_index = product_index + self._product_data = None + self.notify_observers() + + def get_product_name(self) -> str: + product = self._repository[self._product_index] + return product.get_name() + + def map(self) -> None: + product = self._repository[self._product_index].get_product() + object_geometry = product.object_.get_geometry() + + stitcher = BarycentricArrayStitcher[numpy.double]( + numpy.zeros((object_geometry.height_px, object_geometry.width_px)) + ) + + for scan_point, probe in zip(product.positions, product.probes): + object_point = object_geometry.map_scan_point_to_object_point(scan_point) + stitcher.add_patch( + object_point.position_x_px, + object_point.position_y_px, + probe.get_intensity(), + ) + + self._product_data = IlluminationMap( + photon_number=stitcher.stitch(), + photon_energy_J=product.metadata.probe_energy_J, + exposure_time_s=product.metadata.exposure_time_s, + mass_attenuation_m2_kg=product.metadata.mass_attenuation_m2_kg, + pixel_geometry=object_geometry.get_pixel_geometry(), + center=object_geometry.get_center(), + ) + self.notify_observers() + + def get_data(self) -> IlluminationMap: + if self._product_data is None: + raise ValueError('No analyzed data!') + + return self._product_data + + def get_save_file_filters(self) -> Sequence[str]: + return [self.get_save_file_filter()] + + def get_save_file_filter(self) -> str: + return 'NumPy Zipped Archive (*.npz)' + + def save_data(self, file_path: Path) -> None: + if self._product_data is None: + raise ValueError('No analyzed data!') + + contents: dict[str, Any] = { + 'photon_number': self._product_data.photon_number, + 'photon_fluence_1_m2': self._product_data.photon_fluence_1_m2, + 'photon_fluence_rate_Hz_m2': self._product_data.photon_fluence_rate_Hz_m2, + 'energy_fluence_J_m2': self._product_data.energy_fluence_J_m2, + 'energy_fluence_rate_W_m2': self._product_data.energy_fluence_rate_W_m2, + 'dose_Gy': self._product_data.dose_Gy, + 'dose_rate_Gy_s': self._product_data.dose_rate_Gy_s, + 'pixel_height_m': self._product_data.pixel_geometry.height_m, + 'pixel_width_m': self._product_data.pixel_geometry.width_m, + 'center_x_m': self._product_data.center.position_x_m, + 'center_y_m': self._product_data.center.position_y_m, + } + + numpy.savez(file_path, **contents) diff --git a/src/ptychodus/model/analysis/objectInterpolator.py b/src/ptychodus/model/analysis/objectInterpolator.py deleted file mode 100644 index 544f2734..00000000 --- a/src/ptychodus/model/analysis/objectInterpolator.py +++ /dev/null @@ -1,57 +0,0 @@ -from ptychodus.api.geometry import ImageExtent -from ptychodus.api.object import Object, ObjectInterpolator -from ptychodus.api.scan import ScanPoint - - -class ObjectLinearInterpolator(ObjectInterpolator): - def __init__(self, object_: Object) -> None: - self._object = object_ - - def getPatch(self, patchCenter: ScanPoint, patchExtent: ImageExtent) -> Object: - geometry = self._object.getGeometry() - - patchWidth = patchExtent.widthInPixels - patchRadiusXInMeters = geometry.pixelWidthInMeters * patchWidth / 2 - patchMinimumXInMeters = patchCenter.positionXInMeters - patchRadiusXInMeters - ixBeginF, xi = divmod( - patchMinimumXInMeters - geometry.minimumXInMeters, - geometry.pixelWidthInMeters, - ) - ixBegin = int(ixBeginF) - ixEnd = ixBegin + patchWidth - ixSlice0 = slice(ixBegin, ixEnd) - ixSlice1 = slice(ixBegin + 1, ixEnd + 1) - - patchHeight = patchExtent.heightInPixels - patchRadiusYInMeters = geometry.pixelHeightInMeters * patchHeight / 2 - patchMinimumYInMeters = patchCenter.positionYInMeters - patchRadiusYInMeters - iyBeginF, eta = divmod( - patchMinimumYInMeters - geometry.minimumYInMeters, - geometry.pixelHeightInMeters, - ) - iyBegin = int(iyBeginF) - iyEnd = iyBegin + patchHeight - iySlice0 = slice(iyBegin, iyEnd) - iySlice1 = slice(iyBegin + 1, iyEnd + 1) - - xiC = 1.0 - xi - etaC = 1.0 - eta - - w00 = xiC * etaC - w01 = xi * etaC - w10 = xiC * eta - w11 = xi * eta - - patch = w00 * self._object.array[:, iySlice0, ixSlice0] - patch += w01 * self._object.array[:, iySlice0, ixSlice1] - patch += w10 * self._object.array[:, iySlice1, ixSlice0] - patch += w11 * self._object.array[:, iySlice1, ixSlice1] - - return Object( - array=patch, - layerDistanceInMeters=self._object.layerDistanceInMeters, - pixelWidthInMeters=geometry.pixelWidthInMeters, - pixelHeightInMeters=geometry.pixelHeightInMeters, - centerXInMeters=geometry.centerXInMeters, - centerYInMeters=geometry.centerYInMeters, - ) diff --git a/src/ptychodus/model/analysis/objectStitcher.py b/src/ptychodus/model/analysis/objectStitcher.py deleted file mode 100644 index 61c1bb12..00000000 --- a/src/ptychodus/model/analysis/objectStitcher.py +++ /dev/null @@ -1,62 +0,0 @@ -import numpy - -from ptychodus.api.object import Object, ObjectArrayType, ObjectGeometry -from ptychodus.api.scan import ScanPoint - - -class ObjectStitcher: - def __init__(self, geometry: ObjectGeometry) -> None: - self._geometry = geometry - self._weights = numpy.zeros((geometry.heightInPixels, geometry.widthInPixels)) - self._array: ObjectArrayType = numpy.zeros_like(self._weights, dtype=complex) - - def _addPatchPart( - self, ixSlice: slice, iySlice: slice, weight: float, patchArray: ObjectArrayType - ) -> None: - idx = numpy.s_[iySlice, ixSlice] - self._weights[idx] += weight - self._array[idx] += (patchArray - self._array[idx]) * weight / self._weights[idx] - - def addPatch(self, patchCenter: ScanPoint, patchArray: ObjectArrayType) -> None: - geometry = self._geometry - - patchWidth = patchArray.shape[-1] - patchRadiusXInMeters = geometry.pixelWidthInMeters * patchWidth / 2 - patchMinimumXInMeters = patchCenter.positionXInMeters - patchRadiusXInMeters - ixBeginF, xi = divmod( - patchMinimumXInMeters - geometry.minimumXInMeters, - geometry.pixelWidthInMeters, - ) - ixBegin = int(ixBeginF) - ixEnd = ixBegin + patchWidth - ixSlice0 = slice(ixBegin, ixEnd) - ixSlice1 = slice(ixBegin + 1, ixEnd + 1) - - patchHeight = patchArray.shape[-2] - patchRadiusYInMeters = geometry.pixelHeightInMeters * patchHeight / 2 - patchMinimumYInMeters = patchCenter.positionYInMeters - patchRadiusYInMeters - iyBeginF, eta = divmod( - patchMinimumYInMeters - geometry.minimumYInMeters, - geometry.pixelHeightInMeters, - ) - iyBegin = int(iyBeginF) - iyEnd = iyBegin + patchHeight - iySlice0 = slice(iyBegin, iyEnd) - iySlice1 = slice(iyBegin + 1, iyEnd + 1) - - xiC = 1.0 - xi - etaC = 1.0 - eta - - self._addPatchPart(ixSlice0, iySlice0, xiC * etaC, patchArray) - self._addPatchPart(ixSlice1, iySlice0, xi * etaC, patchArray) - self._addPatchPart(ixSlice0, iySlice1, xiC * eta, patchArray) - self._addPatchPart(ixSlice1, iySlice1, xi * eta, patchArray) - - def build(self) -> Object: - return Object( - array=self._array, - pixelWidthInMeters=self._geometry.pixelWidthInMeters, - pixelHeightInMeters=self._geometry.pixelHeightInMeters, - centerXInMeters=self._geometry.centerXInMeters, - centerYInMeters=self._geometry.centerYInMeters, - ) diff --git a/src/ptychodus/model/analysis/propagator.py b/src/ptychodus/model/analysis/propagator.py index a7d7d2de..6f40c2b5 100644 --- a/src/ptychodus/model/analysis/propagator.py +++ b/src/ptychodus/model/analysis/propagator.py @@ -1,17 +1,18 @@ from __future__ import annotations from collections.abc import Sequence from pathlib import Path +from typing import Any import logging import numpy from ptychodus.api.geometry import PixelGeometry from ptychodus.api.observer import Observable -from ptychodus.api.probe import Probe +from ptychodus.api.probe import ProbeSequence from ptychodus.api.propagator import ( AngularSpectrumPropagator, PropagatorParameters, - WavefieldArrayType, + ComplexArrayType, intensity, ) from ptychodus.api.typing import RealArrayType @@ -28,130 +29,130 @@ def __init__(self, settings: ProbePropagationSettings, repository: ProductReposi self._settings = settings self._repository = repository - self._productIndex = -1 - self._propagatedWavefield: WavefieldArrayType | None = None - self._propagatedIntensity: RealArrayType | None = None + self._product_index = -1 + self._propagated_wavefield: ComplexArrayType | None = None + self._propagated_intensity: RealArrayType | None = None - def setProduct(self, productIndex: int) -> None: - if self._productIndex != productIndex: - self._productIndex = productIndex - self._propagatedWavefield = None - self._propagatedIntensity = None - self.notifyObservers() + def set_product(self, product_index: int) -> None: + if self._product_index != product_index: + self._product_index = product_index + self._propagated_wavefield = None + self._propagated_intensity = None + self.notify_observers() - def getProductName(self) -> str: - item = self._repository[self._productIndex] - return item.getName() + def get_product_name(self) -> str: + item = self._repository[self._product_index] + return item.get_name() def propagate( self, *, - beginCoordinateInMeters: float, - endCoordinateInMeters: float, - numberOfSteps: int, + begin_coordinate_m: float, + end_coordinate_m: float, + num_steps: int, ) -> None: - item = self._repository[self._productIndex] - probe = item.getProbe().getProbe() - wavelengthInMeters = item.getGeometry().probeWavelengthInMeters - propagatedWavefield = numpy.zeros( - (numberOfSteps, *probe.array.shape), - dtype=probe.array.dtype, + item = self._repository[self._product_index] + probe = item.get_probe_item().get_probes().get_probe_no_opr() # TODO OPR + wavelength_m = item.get_geometry().probe_wavelength_m + propagated_wavefield = numpy.zeros( + (num_steps, probe.num_incoherent_modes, probe.height_px, probe.width_px), + dtype=probe.dtype, ) - propagatedIntensity = numpy.zeros((numberOfSteps, *probe.array.shape[-2:])) - distanceInMeters = numpy.linspace( - beginCoordinateInMeters, endCoordinateInMeters, numberOfSteps - ) - pixelGeometry = probe.getPixelGeometry() - - for idx, zInMeters in enumerate(distanceInMeters): - propagatorParameters = PropagatorParameters( - wavelength_m=wavelengthInMeters, - width_px=probe.array.shape[-1], - height_px=probe.array.shape[-2], - pixel_width_m=pixelGeometry.widthInMeters, - pixel_height_m=pixelGeometry.heightInMeters, - propagation_distance_m=zInMeters, + propagated_intensity = numpy.zeros((num_steps, probe.height_px, probe.width_px)) + distance_m = numpy.linspace(begin_coordinate_m, end_coordinate_m, num_steps) + pixel_geometry = probe.get_pixel_geometry() + + for idx, z_m in enumerate(distance_m): + propagator_parameters = PropagatorParameters( + wavelength_m=wavelength_m, + width_px=probe.width_px, + height_px=probe.height_px, + pixel_width_m=pixel_geometry.width_m, + pixel_height_m=pixel_geometry.height_m, + propagation_distance_m=float(z_m), ) - propagator = AngularSpectrumPropagator(propagatorParameters) + propagator = AngularSpectrumPropagator(propagator_parameters) - for mode in range(probe.array.shape[-3]): - wf = propagator.propagate(probe.array[mode]) - propagatedWavefield[idx, mode, :, :] = wf - propagatedIntensity[idx, :, :] += intensity(wf) + for mode in range(probe.num_incoherent_modes): + wf = propagator.propagate(probe.get_incoherent_mode(mode)) + propagated_wavefield[idx, mode, :, :] = wf + propagated_intensity[idx, :, :] += intensity(wf) - self._settings.beginCoordinateInMeters.setValue(beginCoordinateInMeters) - self._settings.endCoordinateInMeters.setValue(endCoordinateInMeters) - self._propagatedWavefield = propagatedWavefield - self._propagatedIntensity = propagatedIntensity - self.notifyObservers() + self._settings.begin_coordinate_m.set_value(begin_coordinate_m) + self._settings.end_coordinate_m.set_value(end_coordinate_m) + self._propagated_wavefield = propagated_wavefield + self._propagated_intensity = propagated_intensity + self.notify_observers() - def getBeginCoordinateInMeters(self) -> float: - return self._settings.beginCoordinateInMeters.getValue() + def get_begin_coordinate_m(self) -> float: + return self._settings.begin_coordinate_m.get_value() - def getEndCoordinateInMeters(self) -> float: - return self._settings.endCoordinateInMeters.getValue() + def get_end_coordinate_m(self) -> float: + return self._settings.end_coordinate_m.get_value() - def _getProbe(self) -> Probe: - item = self._repository[self._productIndex] - return item.getProbe().getProbe() + def _get_probe(self) -> ProbeSequence: + item = self._repository[self._product_index] + return item.get_probe_item().get_probes() - def getPixelGeometry(self) -> PixelGeometry: - probe = self._getProbe() - return probe.getPixelGeometry() + def get_pixel_geometry(self) -> PixelGeometry | None: + try: + probe = self._get_probe() + except IndexError: + return None + else: + return probe.get_pixel_geometry() - def getNumberOfSteps(self) -> int: - if self._propagatedIntensity is None: - return self._settings.numberOfSteps.getValue() + def get_num_steps(self) -> int: + if self._propagated_intensity is None: + return self._settings.num_steps.get_value() - return self._propagatedIntensity.shape[0] + return self._propagated_intensity.shape[0] - def getXYProjection(self, step: int) -> RealArrayType: - if self._propagatedIntensity is None: + def get_xy_projection(self, step: int) -> RealArrayType: + if self._propagated_intensity is None: raise ValueError('No propagated wavefield!') - return self._propagatedIntensity[step] + return self._propagated_intensity[step] - def getZXProjection(self) -> RealArrayType: - if self._propagatedIntensity is None: + def get_zx_projection(self) -> RealArrayType: + if self._propagated_intensity is None: raise ValueError('No propagated wavefield!') - sz = self._propagatedIntensity.shape[-2] - cutPlaneL = self._propagatedIntensity[:, (sz - 1) // 2, :] - cutPlaneR = self._propagatedIntensity[:, sz // 2, :] - return numpy.transpose(numpy.add(cutPlaneL, cutPlaneR) / 2) + sz = self._propagated_intensity.shape[-2] + cut_plane_l = self._propagated_intensity[:, (sz - 1) // 2, :] + cut_plane_r = self._propagated_intensity[:, sz // 2, :] + return numpy.transpose(numpy.add(cut_plane_l, cut_plane_r) / 2) - def getZYProjection(self) -> RealArrayType: - if self._propagatedIntensity is None: + def get_zy_projection(self) -> RealArrayType: + if self._propagated_intensity is None: raise ValueError('No propagated wavefield!') - sz = self._propagatedIntensity.shape[-1] - cutPlaneL = self._propagatedIntensity[:, :, (sz - 1) // 2] - cutPlaneR = self._propagatedIntensity[:, :, sz // 2] - return numpy.transpose(numpy.add(cutPlaneL, cutPlaneR) / 2) + sz = self._propagated_intensity.shape[-1] + cut_plane_l = self._propagated_intensity[:, :, (sz - 1) // 2] + cut_plane_r = self._propagated_intensity[:, :, sz // 2] + return numpy.transpose(numpy.add(cut_plane_l, cut_plane_r) / 2) - def getSaveFileFilterList(self) -> Sequence[str]: - return [self.getSaveFileFilter()] + def get_save_file_filters(self) -> Sequence[str]: + return [self.get_save_file_filter()] - def getSaveFileFilter(self) -> str: + def get_save_file_filter(self) -> str: return 'NumPy Zipped Archive (*.npz)' - def savePropagatedProbe(self, filePath: Path) -> None: - if self._propagatedWavefield is None or self._propagatedIntensity is None: + def save_propagated_probe(self, file_path: Path) -> None: + if self._propagated_wavefield is None or self._propagated_intensity is None: raise ValueError('No propagated wavefield!') - pixelGeometry = self.getPixelGeometry() - numpy.savez( - filePath, - 'begin_coordinate_m', - float(self.getBeginCoordinateInMeters()), - 'end_coordinate_m', - float(self.getEndCoordinateInMeters()), - 'pixel_height_m', - pixelGeometry.heightInMeters, - 'pixel_width_m', - pixelGeometry.widthInMeters, - 'wavefield', - self._propagatedWavefield, - 'intensity', - self._propagatedIntensity, - ) + contents: dict[str, Any] = { + 'begin_coordinate_m': self.get_begin_coordinate_m(), + 'end_coordinate_m': self.get_end_coordinate_m(), + 'wavefield': self._propagated_wavefield, + 'intensity': self._propagated_intensity, + } + + pixel_geometry = self.get_pixel_geometry() + + if pixel_geometry is not None: + contents['pixel_height_m'] = pixel_geometry.height_m + contents['pixel_width_m'] = pixel_geometry.width_m + + numpy.savez(file_path, **contents) diff --git a/src/ptychodus/model/analysis/settings.py b/src/ptychodus/model/analysis/settings.py index ea0fcd44..91d9f7bf 100644 --- a/src/ptychodus/model/analysis/settings.py +++ b/src/ptychodus/model/analysis/settings.py @@ -5,17 +5,32 @@ class ProbePropagationSettings(Observable, Observer): def __init__(self, registry: SettingsRegistry) -> None: super().__init__() - self._settingsGroup = registry.createGroup('ProbePropagation') - self._settingsGroup.addObserver(self) + self._group = registry.create_group('ProbePropagation') + self._group.add_observer(self) - self.beginCoordinateInMeters = self._settingsGroup.createRealParameter( + self.begin_coordinate_m = self._group.create_real_parameter( 'BeginCoordinateInMeters', -1e-3 ) - self.endCoordinateInMeters = self._settingsGroup.createRealParameter( - 'EndCoordinateInMeters', 1e-3 + self.end_coordinate_m = self._group.create_real_parameter('EndCoordinateInMeters', 1e-3) + self.num_steps = self._group.create_integer_parameter('NumberOfSteps', 100) + + def _update(self, observable: Observable) -> None: + if observable is self._group: + self.notify_observers() + + +class AffineTransformEstimatorSettings(Observable, Observer): + def __init__(self, registry: SettingsRegistry) -> None: + super().__init__() + self._group = registry.create_group('AffineTransformEstimator') + self._group.add_observer(self) + + self.num_shuffles = self._group.create_integer_parameter('NumberOfShuffles', 10, minimum=1) + self.inlier_threshold = self._group.create_real_parameter( + 'InlierThreshold', 1e-6, minimum=0.0 ) - self.numberOfSteps = self._settingsGroup.createIntegerParameter('NumberOfSteps', 100) + self.min_inliers = self._group.create_integer_parameter('MinimumInliers', 10, minimum=3) - def update(self, observable: Observable) -> None: - if observable is self._settingsGroup: - self.notifyObservers() + def _update(self, observable: Observable) -> None: + if observable is self._group: + self.notify_observers() diff --git a/src/ptychodus/model/analysis/stxm.py b/src/ptychodus/model/analysis/stxm.py index c284d3cc..4b0f56eb 100644 --- a/src/ptychodus/model/analysis/stxm.py +++ b/src/ptychodus/model/analysis/stxm.py @@ -2,17 +2,18 @@ from collections.abc import Sequence from dataclasses import dataclass from pathlib import Path +from typing import Any import logging import numpy from ptychodus.api.geometry import PixelGeometry -from ptychodus.api.object import ObjectGeometry +from ptychodus.api.object import ObjectCenter from ptychodus.api.observer import Observable -from ptychodus.api.scan import ScanPoint -from ptychodus.api.visualization import RealArrayType +from ptychodus.api.typing import RealArrayType from ..reconstructor import DiffractionPatternPositionMatcher +from .barycentric import BarycentricArrayStitcher __all__ = [ 'STXMSimulator', @@ -22,148 +23,84 @@ @dataclass(frozen=True) -class STXMImage: +class STXMData: intensity: RealArrayType - pixel_width_m: float - pixel_height_m: float - center_x_m: float - center_y_m: float - - @property - def pixel_geometry(self) -> PixelGeometry: - return PixelGeometry(self.pixel_width_m, self.pixel_height_m) - - -class STXMStitcher: - def __init__(self, geometry: ObjectGeometry) -> None: - self._geometry = geometry - self._weights = numpy.zeros((geometry.heightInPixels, geometry.widthInPixels)) - self._intensity = numpy.zeros_like(self._weights) - - def _addPatchPart( - self, - ixSlice: slice, - iySlice: slice, - intensity: float, - probeProfile: RealArrayType, - ) -> None: - idx = numpy.s_[iySlice, ixSlice] - self._weights[idx] += probeProfile - self._intensity[idx] += intensity * probeProfile - - def addMeasurement( - self, point: ScanPoint, intensity: float, probeProfile: RealArrayType - ) -> None: - geometry = self._geometry - - patchWidth = probeProfile.shape[-1] - patchRadiusXInMeters = geometry.pixelWidthInMeters * patchWidth / 2 - patchMinimumXInMeters = point.positionXInMeters - patchRadiusXInMeters - ixBeginF, xi = divmod( - patchMinimumXInMeters - geometry.minimumXInMeters, - geometry.pixelWidthInMeters, - ) - ixBegin = int(ixBeginF) - ixEnd = ixBegin + patchWidth - ixSlice0 = slice(ixBegin, ixEnd) - ixSlice1 = slice(ixBegin + 1, ixEnd + 1) - - patchHeight = probeProfile.shape[-2] - patchRadiusYInMeters = geometry.pixelHeightInMeters * patchHeight / 2 - patchMinimumYInMeters = point.positionYInMeters - patchRadiusYInMeters - iyBeginF, eta = divmod( - patchMinimumYInMeters - geometry.minimumYInMeters, - geometry.pixelHeightInMeters, - ) - iyBegin = int(iyBeginF) - iyEnd = iyBegin + patchHeight - iySlice0 = slice(iyBegin, iyEnd) - iySlice1 = slice(iyBegin + 1, iyEnd + 1) - - xiC = 1.0 - xi - etaC = 1.0 - eta - - self._addPatchPart(ixSlice0, iySlice0, xiC * etaC, probeProfile) - self._addPatchPart(ixSlice1, iySlice0, xi * etaC, probeProfile) - self._addPatchPart(ixSlice0, iySlice1, xiC * eta, probeProfile) - self._addPatchPart(ixSlice1, iySlice1, xi * eta, probeProfile) - - def build(self) -> STXMImage: - intensity = numpy.divide( - self._intensity, - self._weights, - out=numpy.zeros_like(self._weights), - where=(self._weights > 0), - ) - return STXMImage( - intensity=intensity, - pixel_width_m=self._geometry.pixelWidthInMeters, - pixel_height_m=self._geometry.pixelHeightInMeters, - center_x_m=self._geometry.centerXInMeters, - center_y_m=self._geometry.centerYInMeters, - ) + pixel_geometry: PixelGeometry + center: ObjectCenter class STXMSimulator(Observable): - def __init__(self, dataMatcher: DiffractionPatternPositionMatcher) -> None: + def __init__(self, data_matcher: DiffractionPatternPositionMatcher) -> None: super().__init__() - self._dataMatcher = dataMatcher + self._data_matcher = data_matcher - self._productIndex = -1 - self._image: STXMImage | None = None + self._product_index = -1 + self._product_data: STXMData | None = None - def setProduct(self, productIndex: int) -> None: - if self._productIndex != productIndex: - self._productIndex = productIndex - self._image = None - self.notifyObservers() + def set_product(self, product_index: int) -> None: + if self._product_index != product_index: + self._product_index = product_index + self._product_data = None + self.notify_observers() - def getProductName(self) -> str: - return self._dataMatcher.getProductName(self._productIndex) + def get_product_name(self) -> str: + return self._data_matcher.get_product_item(self._product_index).get_name() def simulate(self) -> None: - reconstructInput = self._dataMatcher.matchDiffractionPatternsWithPositions( - self._productIndex + reconstruct_input = self._data_matcher.match_diffraction_patterns_with_positions( + self._product_index ) - product = reconstructInput.product - stitcher = STXMStitcher(product.object_.getGeometry()) - - probeIntensity = product.probe.getIntensity() - probeProfile = probeIntensity / numpy.sqrt(numpy.sum(numpy.abs(probeIntensity) ** 2)) + product = reconstruct_input.product + object_geometry = product.object_.get_geometry() + object_shape = object_geometry.height_px, object_geometry.width_px - for pattern, scanPoint in zip(reconstructInput.patterns, product.scan): - patternIntensity = pattern.sum() - stitcher.addMeasurement(scanPoint, patternIntensity, probeProfile) + stitcher = BarycentricArrayStitcher[numpy.double]( + upper=numpy.zeros(object_shape), + lower=numpy.zeros(object_shape), + ) - self._image = stitcher.build() - self.notifyObservers() + for pattern, scan_point, probe in zip( + reconstruct_input.patterns, product.positions, product.probes + ): + probe_intensity = probe.get_intensity() + rescaled_probe_intensity = probe_intensity * (pattern.sum() / probe_intensity.sum()) + object_point = object_geometry.map_scan_point_to_object_point(scan_point) + stitcher.add_patch( + object_point.position_x_px, + object_point.position_y_px, + rescaled_probe_intensity, + numpy.ones_like(rescaled_probe_intensity), + ) + + self._product_data = STXMData( + intensity=stitcher.stitch(), + pixel_geometry=object_geometry.get_pixel_geometry(), + center=object_geometry.get_center(), + ) + self.notify_observers() - def getImage(self) -> STXMImage: - if self._image is None: - raise ValueError('No simulated image!') + def get_data(self) -> STXMData: + if self._product_data is None: + raise ValueError('No simulated data!') - return self._image + return self._product_data - def getSaveFileFilterList(self) -> Sequence[str]: - return [self.getSaveFileFilter()] + def get_save_file_filters(self) -> Sequence[str]: + return [self.get_save_file_filter()] - def getSaveFileFilter(self) -> str: + def get_save_file_filter(self) -> str: return 'NumPy Zipped Archive (*.npz)' - def saveImage(self, filePath: Path) -> None: - if self._image is None: - raise ValueError('No simulated image!') - - numpy.savez( - filePath, - 'pixel_height_m', - self._image.pixel_height_m, - 'pixel_width_m', - self._image.pixel_width_m, - 'center_x_m', - self._image.center_x_m, - 'center_y_m', - self._image.center_y_m, - 'intensity', - self._image.intensity, - ) + def save_data(self, file_path: Path) -> None: + if self._product_data is None: + raise ValueError('No simulated data!') + + contents: dict[str, Any] = { + 'intensity': self._product_data.intensity, + 'pixel_height_m': self._product_data.pixel_geometry.height_m, + 'pixel_width_m': self._product_data.pixel_geometry.width_m, + 'center_x_m': self._product_data.center.position_x_m, + 'center_y_m': self._product_data.center.position_y_m, + } + + numpy.savez(file_path, **contents) diff --git a/src/ptychodus/model/analysis/xmcd.py b/src/ptychodus/model/analysis/xmcd.py index 38044e30..f3a858c7 100644 --- a/src/ptychodus/model/analysis/xmcd.py +++ b/src/ptychodus/model/analysis/xmcd.py @@ -2,104 +2,136 @@ from collections.abc import Sequence from dataclasses import dataclass from pathlib import Path +from typing import Any import logging import numpy from ptychodus.api.geometry import PixelGeometry -from ptychodus.api.object import ObjectArrayType +from ptychodus.api.object import ObjectCenter +from ptychodus.api.observer import Observable +from ptychodus.api.typing import ComplexArrayType -from ..product import ObjectRepository +from ..product import ProductRepository logger = logging.getLogger(__name__) @dataclass(frozen=True) -class XMCDResult: - pixel_width_m: float - pixel_height_m: float - center_x_m: float - center_y_m: float - polar_difference: ObjectArrayType - polar_sum: ObjectArrayType - polar_ratio: ObjectArrayType +class XMCDData: + polar_difference: ComplexArrayType + polar_sum: ComplexArrayType + polar_ratio: ComplexArrayType + pixel_geometry: PixelGeometry + center: ObjectCenter - @property - def pixel_geometry(self) -> PixelGeometry: - return PixelGeometry(self.pixel_width_m, self.pixel_height_m) - -class XMCDAnalyzer: +class XMCDAnalyzer(Observable): # TODO feature request: want ability to align/add reconstructed slices # of repeat scans for each polarization separately to improve statistics - def __init__(self, repository: ObjectRepository) -> None: + def __init__(self, repository: ProductRepository) -> None: + super().__init__() self._repository = repository - def analyze(self, lcircItemIndex: int, rcircItemIndex: int) -> XMCDResult: - lcircObject = self._repository[lcircItemIndex].getObject() - rcircObject = self._repository[rcircItemIndex].getObject() + self._lcirc_product_index = -1 + self._rcirc_product_index = -1 + self._product_data: XMCDData | None = None + + def set_lcirc_product(self, lcirc_product_index: int) -> None: + if self._lcirc_product_index != lcirc_product_index: + self._lcirc_product_index = lcirc_product_index + self._lcirc_product_data = None + self.notify_observers() + + def get_lcirc_product(self) -> int: + return self._lcirc_product_index + + def get_lcirc_product_name(self) -> str: + lcirc_product = self._repository[self._lcirc_product_index] + return lcirc_product.get_name() + + def set_rcirc_product(self, rcirc_product_index: int) -> None: + if self._rcirc_product_index != rcirc_product_index: + self._rcirc_product_index = rcirc_product_index + self._rcirc_product_data = None + self.notify_observers() + + def get_rcirc_product(self) -> int: + return self._rcirc_product_index + + def get_rcirc_product_name(self) -> str: + rcirc_product = self._repository[self._rcirc_product_index] + return rcirc_product.get_name() - if lcircObject.widthInPixels != rcircObject.widthInPixels: + def analyze(self) -> None: + lcirc_object = self._repository[self._lcirc_product_index].get_object_item().get_object() + rcirc_object = self._repository[self._rcirc_product_index].get_object_item().get_object() + + lcirc_object_geometry = lcirc_object.get_geometry() + rcirc_object_geometry = rcirc_object.get_geometry() + + if lcirc_object_geometry.width_px != rcirc_object_geometry.width_px: raise ValueError('Object width mismatch!') - if lcircObject.heightInPixels != rcircObject.heightInPixels: + if lcirc_object_geometry.height_px != rcirc_object_geometry.height_px: raise ValueError('Object height mismatch!') - if lcircObject.pixelWidthInMeters != rcircObject.pixelWidthInMeters: + if lcirc_object_geometry.pixel_width_m != rcirc_object_geometry.pixel_width_m: raise ValueError('Object pixel width mismatch!') - if lcircObject.pixelHeightInMeters != rcircObject.pixelHeightInMeters: + if lcirc_object_geometry.pixel_height_m != rcirc_object_geometry.pixel_height_m: raise ValueError('Object pixel height mismatch!') - # TODO align lcircArray/rcircArray - - lcircAmp = numpy.absolute(lcircObject.array) - rcircAmp = numpy.absolute(rcircObject.array) + # TODO align lcirc_array/rcirc_array + lcirc_amp = numpy.absolute(lcirc_object.get_layers_flattened()) + rcirc_amp = numpy.absolute(rcirc_object.get_layers_flattened()) - ratio = numpy.divide(lcircAmp, rcircAmp) - product = numpy.multiply(lcircAmp, rcircAmp) + ratio = numpy.divide(lcirc_amp, rcirc_amp) + product = numpy.multiply(lcirc_amp, rcirc_amp) - polar_difference = numpy.log(ratio, out=numpy.zeros_like(ratio), where=(ratio > 0)) - polar_sum = numpy.log(product, out=numpy.zeros_like(product), where=(product > 0)) + polar_difference = numpy.log(ratio, out=numpy.zeros_like(ratio), where=(ratio > 0.0)) + polar_sum = numpy.log(product, out=numpy.zeros_like(product), where=(product > 0.0)) polar_ratio = numpy.divide( polar_difference, polar_sum, out=numpy.zeros_like(polar_sum), - where=(polar_sum > 0), + where=(polar_sum > 0.0), ) - return XMCDResult( - pixel_width_m=rcircObject.pixelWidthInMeters, - pixel_height_m=rcircObject.pixelHeightInMeters, - center_x_m=rcircObject.centerXInMeters, - center_y_m=rcircObject.centerYInMeters, + self._product_data = XMCDData( polar_difference=polar_difference, polar_sum=polar_sum, polar_ratio=polar_ratio, + pixel_geometry=rcirc_object.get_pixel_geometry(), + center=rcirc_object.get_center(), ) + self.notify_observers() - def getSaveFileFilterList(self) -> Sequence[str]: - return [self.getSaveFileFilter()] + def get_data(self) -> XMCDData: + if self._product_data is None: + raise ValueError('No analyzed data!') - def getSaveFileFilter(self) -> str: + return self._product_data + + def get_save_file_filters(self) -> Sequence[str]: + return [self.get_save_file_filter()] + + def get_save_file_filter(self) -> str: return 'NumPy Zipped Archive (*.npz)' - def saveResult(self, filePath: Path, result: XMCDResult) -> None: - numpy.savez( - filePath, - 'pixel_height_m', - result.pixel_height_m, - 'pixel_width_m', - result.pixel_width_m, - 'center_x_m', - result.center_x_m, - 'center_y_m', - result.center_y_m, - 'polar_difference', - result.polar_difference, - 'polar_sum', - result.polar_sum, - 'polar_ratio', - result.polar_ratio, - ) + def save_data(self, file_path: Path) -> None: + if self._product_data is None: + raise ValueError('No analyzed data!') + + contents: dict[str, Any] = { + 'polar_difference': self._product_data.polar_difference, + 'polar_sum': self._product_data.polar_sum, + 'polar_ratio': self._product_data.polar_ratio, + 'pixel_height_m': self._product_data.pixel_geometry.height_m, + 'pixel_width_m': self._product_data.pixel_geometry.width_m, + 'center_x_m': self._product_data.center.position_x_m, + 'center_y_m': self._product_data.center.position_y_m, + } + + numpy.savez(file_path, **contents) diff --git a/src/ptychodus/model/automation/buffer.py b/src/ptychodus/model/automation/buffer.py index 0d3de2df..aa31f841 100644 --- a/src/ptychodus/model/automation/buffer.py +++ b/src/ptychodus/model/automation/buffer.py @@ -21,51 +21,51 @@ def __init__( self._settings = settings self._repository = repository self._processor = processor - self._eventTimes: OrderedDict[Path, float] = OrderedDict() - self._eventTimesLock = threading.Lock() - self._stopWorkEvent = threading.Event() + self._event_times: OrderedDict[Path, float] = OrderedDict() + self._event_times_lock = threading.Lock() + self._stop_work_event = threading.Event() self._worker = threading.Thread() - def put(self, filePath: Path) -> None: - with self._eventTimesLock: - self._eventTimes[filePath] = time() - self._eventTimes.move_to_end(filePath) + def put(self, file_path: Path) -> None: + with self._event_times_lock: + self._event_times[file_path] = time() + self._event_times.move_to_end(file_path) - self._repository.put(filePath, AutomationDatasetState.EXISTS) + self._repository.put(file_path, AutomationDatasetState.EXISTS) def _process(self) -> None: - while not self._stopWorkEvent.is_set(): - isFileReadyForProcessing = False + while not self._stop_work_event.is_set(): + is_file_ready_for_processing = False - with self._eventTimesLock: + with self._event_times_lock: try: - filePath, eventTime = next(iter(self._eventTimes.items())) + file_path, event_time = next(iter(self._event_times.items())) except StopIteration: pass else: - delayTime = self._settings.watchdogDelayInSeconds.getValue() - isFileReadyForProcessing = eventTime + delayTime < time() + delay_time = self._settings.watchdog_delay_s.get_value() + is_file_ready_for_processing = event_time + delay_time < time() - if isFileReadyForProcessing: - self._eventTimes.popitem(last=False) + if is_file_ready_for_processing: + self._event_times.popitem(last=False) - if isFileReadyForProcessing: - self._processor.put(filePath) + if is_file_ready_for_processing: + self._processor.put(file_path) else: - self._stopWorkEvent.wait(timeout=5.0) # TODO make configurable + self._stop_work_event.wait(timeout=5.0) # TODO make configurable def start(self) -> None: if self._worker.is_alive(): self.stop() logger.info('Starting automation thread...') - self._stopWorkEvent.clear() + self._stop_work_event.clear() self._worker = threading.Thread(target=self._process) self._worker.start() logger.info('Automation thread started.') def stop(self) -> None: logger.info('Stopping automation thread...') - self._stopWorkEvent.set() + self._stop_work_event.set() self._worker.join() logger.info('Automation thread stopped.') diff --git a/src/ptychodus/model/automation/core.py b/src/ptychodus/model/automation/core.py index f49b59a0..4bcc0b88 100644 --- a/src/ptychodus/model/automation/core.py +++ b/src/ptychodus/model/automation/core.py @@ -1,5 +1,5 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Iterator from pathlib import Path import queue @@ -23,86 +23,86 @@ def __init__( settings: AutomationSettings, workflow: CurrentFileBasedWorkflow, watcher: DataDirectoryWatcher, - datasetBuffer: AutomationDatasetBuffer, - datasetRepository: AutomationDatasetRepository, + dataset_buffer: AutomationDatasetBuffer, + dataset_repository: AutomationDatasetRepository, ) -> None: super().__init__() self._settings = settings self._workflow = workflow self._watcher = watcher - self._datasetBuffer = datasetBuffer - self._datasetRepository = datasetRepository + self._dataset_buffer = dataset_buffer + self._dataset_repository = dataset_repository - settings.addObserver(self) - watcher.addObserver(self) + settings.add_observer(self) + watcher.add_observer(self) - def getStrategyList(self) -> Sequence[str]: - return self._workflow.getAvailableWorkflows() + def get_strategies(self) -> Iterator[str]: + return self._workflow.get_available_workflows() - def getStrategy(self) -> str: - return self._workflow.getWorkflow() + def get_strategy(self) -> str: + return self._workflow.get_workflow() - def setStrategy(self, strategy: str) -> None: - self._workflow.setWorkflow(strategy) + def set_strategy(self, strategy: str) -> None: + self._workflow.set_workflow(strategy) - def getDataDirectory(self) -> Path: - return self._settings.dataDirectory.getValue() + def get_data_directory(self) -> Path: + return self._settings.data_directory.get_value() - def setDataDirectory(self, directory: Path) -> None: - self._settings.dataDirectory.setValue(directory) + def set_data_directory(self, directory: Path) -> None: + self._settings.data_directory.set_value(directory) - def getProcessingIntervalLimitsInSeconds(self) -> Interval[int]: + def get_processing_interval_limits_s(self) -> Interval[int]: return Interval[int](0, 600) - def getProcessingIntervalInSeconds(self) -> int: - limits = self.getProcessingIntervalLimitsInSeconds() - return limits.clamp(self._settings.processingIntervalInSeconds.getValue()) + def get_processing_interval_s(self) -> int: + limits = self.get_processing_interval_limits_s() + return limits.clamp(self._settings.processing_interval_s.get_value()) - def setProcessingIntervalInSeconds(self, value: int) -> None: - self._settings.processingIntervalInSeconds.setValue(value) + def set_processing_interval_s(self, value: int) -> None: + self._settings.processing_interval_s.set_value(value) - def loadExistingDatasetsToRepository(self) -> None: - dataDirectory = self.getDataDirectory() - pattern = '**/' if self._workflow.isWatchRecursive else '' - pattern += self._workflow.getWatchFilePattern() - scanFileList = sorted(scanFile for scanFile in dataDirectory.glob(pattern)) + def load_existing_datasets_to_repository(self) -> None: + data_directory = self.get_data_directory() + pattern = '**/' if self._workflow.is_watch_recursive else '' + pattern += self._workflow.get_watch_file_pattern() + scan_file_list = sorted(scanFile for scanFile in data_directory.glob(pattern)) - for scanFile in scanFileList: - self._datasetBuffer.put(scanFile) + for scan_file in scan_file_list: + self._dataset_buffer.put(scan_file) - def clearDatasetRepository(self) -> None: - self._datasetRepository.clear() + def clear_dataset_repository(self) -> None: + self._dataset_repository.clear() - def isWatchdogEnabled(self) -> bool: - return self._watcher.isAlive + def is_watchdog_enabled(self) -> bool: + return self._watcher.is_alive - def setWatchdogEnabled(self, enable: bool) -> None: + def set_watchdog_enabled(self, enable: bool) -> None: if enable: self._watcher.start() else: self._watcher.stop() - def getWatchdogDelayLimitsInSeconds(self) -> Interval[int]: + def get_watchdog_delay_limits_s(self) -> Interval[int]: return Interval[int](0, 600) - def getWatchdogDelayInSeconds(self) -> int: - limits = self.getWatchdogDelayLimitsInSeconds() - return limits.clamp(self._settings.watchdogDelayInSeconds.getValue()) + def get_watchdog_delay_s(self) -> int: + limits = self.get_watchdog_delay_limits_s() + return limits.clamp(self._settings.watchdog_delay_s.get_value()) - def setWatchdogDelayInSeconds(self, value: int) -> None: - self._settings.watchdogDelayInSeconds.setValue(value) + def set_watchdog_delay_s(self, value: int) -> None: + self._settings.watchdog_delay_s.set_value(value) - def setWatchdogPollingObserverEnabled(self, enable: bool) -> None: - self._settings.useWatchdogPollingObserver.setValue(enable) + def set_watchdog_polling_observer_enabled(self, enable: bool) -> None: + self._settings.use_watchdog_polling_observer.set_value(enable) - def isWatchdogPollingObserverEnabled(self) -> bool: - return self._settings.useWatchdogPollingObserver.getValue() + def is_watchdog_polling_observer_enabled(self) -> bool: + return self._settings.use_watchdog_polling_observer.get_value() - def update(self, observable: Observable) -> None: + def _update(self, observable: Observable) -> None: if observable is self._settings: - self.notifyObservers() + self.notify_observers() elif observable is self._watcher: - self.notifyObservers() + self.notify_observers() class AutomationProcessingPresenter(Observable, Observer): @@ -117,77 +117,77 @@ def __init__( self._repository = repository self._processor = processor - settings.addObserver(self) - repository.addObserver(self) + settings.add_observer(self) + repository.add_observer(self) - def getDatasetLabel(self, index: int) -> str: - return self._repository.getLabel(index) + def get_dataset_label(self, index: int) -> str: + return self._repository.get_label(index) - def getDatasetState(self, index: int) -> AutomationDatasetState: - return self._repository.getState(index) + def get_dataset_state(self, index: int) -> AutomationDatasetState: + return self._repository.get_state(index) - def getNumberOfDatasets(self) -> int: + def get_num_datasets(self) -> int: return len(self._repository) - def isProcessingEnabled(self) -> bool: - return self._processor.isAlive + def is_processing_enabled(self) -> bool: + return self._processor.is_alive - def setProcessingEnabled(self, enable: bool) -> None: + def set_processing_enabled(self, enable: bool) -> None: if enable: self._processor.start() else: self._processor.stop() - def update(self, observable: Observable) -> None: + def _update(self, observable: Observable) -> None: if observable is self._settings: - self.notifyObservers() + self.notify_observers() elif observable is self._repository: - self.notifyObservers() + self.notify_observers() class AutomationCore: def __init__( self, - settingsRegistry: SettingsRegistry, - workflowAPI: WorkflowAPI, - workflowChooser: PluginChooser[FileBasedWorkflow], + settings_registry: SettingsRegistry, + workflow_api: WorkflowAPI, + workflow_chooser: PluginChooser[FileBasedWorkflow], ) -> None: - self._settings = AutomationSettings(settingsRegistry) + self._settings = AutomationSettings(settings_registry) self.repository = AutomationDatasetRepository(self._settings) - self._workflow = CurrentFileBasedWorkflow(self._settings, workflowChooser) - self._processingQueue: queue.Queue[Path] = queue.Queue() + self._workflow = CurrentFileBasedWorkflow(self._settings, workflow_chooser) + self._processing_queue: queue.Queue[Path] = queue.Queue() self._processor = AutomationDatasetProcessor( self._settings, self.repository, self._workflow, - workflowAPI, - self._processingQueue, + workflow_api, + self._processing_queue, ) - self._datasetBuffer = AutomationDatasetBuffer( + self._dataset_buffer = AutomationDatasetBuffer( self._settings, self.repository, self._processor ) - self._watcher = DataDirectoryWatcher(self._settings, self._workflow, self._datasetBuffer) + self._watcher = DataDirectoryWatcher(self._settings, self._workflow, self._dataset_buffer) self.presenter = AutomationPresenter( self._settings, self._workflow, self._watcher, - self._datasetBuffer, + self._dataset_buffer, self.repository, ) - self.processingPresenter = AutomationProcessingPresenter( + self.processing_presenter = AutomationProcessingPresenter( self._settings, self.repository, self._processor ) def start(self) -> None: - self._datasetBuffer.start() + self._dataset_buffer.start() - def refreshDatasetRepository(self) -> None: - self.repository.notifyObserversIfRepositoryChanged() + def refresh_dataset_repository(self) -> None: + self.repository.notify_observers_if_repository_changed() - def executeWaitingTasks(self) -> None: - self._processor.runOnce() + def execute_waiting_tasks(self) -> None: + self._processor.run_once() def stop(self) -> None: self._processor.stop() self._watcher.stop() - self._datasetBuffer.stop() + self._dataset_buffer.stop() diff --git a/src/ptychodus/model/automation/processor.py b/src/ptychodus/model/automation/processor.py index 6414f6b4..eb5c0bff 100644 --- a/src/ptychodus/model/automation/processor.py +++ b/src/ptychodus/model/automation/processor.py @@ -18,74 +18,74 @@ def __init__( settings: AutomationSettings, repository: AutomationDatasetRepository, workflow: FileBasedWorkflow, - workflowAPI: WorkflowAPI, - processingQueue: queue.Queue[Path], + workflow_api: WorkflowAPI, + processing_queue: queue.Queue[Path], ) -> None: self._settings = settings self._repository = repository self._workflow = workflow - self._workflowAPI = workflowAPI - self._processingQueue = processingQueue - self._stopWorkEvent = threading.Event() + self._workflow_api = workflow_api + self._processing_queue = processing_queue + self._stop_work_event = threading.Event() self._worker = threading.Thread() - self._nextJobTime = time() + self._next_job_time = time() @property - def isAlive(self) -> bool: + def is_alive(self) -> bool: return self._worker.is_alive() - def put(self, filePath: Path) -> None: - self._repository.put(filePath, AutomationDatasetState.WAITING) - self._processingQueue.put(filePath) + def put(self, file_path: Path) -> None: + self._repository.put(file_path, AutomationDatasetState.WAITING) + self._processing_queue.put(file_path) - def runOnce(self) -> None: + def run_once(self) -> None: try: - filePath = self._processingQueue.get(block=False) + file_path = self._processing_queue.get(block=False) try: - self._repository.put(filePath, AutomationDatasetState.PROCESSING) - self._workflow.execute(self._workflowAPI, filePath) - self._repository.put(filePath, AutomationDatasetState.COMPLETE) + self._repository.put(file_path, AutomationDatasetState.PROCESSING) + self._workflow.execute(self._workflow_api, file_path) + self._repository.put(file_path, AutomationDatasetState.COMPLETE) except Exception: logger.exception('Error while processing dataset!') finally: - self._processingQueue.task_done() + self._processing_queue.task_done() except queue.Empty: pass def _run(self) -> None: - while not self._stopWorkEvent.is_set(): + while not self._stop_work_event.is_set(): try: - filePath = self._processingQueue.get(block=True, timeout=1) + file_path = self._processing_queue.get(block=True, timeout=1) except queue.Empty: continue - delayInSeconds = self._nextJobTime - time() + delay_s = self._next_job_time - time() - if delayInSeconds > 0.0 and self._stopWorkEvent.wait(timeout=delayInSeconds): + if delay_s > 0.0 and self._stop_work_event.wait(timeout=delay_s): break try: - self._repository.put(filePath, AutomationDatasetState.PROCESSING) - self._workflow.execute(self._workflowAPI, filePath) - self._repository.put(filePath, AutomationDatasetState.COMPLETE) + self._repository.put(file_path, AutomationDatasetState.PROCESSING) + self._workflow.execute(self._workflow_api, file_path) + self._repository.put(file_path, AutomationDatasetState.COMPLETE) except Exception: logger.exception('Error while processing dataset!') finally: - self._processingQueue.task_done() - self._nextJobTime = self._settings.processingIntervalInSeconds.getValue() + time() + self._processing_queue.task_done() + self._next_job_time = self._settings.processing_interval_s.get_value() + time() def start(self) -> None: self.stop() logger.info('Starting automation processor thread...') - self._stopWorkEvent.clear() + self._stop_work_event.clear() self._worker = threading.Thread(target=self._run) self._worker.start() logger.info('Automation processor thread started.') def stop(self) -> None: - if self.isAlive: + if self.is_alive: logger.info('Stopping automation processor thread...') - self._stopWorkEvent.set() + self._stop_work_event.set() self._worker.join() logger.info('Automation processor thread stopped.') diff --git a/src/ptychodus/model/automation/repository.py b/src/ptychodus/model/automation/repository.py index dbbd2de0..f9081d55 100644 --- a/src/ptychodus/model/automation/repository.py +++ b/src/ptychodus/model/automation/repository.py @@ -21,49 +21,49 @@ class AutomationDatasetRepository(Observable): def __init__(self, settings: AutomationSettings) -> None: super().__init__() self._settings = settings - self._fileList: list[Path] = list() - self._fileState: dict[Path, AutomationDatasetState] = dict() + self._file_list: list[Path] = list() + self._file_state: dict[Path, AutomationDatasetState] = dict() self._lock = threading.Lock() - self._changedEvent = threading.Event() + self._changed_event = threading.Event() - def put(self, filePath: Path, state: AutomationDatasetState) -> None: + def put(self, file_path: Path, state: AutomationDatasetState) -> None: with self._lock: try: - priorState = self._fileState[filePath] + prior_state = self._file_state[file_path] except KeyError: if state == AutomationDatasetState.EXISTS: - self._fileList.append(filePath) - self._fileState[filePath] = state + self._file_list.append(file_path) + self._file_state[file_path] = state else: - logger.error(f'{filePath}: UNKNOWN -> {state}') + logger.error(f'{file_path}: UNKNOWN -> {state}') else: - logger.debug(f'{filePath}: {priorState} -> {state}') - self._fileState[filePath] = state + logger.debug(f'{file_path}: {prior_state} -> {state}') + self._file_state[file_path] = state - self._changedEvent.set() + self._changed_event.set() def clear(self) -> None: with self._lock: - self._fileList.clear() - self._fileState.clear() + self._file_list.clear() + self._file_state.clear() - self._changedEvent.set() + self._changed_event.set() - def getLabel(self, index: int) -> str: + def get_label(self, index: int) -> str: with self._lock: - filePath = self._fileList[index] - return str(filePath.relative_to(self._settings.dataDirectory.getValue())) + file_path = self._file_list[index] + return str(file_path.relative_to(self._settings.data_directory.get_value())) - def getState(self, index: int) -> AutomationDatasetState: + def get_state(self, index: int) -> AutomationDatasetState: with self._lock: - filePath = self._fileList[index] - return self._fileState[filePath] + file_path = self._file_list[index] + return self._file_state[file_path] def __len__(self) -> int: with self._lock: - return len(self._fileList) + return len(self._file_list) - def notifyObserversIfRepositoryChanged(self) -> None: - if self._changedEvent.is_set(): - self._changedEvent.clear() - self.notifyObservers() + def notify_observers_if_repository_changed(self) -> None: + if self._changed_event.is_set(): + self._changed_event.clear() + self.notify_observers() diff --git a/src/ptychodus/model/automation/settings.py b/src/ptychodus/model/automation/settings.py index ec0e866e..7bdc4ea8 100644 --- a/src/ptychodus/model/automation/settings.py +++ b/src/ptychodus/model/automation/settings.py @@ -8,23 +8,21 @@ class AutomationSettings(Observable, Observer): def __init__(self, registry: SettingsRegistry) -> None: super().__init__() - self._settingsGroup = registry.createGroup('Automation') - self._settingsGroup.addObserver(self) + self._group = registry.create_group('Automation') + self._group.add_observer(self) - self.strategy = self._settingsGroup.createStringParameter('Strategy', 'APS2ID') - self.dataDirectory = self._settingsGroup.createPathParameter( + self.strategy = self._group.create_string_parameter('Strategy', 'Autoload_Product') + self.data_directory = self._group.create_path_parameter( 'DataDirectory', Path('/path/to/data') ) - self.processingIntervalInSeconds = self._settingsGroup.createIntegerParameter( + self.processing_interval_s = self._group.create_integer_parameter( 'ProcessingIntervalInSeconds', 0 ) - self.useWatchdogPollingObserver = self._settingsGroup.createBooleanParameter( + self.use_watchdog_polling_observer = self._group.create_boolean_parameter( 'UseWatchdogPollingObserver', False ) - self.watchdogDelayInSeconds = self._settingsGroup.createIntegerParameter( - 'WatchdogDelayInSeconds', 15 - ) + self.watchdog_delay_s = self._group.create_integer_parameter('WatchdogDelayInSeconds', 15) - def update(self, observable: Observable) -> None: - if observable is self._settingsGroup: - self.notifyObservers() + def _update(self, observable: Observable) -> None: + if observable is self._group: + self.notify_observers() diff --git a/src/ptychodus/model/automation/watcher.py b/src/ptychodus/model/automation/watcher.py index 74473186..661cdce2 100644 --- a/src/ptychodus/model/automation/watcher.py +++ b/src/ptychodus/model/automation/watcher.py @@ -17,17 +17,19 @@ class DataDirectoryEventHandler(watchdog.events.FileSystemEventHandler): - def __init__(self, workflow: FileBasedWorkflow, datasetBuffer: AutomationDatasetBuffer) -> None: + def __init__( + self, workflow: FileBasedWorkflow, dataset_buffer: AutomationDatasetBuffer + ) -> None: super().__init__() self._workflow = workflow - self._datasetBuffer = datasetBuffer + self._dataset_buffer = dataset_buffer def on_created_or_modified(self, event: watchdog.events.FileSystemEvent) -> None: if not event.is_directory: - srcPath = Path(str(event.src_path)) + src_path = Path(str(event.src_path)) - if srcPath.match(self._workflow.getWatchFilePattern()): - self._datasetBuffer.put(srcPath) + if src_path.match(self._workflow.get_watch_file_pattern()): + self._dataset_buffer.put(src_path) def on_created(self, event: watchdog.events.FileSystemEvent) -> None: self.on_created_or_modified(event) @@ -41,58 +43,58 @@ def __init__( self, settings: AutomationSettings, workflow: CurrentFileBasedWorkflow, - datasetBuffer: AutomationDatasetBuffer, + dataset_buffer: AutomationDatasetBuffer, ) -> None: super().__init__() self._settings = settings self._workflow = workflow - self._datasetBuffer = datasetBuffer + self._dataset_buffer = dataset_buffer self._observer: watchdog.observers.api.BaseObserver = watchdog.observers.Observer() - settings.addObserver(self) - workflow.addObserver(self) + settings.add_observer(self) + workflow.add_observer(self) @property - def isAlive(self) -> bool: + def is_alive(self) -> bool: return self._observer.is_alive() - def _updateWatch(self) -> None: + def _update_watch(self) -> None: self._observer.unschedule_all() - dataDirectory = self._settings.dataDirectory.getValue() + data_directory = self._settings.data_directory.get_value() - if dataDirectory.exists(): - observedWatch = self._observer.schedule( - event_handler=DataDirectoryEventHandler(self._workflow, self._datasetBuffer), - path=str(dataDirectory), - recursive=self._workflow.isWatchRecursive, + if data_directory.exists(): + observed_watch = self._observer.schedule( + event_handler=DataDirectoryEventHandler(self._workflow, self._dataset_buffer), + path=str(data_directory), + recursive=self._workflow.is_watch_recursive, ) - logger.debug(observedWatch) + logger.debug(observed_watch) else: - logger.warning(f'Data directory "{dataDirectory}" does not exist!') + logger.warning(f'Data directory "{data_directory}" does not exist!') def start(self) -> None: - if self.isAlive: + if self.is_alive: logger.error('Automation watchdog thread already started!') else: logger.info('Starting automation watchdog thread...') self._observer = ( PollingObserver() - if self._settings.useWatchdogPollingObserver.getValue() + if self._settings.use_watchdog_polling_observer.get_value() else watchdog.observers.Observer() ) self._observer.start() - self._updateWatch() + self._update_watch() logger.debug('Automation watchdog thread started.') def stop(self) -> None: - if self.isAlive: + if self.is_alive: logger.info('Stopping automation watchdog thread...') self._observer.stop() self._observer.join() logger.debug('Automation watchdog thread stopped.') - def update(self, observable: Observable) -> None: + def _update(self, observable: Observable) -> None: if observable is self._settings: - self._updateWatch() + self._update_watch() elif observable is self._workflow: - self._updateWatch() + self._update_watch() diff --git a/src/ptychodus/model/automation/workflow.py b/src/ptychodus/model/automation/workflow.py index 09cdcda4..73d3aad2 100644 --- a/src/ptychodus/model/automation/workflow.py +++ b/src/ptychodus/model/automation/workflow.py @@ -1,4 +1,4 @@ -from collections.abc import Sequence +from collections.abc import Iterator from pathlib import Path from ptychodus.api.observer import Observable, Observer @@ -12,40 +12,37 @@ class CurrentFileBasedWorkflow(FileBasedWorkflow, Observable, Observer): def __init__( self, settings: AutomationSettings, - workflowChooser: PluginChooser[FileBasedWorkflow], + workflow_chooser: PluginChooser[FileBasedWorkflow], ) -> None: super().__init__() - self._settings = settings - self._workflowChooser = workflowChooser + self._workflow_chooser = workflow_chooser - settings.addObserver(self) - workflowChooser.addObserver(self) + workflow_chooser.synchronize_with_parameter(settings.strategy) + workflow_chooser.add_observer(self) - def getAvailableWorkflows(self) -> Sequence[str]: - return self._workflowChooser.getDisplayNameList() + def get_available_workflows(self) -> Iterator[str]: + for plugin in self._workflow_chooser: + yield plugin.display_name - def getWorkflow(self) -> str: - return self._workflowChooser.currentPlugin.displayName + def get_workflow(self) -> str: + return self._workflow_chooser.get_current_plugin().display_name - def setWorkflow(self, name: str) -> None: - self._workflowChooser.setCurrentPluginByName(name) - self._settings.strategy.setValue(self._workflowChooser.currentPlugin.simpleName) + def set_workflow(self, name: str) -> None: + self._workflow_chooser.set_current_plugin(name) @property - def isWatchRecursive(self) -> bool: - workflow = self._workflowChooser.currentPlugin.strategy - return workflow.isWatchRecursive - - def getWatchFilePattern(self) -> str: - workflow = self._workflowChooser.currentPlugin.strategy - return workflow.getWatchFilePattern() - - def execute(self, api: WorkflowAPI, filePath: Path) -> None: - workflow = self._workflowChooser.currentPlugin.strategy - workflow.execute(api, filePath) - - def update(self, observable: Observable) -> None: - if observable is self._settings: - self.setWorkflow(self._settings.strategy.getValue()) - if observable is self._workflowChooser: - self.notifyObservers() + def is_watch_recursive(self) -> bool: + workflow = self._workflow_chooser.get_current_plugin().strategy + return workflow.is_watch_recursive + + def get_watch_file_pattern(self) -> str: + workflow = self._workflow_chooser.get_current_plugin().strategy + return workflow.get_watch_file_pattern() + + def execute(self, api: WorkflowAPI, file_path: Path) -> None: + workflow = self._workflow_chooser.get_current_plugin().strategy + workflow.execute(api, file_path) + + def _update(self, observable: Observable) -> None: + if observable is self._workflow_chooser: + self.notify_observers() diff --git a/src/ptychodus/model/core.py b/src/ptychodus/model/core.py index e8502692..f4b33714 100644 --- a/src/ptychodus/model/core.py +++ b/src/ptychodus/model/core.py @@ -21,64 +21,33 @@ from ptychodus.api.settings import SettingsRegistry from ptychodus.api.workflow import WorkflowAPI -from .analysis import ( - AnalysisCore, - ExposureAnalyzer, - FourierRingCorrelator, - ProbePropagator, - STXMSimulator, - XMCDAnalyzer, -) -from .automation import ( - AutomationCore, - AutomationPresenter, - AutomationProcessingPresenter, -) -from .fluorescence import FluorescenceCore, FluorescenceEnhancer +from .agent import AgentCore +from .analysis import AnalysisCore +from .automation import AutomationCore +from .fluorescence import FluorescenceCore from .memory import MemoryPresenter -from .patterns import ( - Detector, - DiffractionDatasetInputOutputPresenter, - DiffractionDatasetPresenter, - DiffractionMetadataPresenter, - DiffractionPatternPresenter, - PatternsCore, -) -from .product import ( - ObjectAPI, - ObjectRepository, - ProbeAPI, - ProbeRepository, - ProductAPI, - ProductCore, - ProductRepository, - ScanAPI, - ScanRepository, -) +from .metadata import MetadataPresenter +from .patterns import PatternsCore, PatternsStreamingContext +from .product import PositionsStreamingContext, ProductCore +from .ptychi import PtyChiReconstructorLibrary from .ptychonn import PtychoNNReconstructorLibrary -from .reconstructor import ReconstructorCore, ReconstructorPresenter +from .ptychopinn import PtychoPINNReconstructorLibrary +from .reconstructor import ReconstructorCore from .tike import TikeReconstructorLibrary from .visualization import VisualizationEngine -from .workflow import ( - WorkflowAuthorizationPresenter, - WorkflowCore, - WorkflowExecutionPresenter, - WorkflowParametersPresenter, - WorkflowStatusPresenter, -) +from .workflow import WorkflowCore logger = logging.getLogger(__name__) -def configureLogger(isDeveloperModeEnabled: bool) -> None: +def configure_logger(is_developer_mode_enabled: bool) -> None: logging.basicConfig( format='%(asctime)s [%(levelname)s] %(name)s: %(message)s', stream=sys.stdout, encoding='utf-8', - level=logging.DEBUG if isDeveloperModeEnabled else logging.INFO, + level=logging.DEBUG if is_developer_mode_enabled else logging.INFO, ) logging.getLogger('matplotlib').setLevel(logging.WARNING) - logging.getLogger('tike').setLevel(logging.WARNING) logger.info(f'Ptychodus {version("ptychodus")}') logger.info(f'NumPy {version("numpy")}') @@ -88,96 +57,141 @@ def configureLogger(isDeveloperModeEnabled: bool) -> None: logger.info(f'HDF5 {h5py.version.hdf5_version}') +class PtychodusStreamingContext: + def __init__( + self, + positions_context: PositionsStreamingContext, + patterns_context: PatternsStreamingContext, + ) -> None: + self._positions_context = positions_context + self._patterns_context = patterns_context + + def start(self) -> None: + self._positions_context.start() + self._patterns_context.start() + + def append_positions_x(self, values_m: Sequence[float], trigger_counts: Sequence[int]) -> None: + self._positions_context.append_positions_x(values_m, trigger_counts) + + def append_positions_y(self, values_m: Sequence[float], trigger_counts: Sequence[int]) -> None: + self._positions_context.append_positions_y(values_m, trigger_counts) + + def append_array(self, array: DiffractionPatternArray) -> None: + self._patterns_context.append_array(array) + + def get_queue_size(self) -> int: + return self._patterns_context.get_queue_size() + + def stop(self) -> None: + self._patterns_context.stop() + self._positions_context.stop() + + class ModelCore: def __init__( - self, settingsFile: Path | None = None, *, isDeveloperModeEnabled: bool = False + self, settings_file: Path | None = None, *, is_developer_mode_enabled: bool = False ) -> None: - configureLogger(isDeveloperModeEnabled) + configure_logger(is_developer_mode_enabled) self.rng = numpy.random.default_rng() - self._pluginRegistry = PluginRegistry.loadPlugins() + self._plugin_registry = PluginRegistry.load_plugins() - self.memoryPresenter = MemoryPresenter() - self.settingsRegistry = SettingsRegistry() + self.memory_presenter = MemoryPresenter() + self.settings_registry = SettingsRegistry() - self._patternsCore = PatternsCore( - self.settingsRegistry, - self._pluginRegistry.diffractionFileReaders, - self._pluginRegistry.diffractionFileWriters, + self.patterns = PatternsCore( + self.settings_registry, + self._plugin_registry.diffraction_file_readers, + self._plugin_registry.diffraction_file_writers, + self.settings_registry, ) - self._productCore = ProductCore( + self.product = ProductCore( self.rng, - self.settingsRegistry, - self._patternsCore.detector, - self._patternsCore.productSettings, - self._patternsCore.patternSizer, - self._patternsCore.dataset, - self._pluginRegistry.scanFileReaders, - self._pluginRegistry.scanFileWriters, - self._pluginRegistry.fresnelZonePlates, - self._pluginRegistry.probeFileReaders, - self._pluginRegistry.probeFileWriters, - self._pluginRegistry.objectFileReaders, - self._pluginRegistry.objectFileWriters, - self._pluginRegistry.productFileReaders, - self._pluginRegistry.productFileWriters, - self.settingsRegistry, + self.settings_registry, + self.patterns.pattern_sizer, + self.patterns.dataset, + self._plugin_registry.position_file_readers, + self._plugin_registry.position_file_writers, + self._plugin_registry.fresnel_zone_plates, + self._plugin_registry.probe_file_readers, + self._plugin_registry.probe_file_writers, + self._plugin_registry.object_file_readers, + self._plugin_registry.object_file_writers, + self._plugin_registry.product_file_readers, + self._plugin_registry.product_file_writers, + self.settings_registry, + ) + self.metadata_presenter = MetadataPresenter( + self.patterns.detector_settings, + self.patterns.pattern_settings, + self.patterns.dataset, + self.product.settings, ) - self.patternVisualizationEngine = VisualizationEngine(isComplex=False) - self.probeVisualizationEngine = VisualizationEngine(isComplex=True) - self.objectVisualizationEngine = VisualizationEngine(isComplex=True) + self.pattern_visualization_engine = VisualizationEngine(is_complex=False) + self.probe_visualization_engine = VisualizationEngine(is_complex=True) + self.object_visualization_engine = VisualizationEngine(is_complex=True) - self.tikeReconstructorLibrary = TikeReconstructorLibrary.createInstance( - self.settingsRegistry, isDeveloperModeEnabled + self.ptychi_reconstructor_library = PtyChiReconstructorLibrary( + self.settings_registry, self.patterns.pattern_sizer, is_developer_mode_enabled ) - self.ptychonnReconstructorLibrary = PtychoNNReconstructorLibrary.createInstance( - self.settingsRegistry, isDeveloperModeEnabled + self.tike_reconstructor_library = TikeReconstructorLibrary.create_instance( + self.settings_registry, is_developer_mode_enabled ) - self._reconstructorCore = ReconstructorCore( - self.settingsRegistry, - self._patternsCore.dataset, - self._productCore.productRepository, + self.ptychonn_reconstructor_library = PtychoNNReconstructorLibrary.create_instance( + self.settings_registry, is_developer_mode_enabled + ) + self.ptychopinn_reconstructor_library = PtychoPINNReconstructorLibrary( + self.settings_registry, is_developer_mode_enabled + ) + self.reconstructor = ReconstructorCore( + self.settings_registry, + self.patterns.dataset, + self.product.product_api, [ - self.tikeReconstructorLibrary, - self.ptychonnReconstructorLibrary, + self.ptychi_reconstructor_library, + self.tike_reconstructor_library, + self.ptychonn_reconstructor_library, + self.ptychopinn_reconstructor_library, ], ) - self._fluorescenceCore = FluorescenceCore( - self.settingsRegistry, - self._productCore.productRepository, - self._pluginRegistry.upscalingStrategies, - self._pluginRegistry.deconvolutionStrategies, - self._pluginRegistry.fluorescenceFileReaders, - self._pluginRegistry.fluorescenceFileWriters, + self.fluorescence_core = FluorescenceCore( + self.settings_registry, + self.product.product_repository, + self._plugin_registry.upscaling_strategies, + self._plugin_registry.deconvolution_strategies, + self._plugin_registry.fluorescence_file_readers, + self._plugin_registry.fluorescence_file_writers, ) - self._analysisCore = AnalysisCore( - self.settingsRegistry, - self._reconstructorCore.dataMatcher, - self._productCore.productRepository, - self._productCore.objectRepository, + self.analysis = AnalysisCore( + self.settings_registry, + self.reconstructor.data_matcher, + self.product.product_repository, + self.product.object_repository, ) - self._workflowCore = WorkflowCore( - self.settingsRegistry, - self._patternsCore.patternsAPI, - self._productCore.productAPI, - self._productCore.scanAPI, - self._productCore.probeAPI, - self._productCore.objectAPI, - self._reconstructorCore.reconstructorAPI, + self.workflow = WorkflowCore( + self.settings_registry, + self.patterns.patterns_api, + self.product.product_api, + self.product.scan_api, + self.product.probe_api, + self.product.object_api, + self.reconstructor.reconstructor_api, ) - self._automationCore = AutomationCore( - self.settingsRegistry, - self._workflowCore.workflowAPI, - self._pluginRegistry.fileBasedWorkflows, + self.automation = AutomationCore( + self.settings_registry, + self.workflow.workflow_api, + self._plugin_registry.file_based_workflows, ) + self.agent = AgentCore(self.settings_registry) - if settingsFile: - self.settingsRegistry.openSettings(settingsFile) + if settings_file: + self.settings_registry.open_settings(settings_file) def __enter__(self) -> ModelCore: - self._patternsCore.start() - self._workflowCore.start() - self._automationCore.start() + self.patterns.start() + self.reconstructor.start() + self.workflow.start() + self.automation.start() return self @overload @@ -197,130 +211,69 @@ def __exit__( exception_value: BaseException | None, traceback: TracebackType | None, ) -> None: - self._automationCore.stop() - self._workflowCore.stop() - self._patternsCore.stop() - - @property - def diffractionDatasetInputOutputPresenter( - self, - ) -> DiffractionDatasetInputOutputPresenter: - return self._patternsCore.datasetInputOutputPresenter - - @property - def diffractionMetadataPresenter(self) -> DiffractionMetadataPresenter: - return self._patternsCore.metadataPresenter - - @property - def diffractionDatasetPresenter(self) -> DiffractionDatasetPresenter: - return self._patternsCore.datasetPresenter - - @property - def patternPresenter(self) -> DiffractionPatternPresenter: - return self._patternsCore.patternPresenter - - @property - def productRepository(self) -> ProductRepository: - return self._productCore.productRepository - - @property - def productAPI(self) -> ProductAPI: - return self._productCore.productAPI - - @property - def scanRepository(self) -> ScanRepository: - return self._productCore.scanRepository - - @property - def scanAPI(self) -> ScanAPI: - return self._productCore.scanAPI - - @property - def probeRepository(self) -> ProbeRepository: - return self._productCore.probeRepository - - @property - def probeAPI(self) -> ProbeAPI: - return self._productCore.probeAPI - - @property - def objectRepository(self) -> ObjectRepository: - return self._productCore.objectRepository - - @property - def objectAPI(self) -> ObjectAPI: - return self._productCore.objectAPI - - def initializeStreamingWorkflow(self, metadata: DiffractionMetadata) -> None: - self._patternsCore.patternsAPI.initializeStreaming(metadata) - self._patternsCore.patternsAPI.startAssemblingDiffractionPatterns() - self._productCore.scanAPI.initializeStreamingScan() # FIXME - - def assembleDiffractionPattern(self, array: DiffractionPatternArray, timeStamp: float) -> None: - self._patternsCore.patternsAPI.assemble(array) - self._productCore.scanAPI.insertArrayTimeStamp(array.getIndex(), timeStamp) # FIXME - - def assembleScanPositionsX( - self, valuesInMeters: Sequence[float], timeStamps: Sequence[float] - ) -> None: - self._productCore.scanAPI.assembleScanPositionsX(valuesInMeters, timeStamps) # FIXME - - def assembleScanPositionsY( - self, valuesInMeters: Sequence[float], timeStamps: Sequence[float] - ) -> None: - self._productCore.scanAPI.assembleScanPositionsY(valuesInMeters, timeStamps) # FIXME - - def finalizeStreamingWorkflow(self) -> None: - self._productCore.scanAPI.finalizeStreamingScan() # FIXME - self._patternsCore.patternsAPI.stopAssemblingDiffractionPatterns(finishAssembling=True) - - def getDiffractionPatternAssemblyQueueSize(self) -> int: - return self._patternsCore.patternsAPI.getAssemblyQueueSize() + self.automation.stop() + self.workflow.stop() + self.reconstructor.stop() + self.patterns.stop() + + def create_streaming_context(self, metadata: DiffractionMetadata) -> PtychodusStreamingContext: + return PtychodusStreamingContext( + self.product.scan_api.create_streaming_context(), + self.patterns.patterns_api.create_streaming_context(metadata), + ) - def refreshActiveDataset(self) -> None: - self._patternsCore.dataset.notifyObserversIfDatasetChanged() + def refresh_active_dataset(self) -> None: + self.patterns.dataset.assemble_patterns() - def batchModeExecute( + def batch_mode_execute( self, action: str, - inputFilePath: Path, - outputFilePath: Path, + input_path: Path, + output_path: Path, *, - fluorescenceInputFilePath: Path | None = None, - fluorescenceOutputFilePath: Path | None = None, + product_file_type: str = 'NPZ', + fluorescence_input_file_path: Path | None = None, + fluorescence_output_file_path: Path | None = None, ) -> int: # TODO add enum for actions; implement using workflow API - inputProductIndex = self._productCore.productAPI.openProduct(inputFilePath, fileType='NPZ') + if action.lower() == 'train': + output = self.reconstructor.reconstructor_api.train(input_path) + self.reconstructor.reconstructor_api.save_model(output_path) + return output.result + + input_product_index = self.product.product_api.open_product( + input_path, file_type=product_file_type + ) - if inputProductIndex < 0: - logger.error(f'Failed to open product "{inputFilePath}"') + if input_product_index < 0: + logger.error(f'Failed to open product "{input_path}"!') return -1 if action.lower() == 'reconstruct': - outputProductName = self._productCore.productAPI.getItemName(inputProductIndex) - outputProductIndex = self._reconstructorCore.reconstructorAPI.reconstruct( - inputProductIndex, outputProductName + logger.info('Reconstructing...') + output_product_index = self.reconstructor.reconstructor_api.reconstruct( + input_product_index ) + self.reconstructor.reconstructor_api.process_results(block=True) + logger.info('Reconstruction complete.') - if outputProductIndex < 0: - logger.error(f'Failed to reconstruct product index="{inputProductIndex}"') - return -1 - - self._productCore.productAPI.saveProduct( - outputProductIndex, outputFilePath, fileType='NPZ' + self.product.product_api.save_product( + output_product_index, output_path, file_type=product_file_type ) - if fluorescenceInputFilePath is not None and fluorescenceOutputFilePath is not None: - self._fluorescenceCore.enhanceFluorescence( - outputProductIndex, - fluorescenceInputFilePath, - fluorescenceOutputFilePath, + if ( + fluorescence_input_file_path is not None + and fluorescence_output_file_path is not None + ): + self.fluorescence_core.enhance_fluorescence( + output_product_index, + fluorescence_input_file_path, + fluorescence_output_file_path, ) - - elif action.lower() == 'train': - self._reconstructorCore.reconstructorAPI.ingestTrainingData(inputProductIndex) - _ = self._reconstructorCore.reconstructorAPI.train() - self._reconstructorCore.reconstructorAPI.saveModel(outputFilePath) + elif action.lower() == 'prepare_training_data': + self.reconstructor.reconstructor_api.export_training_data( + output_path, input_product_index + ) else: logger.error(f'Unknown batch mode action "{action}"!') return -1 @@ -328,85 +281,5 @@ def batchModeExecute( return 0 @property - def detector(self) -> Detector: - return self._patternsCore.detector - - @property - def reconstructorPresenter(self) -> ReconstructorPresenter: - return self._reconstructorCore.presenter - - @property - def stxmSimulator(self) -> STXMSimulator: - return self._analysisCore.stxmSimulator - - @property - def stxmVisualizationEngine(self) -> VisualizationEngine: - return self._analysisCore.stxmVisualizationEngine - - @property - def probePropagator(self) -> ProbePropagator: - return self._analysisCore.probePropagator - - @property - def probePropagatorVisualizationEngine(self) -> VisualizationEngine: - return self._analysisCore.probePropagatorVisualizationEngine - - @property - def exposureAnalyzer(self) -> ExposureAnalyzer: - return self._analysisCore.exposureAnalyzer - - @property - def exposureVisualizationEngine(self) -> VisualizationEngine: - return self._analysisCore.exposureVisualizationEngine - - @property - def fourierRingCorrelator(self) -> FourierRingCorrelator: - return self._analysisCore.fourierRingCorrelator - - @property - def fluorescenceEnhancer(self) -> FluorescenceEnhancer: - return self._fluorescenceCore.enhancer - - @property - def fluorescenceVisualizationEngine(self) -> VisualizationEngine: - return self._fluorescenceCore.visualizationEngine - - @property - def xmcdAnalyzer(self) -> XMCDAnalyzer: - return self._analysisCore.xmcdAnalyzer - - @property - def xmcdVisualizationEngine(self) -> VisualizationEngine: - return self._analysisCore.xmcdVisualizationEngine - - @property - def areWorkflowsSupported(self) -> bool: - return self._workflowCore.areWorkflowsSupported - - @property - def workflowAuthorizationPresenter(self) -> WorkflowAuthorizationPresenter: - return self._workflowCore.authorizationPresenter - - @property - def workflowStatusPresenter(self) -> WorkflowStatusPresenter: - return self._workflowCore.statusPresenter - - @property - def workflowExecutionPresenter(self) -> WorkflowExecutionPresenter: - return self._workflowCore.executionPresenter - - @property - def workflowParametersPresenter(self) -> WorkflowParametersPresenter: - return self._workflowCore.parametersPresenter - - @property - def workflowAPI(self) -> WorkflowAPI: - return self._workflowCore.workflowAPI - - @property - def automationPresenter(self) -> AutomationPresenter: - return self._automationCore.presenter - - @property - def automationProcessingPresenter(self) -> AutomationProcessingPresenter: - return self._automationCore.processingPresenter + def workflow_api(self) -> WorkflowAPI: + return self.workflow.workflow_api diff --git a/src/ptychodus/model/fluorescence/core.py b/src/ptychodus/model/fluorescence/core.py index 02c396b6..4fd2994a 100644 --- a/src/ptychodus/model/fluorescence/core.py +++ b/src/ptychodus/model/fluorescence/core.py @@ -1,5 +1,5 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Iterator from pathlib import Path import logging @@ -31,191 +31,188 @@ class FluorescenceEnhancer(Observable, Observer): def __init__( self, settings: FluorescenceSettings, - productRepository: ProductRepository, - twoStepEnhancingAlgorithm: TwoStepFluorescenceEnhancingAlgorithm, - vspiEnhancingAlgorithm: VSPIFluorescenceEnhancingAlgorithm, - fileReaderChooser: PluginChooser[FluorescenceFileReader], - fileWriterChooser: PluginChooser[FluorescenceFileWriter], - reinitObservable: Observable, + product_repository: ProductRepository, + two_step_enhancing_algorithm: TwoStepFluorescenceEnhancingAlgorithm, + vspi_enhancing_algorithm: VSPIFluorescenceEnhancingAlgorithm, + file_reader_chooser: PluginChooser[FluorescenceFileReader], + file_writer_chooser: PluginChooser[FluorescenceFileWriter], + reinit_observable: Observable, ) -> None: super().__init__() self._settings = settings - self._productRepository = productRepository - self.twoStepEnhancingAlgorithm = twoStepEnhancingAlgorithm - self.vspiEnhancingAlgorithm = vspiEnhancingAlgorithm - self._fileReaderChooser = fileReaderChooser - self._fileWriterChooser = fileWriterChooser - self._reinitObservable = reinitObservable - - self._algorithmChooser = PluginChooser[FluorescenceEnhancingAlgorithm]() - self._algorithmChooser.registerPlugin( - twoStepEnhancingAlgorithm, - simpleName=TwoStepFluorescenceEnhancingAlgorithm.SIMPLE_NAME, - displayName=TwoStepFluorescenceEnhancingAlgorithm.DISPLAY_NAME, + self._product_repository = product_repository + self.two_step_enhancing_algorithm = two_step_enhancing_algorithm + self.vspi_enhancing_algorithm = vspi_enhancing_algorithm + self._file_reader_chooser = file_reader_chooser + self._file_writer_chooser = file_writer_chooser + self._reinit_observable = reinit_observable + + self._algorithm_chooser = PluginChooser[FluorescenceEnhancingAlgorithm]() + self._algorithm_chooser.register_plugin( + two_step_enhancing_algorithm, + simple_name=TwoStepFluorescenceEnhancingAlgorithm.SIMPLE_NAME, + display_name=TwoStepFluorescenceEnhancingAlgorithm.DISPLAY_NAME, ) - self._algorithmChooser.registerPlugin( - vspiEnhancingAlgorithm, - simpleName=VSPIFluorescenceEnhancingAlgorithm.SIMPLE_NAME, - displayName=VSPIFluorescenceEnhancingAlgorithm.DISPLAY_NAME, + self._algorithm_chooser.register_plugin( + vspi_enhancing_algorithm, + simple_name=VSPIFluorescenceEnhancingAlgorithm.SIMPLE_NAME, + display_name=VSPIFluorescenceEnhancingAlgorithm.DISPLAY_NAME, ) - self._syncAlgorithmFromSettings() - self._algorithmChooser.addObserver(self) + self._algorithm_chooser.synchronize_with_parameter(settings.algorithm) + self._algorithm_chooser.add_observer(self) - self._productIndex = -1 + self._product_index = -1 self._measured: FluorescenceDataset | None = None self._enhanced: FluorescenceDataset | None = None - fileReaderChooser.setCurrentPluginByName(settings.fileType.getValue()) - fileWriterChooser.setCurrentPluginByName(settings.fileType.getValue()) - reinitObservable.addObserver(self) + file_reader_chooser.synchronize_with_parameter(settings.file_type) + file_writer_chooser.set_current_plugin(settings.file_type.get_value()) + reinit_observable.add_observer(self) @property def _product(self) -> ProductRepositoryItem: - return self._productRepository[self._productIndex] + return self._product_repository[self._product_index] - def setProduct(self, productIndex: int) -> None: - if self._productIndex != productIndex: - self._productIndex = productIndex + def set_product(self, product_index: int) -> None: + if self._product_index != product_index: + self._product_index = product_index self._enhanced = None - self.notifyObservers() + self.notify_observers() - def getProductName(self) -> str: - return self._product.getName() + def get_product_name(self) -> str: + return self._product.get_name() - def getPixelGeometry(self) -> PixelGeometry: - return self._product.getGeometry().getPixelGeometry() + def get_pixel_geometry(self) -> PixelGeometry: + return self._product.get_geometry().get_object_plane_pixel_geometry() - def getOpenFileFilterList(self) -> Sequence[str]: - return self._fileReaderChooser.getDisplayNameList() + def get_open_file_filters(self) -> Iterator[str]: + for plugin in self._file_reader_chooser: + yield plugin.display_name - def getOpenFileFilter(self) -> str: - return self._fileReaderChooser.currentPlugin.displayName + def get_open_file_filter(self) -> str: + return self._file_reader_chooser.get_current_plugin().display_name - def openMeasuredDataset(self, filePath: Path, fileFilter: str) -> None: - if filePath.is_file(): - self._fileReaderChooser.setCurrentPluginByName(fileFilter) - fileType = self._fileReaderChooser.currentPlugin.simpleName - logger.debug(f'Reading "{filePath}" as "{fileType}"') - fileReader = self._fileReaderChooser.currentPlugin.strategy + def open_measured_dataset(self, file_path: Path, file_filter: str) -> None: + if file_path.is_file(): + self._file_reader_chooser.set_current_plugin(file_filter) + file_type = self._file_reader_chooser.get_current_plugin().simple_name + logger.debug(f'Reading "{file_path}" as "{file_type}"') + file_reader = self._file_reader_chooser.get_current_plugin().strategy try: - measured = fileReader.read(filePath) + measured = file_reader.read(file_path) except Exception as exc: - raise RuntimeError(f'Failed to read "{filePath}"') from exc + raise RuntimeError(f'Failed to read "{file_path}"') from exc else: self._measured = measured self._enhanced = None - self._settings.filePath.setValue(filePath) - self._settings.fileType.setValue(fileType) + self._settings.file_path.set_value(file_path) - self.notifyObservers() + self.notify_observers() else: - logger.warning(f'Refusing to load dataset from invalid file path "{filePath}"') + logger.warning(f'Refusing to load dataset from invalid file path "{file_path}"') - def getNumberOfChannels(self) -> int: + def get_num_channels(self) -> int: return 0 if self._measured is None else len(self._measured.element_maps) - def getMeasuredElementMap(self, channelIndex: int) -> ElementMap: + def get_measured_element_map(self, channel_index: int) -> ElementMap: if self._measured is None: raise ValueError('Fluorescence dataset not loaded!') - return self._measured.element_maps[channelIndex] + return self._measured.element_maps[channel_index] - def getAlgorithmList(self) -> Sequence[str]: - return self._algorithmChooser.getDisplayNameList() + def algorithms(self) -> Iterator[str]: + for plugin in self._algorithm_chooser: + yield plugin.display_name - def getAlgorithm(self) -> str: - return self._algorithmChooser.currentPlugin.displayName + def get_algorithm(self) -> str: + return self._algorithm_chooser.get_current_plugin().display_name - def setAlgorithm(self, name: str) -> None: - self._algorithmChooser.setCurrentPluginByName(name) - self._settings.algorithm.setValue(self._algorithmChooser.currentPlugin.simpleName) + def set_algorithm(self, name: str) -> None: + self._algorithm_chooser.set_current_plugin(name) - def _syncAlgorithmFromSettings(self) -> None: - self.setAlgorithm(self._settings.algorithm.getValue()) - - def enhanceFluorescence(self) -> None: + def enhance_fluorescence(self) -> None: if self._measured is None: raise ValueError('Fluorescence dataset not loaded!') else: - algorithm = self._algorithmChooser.currentPlugin.strategy - product = self._product.getProduct() + algorithm = self._algorithm_chooser.get_current_plugin().strategy + product = self._product.get_product() self._enhanced = algorithm.enhance(self._measured, product) - self.notifyObservers() + self.notify_observers() - def getEnhancedElementMap(self, channelIndex: int) -> ElementMap: + def get_enhanced_element_map(self, channel_index: int) -> ElementMap: if self._enhanced is None: - return self.getMeasuredElementMap(channelIndex) + return self.get_measured_element_map(channel_index) - return self._enhanced.element_maps[channelIndex] + return self._enhanced.element_maps[channel_index] - def getSaveFileFilterList(self) -> Sequence[str]: - return self._fileWriterChooser.getDisplayNameList() + def get_save_file_filters(self) -> Iterator[str]: + for plugin in self._file_writer_chooser: + yield plugin.display_name - def getSaveFileFilter(self) -> str: - return self._fileWriterChooser.currentPlugin.displayName + def get_save_file_filter(self) -> str: + return self._file_writer_chooser.get_current_plugin().display_name - def saveEnhancedDataset(self, filePath: Path, fileFilter: str) -> None: + def save_enhanced_dataset(self, file_path: Path, file_filter: str) -> None: if self._enhanced is None: raise ValueError('Fluorescence dataset not enhanced!') - self._fileWriterChooser.setCurrentPluginByName(fileFilter) - fileType = self._fileWriterChooser.currentPlugin.simpleName - logger.debug(f'Writing "{filePath}" as "{fileType}"') - writer = self._fileWriterChooser.currentPlugin.strategy - writer.write(filePath, self._enhanced) + self._file_writer_chooser.set_current_plugin(file_filter) + file_type = self._file_writer_chooser.get_current_plugin().simple_name + logger.debug(f'Writing "{file_path}" as "{file_type}"') + writer = self._file_writer_chooser.get_current_plugin().strategy + writer.write(file_path, self._enhanced) - def _openFluorescenceFileFromSettings(self) -> None: - self.openMeasuredDataset( - self._settings.filePath.getValue(), self._settings.fileType.getValue() + def _open_fluorescence_file_from_settings(self) -> None: + self.open_measured_dataset( + self._settings.file_path.get_value(), self._settings.file_type.get_value() ) - def update(self, observable: Observable) -> None: - if observable is self._algorithmChooser: - self.notifyObservers() - elif observable is self._reinitObservable: - self._syncAlgorithmFromSettings() - self._openFluorescenceFileFromSettings() + def _update(self, observable: Observable) -> None: + if observable is self._algorithm_chooser: + self.notify_observers() + elif observable is self._reinit_observable: + self._open_fluorescence_file_from_settings() class FluorescenceCore: def __init__( self, - settingsRegistry: SettingsRegistry, - productRepository: ProductRepository, - upscalingStrategyChooser: PluginChooser[UpscalingStrategy], - deconvolutionStrategyChooser: PluginChooser[DeconvolutionStrategy], - fileReaderChooser: PluginChooser[FluorescenceFileReader], - fileWriterChooser: PluginChooser[FluorescenceFileWriter], + settings_registry: SettingsRegistry, + product_repository: ProductRepository, + upscaling_strategy_chooser: PluginChooser[UpscalingStrategy], + deconvolution_strategy_chooser: PluginChooser[DeconvolutionStrategy], + file_reader_chooser: PluginChooser[FluorescenceFileReader], + file_writer_chooser: PluginChooser[FluorescenceFileWriter], ) -> None: - self._settings = FluorescenceSettings(settingsRegistry) - self._twoStepEnhancingAlgorithm = TwoStepFluorescenceEnhancingAlgorithm( - self._settings, upscalingStrategyChooser, deconvolutionStrategyChooser, settingsRegistry + self._settings = FluorescenceSettings(settings_registry) + self._two_step_enhancing_algorithm = TwoStepFluorescenceEnhancingAlgorithm( + self._settings, upscaling_strategy_chooser, deconvolution_strategy_chooser ) - self._vspiEnhancingAlgorithm = VSPIFluorescenceEnhancingAlgorithm(self._settings) + self._vspi_enhancing_algorithm = VSPIFluorescenceEnhancingAlgorithm(self._settings) self.enhancer = FluorescenceEnhancer( self._settings, - productRepository, - self._twoStepEnhancingAlgorithm, - self._vspiEnhancingAlgorithm, - fileReaderChooser, - fileWriterChooser, - settingsRegistry, + product_repository, + self._two_step_enhancing_algorithm, + self._vspi_enhancing_algorithm, + file_reader_chooser, + file_writer_chooser, + settings_registry, ) - self.visualizationEngine = VisualizationEngine(isComplex=False) + self.visualization_engine = VisualizationEngine(is_complex=False) - def enhanceFluorescence( - self, productIndex: int, inputFilePath: Path, outputFilePath: Path + def enhance_fluorescence( + self, product_index: int, input_file_path: Path, output_file_path: Path ) -> int: - fileType = 'XRF-Maps' + file_type = 'XRF-Maps' try: - self.enhancer.setProduct(productIndex) - self.enhancer.openMeasuredDataset(inputFilePath, fileType) - self.enhancer.enhanceFluorescence() - self.enhancer.saveEnhancedDataset(outputFilePath, fileType) + self.enhancer.set_product(product_index) + self.enhancer.open_measured_dataset(input_file_path, file_type) + self.enhancer.enhance_fluorescence() + self.enhancer.save_enhanced_dataset(output_file_path, file_type) except Exception as exc: logger.exception(exc) return -1 diff --git a/src/ptychodus/model/fluorescence/settings.py b/src/ptychodus/model/fluorescence/settings.py index 71540f6b..ae285a1b 100644 --- a/src/ptychodus/model/fluorescence/settings.py +++ b/src/ptychodus/model/fluorescence/settings.py @@ -7,27 +7,23 @@ class FluorescenceSettings(Observable, Observer): def __init__(self, registry: SettingsRegistry) -> None: super().__init__() - self._settingsGroup = registry.createGroup('Fluorescence') - self._settingsGroup.addObserver(self) + self._group = registry.create_group('Fluorescence') + self._group.add_observer(self) - self.filePath = self._settingsGroup.createPathParameter( - 'FilePath', Path('/path/to/dataset.h5') - ) - self.fileType = self._settingsGroup.createStringParameter('FileType', 'XRF-Maps') - self.algorithm = self._settingsGroup.createStringParameter('Algorithm', 'VSPI') - self.vspiDampingFactor = self._settingsGroup.createRealParameter( + self.file_path = self._group.create_path_parameter('FilePath', Path('/path/to/dataset.h5')) + self.file_type = self._group.create_string_parameter('FileType', 'XRF-Maps') + self.algorithm = self._group.create_string_parameter('Algorithm', 'VSPI') + self.vspi_damping_factor = self._group.create_real_parameter( 'VSPIDampingFactor', 0.0, minimum=0.0 ) - self.vspiMaxIterations = self._settingsGroup.createIntegerParameter( + self.vspi_max_iterations = self._group.create_integer_parameter( 'VSPIMaxIterations', 100, minimum=1 ) - self.upscalingStrategy = self._settingsGroup.createStringParameter( - 'UpscalingStrategy', 'Linear' - ) - self.deconvolutionStrategy = self._settingsGroup.createStringParameter( + self.upscaling_strategy = self._group.create_string_parameter('UpscalingStrategy', 'Linear') + self.deconvolution_strategy = self._group.create_string_parameter( 'DeconvolutionStrategy', 'Richardson-Lucy' ) - def update(self, observable: Observable) -> None: - if observable is self._settingsGroup: - self.notifyObservers() + def _update(self, observable: Observable) -> None: + if observable is self._group: + self.notify_observers() diff --git a/src/ptychodus/model/fluorescence/two_step.py b/src/ptychodus/model/fluorescence/two_step.py index 28c08e98..47e633d8 100644 --- a/src/ptychodus/model/fluorescence/two_step.py +++ b/src/ptychodus/model/fluorescence/two_step.py @@ -1,5 +1,5 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Iterator from typing import Final import logging import time @@ -15,6 +15,7 @@ from ptychodus.api.plugins import PluginChooser from ptychodus.api.product import Product +from ..analysis import BarycentricArrayInterpolator from .settings import FluorescenceSettings logger = logging.getLogger(__name__) @@ -31,27 +32,22 @@ class TwoStepFluorescenceEnhancingAlgorithm(FluorescenceEnhancingAlgorithm, Obse def __init__( self, settings: FluorescenceSettings, - upscalingStrategyChooser: PluginChooser[UpscalingStrategy], - deconvolutionStrategyChooser: PluginChooser[DeconvolutionStrategy], - reinitObservable: Observable, + upscaling_strategy_chooser: PluginChooser[UpscalingStrategy], + deconvolution_strategy_chooser: PluginChooser[DeconvolutionStrategy], ) -> None: super().__init__() - self._settings = settings - self._upscalingStrategyChooser = upscalingStrategyChooser - self._deconvolutionStrategyChooser = deconvolutionStrategyChooser - self._reinitObservable = reinitObservable + self._upscaling_strategy_chooser = upscaling_strategy_chooser + self._deconvolution_strategy_chooser = deconvolution_strategy_chooser - self._syncUpscalingStrategyFromSettings() - upscalingStrategyChooser.addObserver(self) + upscaling_strategy_chooser.synchronize_with_parameter(settings.upscaling_strategy) + upscaling_strategy_chooser.add_observer(self) - self._syncDeconvolutionStrategyFromSettings() - deconvolutionStrategyChooser.addObserver(self) - - reinitObservable.addObserver(self) + deconvolution_strategy_chooser.synchronize_with_parameter(settings.deconvolution_strategy) + deconvolution_strategy_chooser.add_observer(self) def enhance(self, dataset: FluorescenceDataset, product: Product) -> FluorescenceDataset: - upscaler = self._upscalingStrategyChooser.currentPlugin.strategy - deconvolver = self._deconvolutionStrategyChooser.currentPlugin.strategy + upscaler = self._upscaling_strategy_chooser.get_current_plugin().strategy + deconvolver = self._deconvolution_strategy_chooser.get_current_plugin().strategy element_maps: list[ElementMap] = list() for emap in dataset.element_maps: @@ -70,41 +66,28 @@ def enhance(self, dataset: FluorescenceDataset, product: Product) -> Fluorescenc channel_names_path=dataset.channel_names_path, ) - def getUpscalingStrategyList(self) -> Sequence[str]: - return self._upscalingStrategyChooser.getDisplayNameList() - - def getUpscalingStrategy(self) -> str: - return self._upscalingStrategyChooser.currentPlugin.displayName + def get_upscaling_strategies(self) -> Iterator[str]: + for plugin in self._upscaling_strategy_chooser: + yield plugin.display_name - def setUpscalingStrategy(self, name: str) -> None: - self._upscalingStrategyChooser.setCurrentPluginByName(name) - self._settings.upscalingStrategy.setValue( - self._upscalingStrategyChooser.currentPlugin.simpleName - ) + def get_upscaling_strategy(self) -> str: + return self._upscaling_strategy_chooser.get_current_plugin().display_name - def _syncUpscalingStrategyFromSettings(self) -> None: - self.setUpscalingStrategy(self._settings.upscalingStrategy.getValue()) + def set_upscaling_strategy(self, name: str) -> None: + self._upscaling_strategy_chooser.set_current_plugin(name) - def getDeconvolutionStrategyList(self) -> Sequence[str]: - return self._deconvolutionStrategyChooser.getDisplayNameList() + def get_deconvolution_strategies(self) -> Iterator[str]: + for plugin in self._deconvolution_strategy_chooser: + yield plugin.display_name - def getDeconvolutionStrategy(self) -> str: - return self._deconvolutionStrategyChooser.currentPlugin.displayName + def get_deconvolution_strategy(self) -> str: + return self._deconvolution_strategy_chooser.get_current_plugin().display_name - def setDeconvolutionStrategy(self, name: str) -> None: - self._deconvolutionStrategyChooser.setCurrentPluginByName(name) - self._settings.deconvolutionStrategy.setValue( - self._deconvolutionStrategyChooser.currentPlugin.simpleName - ) + def set_deconvolution_strategy(self, name: str) -> None: + self._deconvolution_strategy_chooser.set_current_plugin(name) - def _syncDeconvolutionStrategyFromSettings(self) -> None: - self.setDeconvolutionStrategy(self._settings.deconvolutionStrategy.getValue()) - - def update(self, observable: Observable) -> None: - if observable is self._reinitObservable: - self._syncUpscalingStrategyFromSettings() - self._syncDeconvolutionStrategyFromSettings() - elif observable is self._upscalingStrategyChooser: - self.notifyObservers() - elif observable is self._deconvolutionStrategyChooser: - self.notifyObservers() + def _update(self, observable: Observable) -> None: + if observable is self._upscaling_strategy_chooser: + self.notify_observers() + elif observable is self._deconvolution_strategy_chooser: + self.notify_observers() diff --git a/src/ptychodus/model/fluorescence/vspi.py b/src/ptychodus/model/fluorescence/vspi.py index 15247c66..148e025f 100644 --- a/src/ptychodus/model/fluorescence/vspi.py +++ b/src/ptychodus/model/fluorescence/vspi.py @@ -28,8 +28,8 @@ class ArrayPatchInterpolator: def __init__(self, array: RealArrayType, point: ObjectPoint, shape: tuple[int, ...]) -> None: # top left corner of object support - xmin = point.positionXInPixels - shape[-1] / 2 - ymin = point.positionYInPixels - shape[-2] / 2 + xmin = point.position_x_px - shape[-1] / 2 + ymin = point.position_y_px - shape[-2] / 2 # whole components (pixel indexes) xmin_wh = int(xmin) @@ -81,40 +81,42 @@ def __init__(self, product: Product) -> None: A[M,N] * X[N,P] = B[M,P] """ - object_geometry = product.object_.getGeometry() - M = len(product.scan) - N = object_geometry.heightInPixels * object_geometry.widthInPixels + object_geometry = product.object_.get_geometry() + M = len(product.positions) # noqa: N806 + N = object_geometry.height_px * object_geometry.width_px # noqa: N806 super().__init__(float, (M, N)) self._product = product - def _get_psf(self) -> RealArrayType: - intensity = self._product.probe.getIntensity() - return intensity / intensity.sum() - - def _matvec(self, X: RealArrayType) -> RealArrayType: - object_geometry = self._product.object_.getGeometry() - object_array = X.reshape((object_geometry.heightInPixels, object_geometry.widthInPixels)) - psf = self._get_psf() - AX = numpy.zeros(len(self._product.scan)) - - for index, scan_point in enumerate(self._product.scan): - object_point = object_geometry.mapScanPointToObjectPoint(scan_point) + def _matvec(self, x: RealArrayType) -> RealArrayType: # noqa: N803 + object_geometry = self._product.object_.get_geometry() + object_array = x.reshape((object_geometry.height_px, object_geometry.width_px)) + AX = numpy.zeros(len(self._product.positions)) # noqa: N806 + + for index, (scan_point, probe) in enumerate( + zip(self._product.positions, self._product.probes) + ): + object_point = object_geometry.map_scan_point_to_object_point(scan_point) + probe_intensity = probe.get_intensity() + psf = probe_intensity / probe_intensity.sum() interpolator = ArrayPatchInterpolator(object_array, object_point, psf.shape) AX[index] = numpy.sum(psf * interpolator.get_patch()) return AX - def _rmatvec(self, X: RealArrayType) -> RealArrayType: - object_geometry = self._product.object_.getGeometry() - object_array = numpy.zeros((object_geometry.heightInPixels, object_geometry.widthInPixels)) - psf = self._get_psf() + def _rmatvec(self, x: RealArrayType) -> RealArrayType: # noqa: N803 + object_geometry = self._product.object_.get_geometry() + object_array = numpy.zeros((object_geometry.height_px, object_geometry.width_px)) - for index, scan_point in enumerate(self._product.scan): - object_point = object_geometry.mapScanPointToObjectPoint(scan_point) + for index, (scan_point, probe) in enumerate( + zip(self._product.positions, self._product.probes) + ): + object_point = object_geometry.map_scan_point_to_object_point(scan_point) + probe_intensity = probe.get_intensity() + psf = probe_intensity / probe_intensity.sum() interpolator = ArrayPatchInterpolator(object_array, object_point, psf.shape) - interpolator.accumulate_patch(X[index] * psf) + interpolator.accumulate_patch(x[index] * psf) - HX = object_array.flatten() + HX = object_array.flatten() # noqa: N806 return HX @@ -127,14 +129,14 @@ def __init__(self, settings: FluorescenceSettings) -> None: super().__init__() self._settings = settings - settings.vspiDampingFactor.addObserver(self) - settings.vspiMaxIterations.addObserver(self) + settings.vspi_damping_factor.add_observer(self) + settings.vspi_max_iterations.add_observer(self) def enhance(self, dataset: FluorescenceDataset, product: Product) -> FluorescenceDataset: - object_geometry = product.object_.getGeometry() - e_cps_shape = object_geometry.heightInPixels, object_geometry.widthInPixels + object_geometry = product.object_.get_geometry() + e_cps_shape = object_geometry.height_px, object_geometry.width_px element_maps: list[ElementMap] = list() - A = VSPILinearOperator(product) + A = VSPILinearOperator(product) # noqa: N806 for emap in dataset.element_maps: logger.info(f'Enhancing "{emap.name}"...') @@ -143,8 +145,8 @@ def enhance(self, dataset: FluorescenceDataset, product: Product) -> Fluorescenc result = lsmr( A, m_cps.flatten(), - damp=self._settings.vspiDampingFactor.getValue(), - maxiter=self._settings.vspiMaxIterations.getValue(), + damp=self._settings.vspi_damping_factor.get_value(), + maxiter=self._settings.vspi_max_iterations.get_value(), show=True, ) logger.debug(result) @@ -161,20 +163,20 @@ def enhance(self, dataset: FluorescenceDataset, product: Product) -> Fluorescenc channel_names_path=dataset.channel_names_path, ) - def getDampingFactor(self) -> float: - return self._settings.vspiDampingFactor.getValue() + def get_damping_factor(self) -> float: + return self._settings.vspi_damping_factor.get_value() - def setDampingFactor(self, factor: float) -> None: - self._settings.vspiDampingFactor.setValue(factor) + def set_damping_factor(self, factor: float) -> None: + self._settings.vspi_damping_factor.set_value(factor) - def getMaxIterations(self) -> int: - return self._settings.vspiMaxIterations.getValue() + def get_max_iterations(self) -> int: + return self._settings.vspi_max_iterations.get_value() - def setMaxIterations(self, number: int) -> None: - self._settings.vspiMaxIterations.setValue(number) + def set_max_iterations(self, number: int) -> None: + self._settings.vspi_max_iterations.set_value(number) - def update(self, observable: Observable) -> None: - if observable is self._settings.vspiDampingFactor: - self.notifyObservers() - elif observable is self._settings.vspiMaxIterations: - self.notifyObservers() + def _update(self, observable: Observable) -> None: + if observable is self._settings.vspi_damping_factor: + self.notify_observers() + elif observable is self._settings.vspi_max_iterations: + self.notify_observers() diff --git a/src/ptychodus/model/memory.py b/src/ptychodus/model/memory.py index b58a169d..06246354 100644 --- a/src/ptychodus/model/memory.py +++ b/src/ptychodus/model/memory.py @@ -5,17 +5,16 @@ @dataclass(frozen=True) class MemoryStatistics: - totalMemoryInBytes: int - availableMemoryInBytes: int - memoryUsagePercent: float + total_physical_memory_bytes: int + available_memory_bytes: int + percent_usage: float class MemoryPresenter: - def getStatistics(self) -> MemoryStatistics: + def get_statistics(self) -> MemoryStatistics: mem = psutil.virtual_memory() - stats = MemoryStatistics( - totalMemoryInBytes=mem.total, - availableMemoryInBytes=mem.available, - memoryUsagePercent=mem.percent, + return MemoryStatistics( + total_physical_memory_bytes=mem.total, + available_memory_bytes=mem.available, + percent_usage=mem.percent, ) - return stats diff --git a/src/ptychodus/model/metadata.py b/src/ptychodus/model/metadata.py new file mode 100644 index 00000000..1ffda1ab --- /dev/null +++ b/src/ptychodus/model/metadata.py @@ -0,0 +1,137 @@ +from __future__ import annotations + +from ptychodus.api.observer import Observable +from ptychodus.api.patterns import DiffractionMetadata + +from .patterns import ( + DetectorSettings, + DiffractionDatasetObserver, + AssembledDiffractionDataset, + PatternSettings, +) +from .product import ProductSettings + + +class MetadataPresenter(Observable, DiffractionDatasetObserver): + def __init__( + self, + detector_settings: DetectorSettings, + pattern_settings: PatternSettings, + dataset: AssembledDiffractionDataset, + product_settings: ProductSettings, + ) -> None: + super().__init__() + self._detector_settings = detector_settings + self._pattern_settings = pattern_settings + self._dataset = dataset + self._product_settings = product_settings + + dataset.add_observer(self) + + @property + def _metadata(self) -> DiffractionMetadata: + return self._dataset.get_metadata() + + def can_sync_detector_extent(self) -> bool: + return self._metadata.detector_extent is not None + + def sync_detector_extent(self) -> None: + detector_extent = self._metadata.detector_extent + + if detector_extent: + self._detector_settings.width_px.set_value(detector_extent.width_px) + self._detector_settings.height_px.set_value(detector_extent.height_px) + + def can_sync_detector_pixel_size(self) -> bool: + return self._metadata.detector_pixel_geometry is not None + + def sync_detector_pixel_size(self) -> None: + pixel_geometry = self._metadata.detector_pixel_geometry + + if pixel_geometry: + self._detector_settings.pixel_width_m.set_value(pixel_geometry.width_m) + self._detector_settings.pixel_height_m.set_value(pixel_geometry.height_m) + + def can_sync_detector_bit_depth(self) -> bool: + return self._metadata.detector_bit_depth is not None + + def sync_detector_bit_depth(self) -> None: + bit_depth = self._metadata.detector_bit_depth + + if bit_depth: + self._detector_settings.bit_depth.set_value(bit_depth) + + def can_sync_pattern_crop_center(self) -> bool: + return self._metadata.crop_center is not None or self._metadata.detector_extent is not None + + def can_sync_pattern_crop_extent(self) -> bool: + return self._metadata.detector_extent is not None + + def sync_pattern_crop(self, sync_center: bool, sync_extent: bool) -> None: + if sync_center: + crop_center = self._metadata.crop_center + + if crop_center: + self._pattern_settings.crop_center_x_px.set_value(crop_center.position_x_px) + self._pattern_settings.crop_center_y_px.set_value(crop_center.position_y_px) + elif self._metadata.detector_extent: + self._pattern_settings.crop_center_x_px.set_value( + int(self._metadata.detector_extent.width_px) // 2 + ) + self._pattern_settings.crop_center_y_px.set_value( + int(self._metadata.detector_extent.height_px) // 2 + ) + + if sync_extent and self._metadata.detector_extent: + center_x = self._pattern_settings.crop_center_x_px.get_value() + center_y = self._pattern_settings.crop_center_y_px.get_value() + + extent_x = int(self._metadata.detector_extent.width_px) + extent_y = int(self._metadata.detector_extent.height_px) + + max_radius_x = min(center_x, extent_x - center_x) + max_radius_y = min(center_y, extent_y - center_y) + max_radius = min(max_radius_x, max_radius_y) + crop_diameter = 1 + + while crop_diameter < max_radius: + crop_diameter <<= 1 + + self._pattern_settings.crop_width_px.set_value(crop_diameter) + self._pattern_settings.crop_height_px.set_value(crop_diameter) + + def can_sync_probe_photon_count(self) -> bool: + return self._metadata.probe_photon_count is not None + + def sync_probe_photon_count(self) -> None: + photon_count = self._metadata.probe_photon_count + + if photon_count: + self._product_settings.probe_photon_count.set_value(photon_count) + + def can_sync_probe_energy(self) -> bool: + return self._metadata.probe_energy_eV is not None + + def sync_probe_energy(self) -> None: + energy_eV = self._metadata.probe_energy_eV # noqa: N806 + + if energy_eV: + self._product_settings.probe_energy_eV.set_value(energy_eV) + + def can_sync_detector_distance(self) -> bool: + return self._metadata.detector_distance_m is not None + + def sync_detector_distance(self) -> None: + distance_m = self._metadata.detector_distance_m + + if distance_m: + self._product_settings.detector_distance_m.set_value(distance_m) + + def handle_array_inserted(self, index: int) -> None: + pass + + def handle_array_changed(self, index: int) -> None: + pass + + def handle_dataset_reloaded(self) -> None: + self.notify_observers() diff --git a/src/ptychodus/model/patterns/__init__.py b/src/ptychodus/model/patterns/__init__.py index 6b5fcd2c..bc87bb8d 100644 --- a/src/ptychodus/model/patterns/__init__.py +++ b/src/ptychodus/model/patterns/__init__.py @@ -1,28 +1,21 @@ -from .active import ActiveDiffractionDataset -from .api import PatternsAPI -from .core import ( - DiffractionDatasetPresenter, - DiffractionPatternArrayPresenter, - PatternsCore, +from .api import PatternsAPI, PatternsStreamingContext +from .core import PatternsCore +from .dataset import ( + AssembledDiffractionDataset, + AssembledDiffractionPatternArray, + DiffractionDatasetObserver, ) -from .detector import Detector -from .io import DiffractionDatasetInputOutputPresenter -from .metadata import DiffractionMetadataPresenter -from .patterns import DiffractionPatternPresenter -from .settings import PatternSettings, ProductSettings +from .settings import DetectorSettings, PatternSettings from .sizer import PatternSizer __all__ = [ - 'ActiveDiffractionDataset', - 'Detector', - 'DiffractionDatasetInputOutputPresenter', - 'DiffractionDatasetPresenter', - 'DiffractionMetadataPresenter', - 'DiffractionPatternArrayPresenter', - 'DiffractionPatternPresenter', + 'AssembledDiffractionDataset', + 'AssembledDiffractionPatternArray', + 'DetectorSettings', + 'DiffractionDatasetObserver', 'PatternSettings', 'PatternSizer', 'PatternsAPI', 'PatternsCore', - 'ProductSettings', + 'PatternsStreamingContext', ] diff --git a/src/ptychodus/model/patterns/active.py b/src/ptychodus/model/patterns/active.py deleted file mode 100644 index b43cdaa2..00000000 --- a/src/ptychodus/model/patterns/active.py +++ /dev/null @@ -1,203 +0,0 @@ -from __future__ import annotations -from collections.abc import Sequence -from typing import overload -import logging -import tempfile -import threading - -import numpy -import numpy.typing - -from ptychodus.api.geometry import ImageExtent -from ptychodus.api.patterns import ( - BooleanArrayType, - DiffractionDataset, - DiffractionMetadata, - DiffractionPatternArray, - DiffractionPatternArrayType, - DiffractionPatternIndexes, - DiffractionPatternState, - SimpleDiffractionPatternArray, -) -from ptychodus.api.tree import SimpleTreeNode - -from .settings import PatternSettings -from .sizer import PatternSizer - -__all__ = [ - 'ActiveDiffractionDataset', -] - -logger = logging.getLogger(__name__) - - -class ActiveDiffractionDataset(DiffractionDataset): - def __init__(self, settings: PatternSettings, diffractionPatternSizer: PatternSizer) -> None: - super().__init__() - self._settings = settings - self._diffractionPatternSizer = diffractionPatternSizer - - self._metadata = DiffractionMetadata.createNullInstance() - self._contentsTree = SimpleTreeNode.createRoot(list()) - self._arrayListLock = threading.RLock() - self._arrayList: list[DiffractionPatternArray] = list() - self._arrayData: DiffractionPatternArrayType = numpy.zeros((0, 0, 0), dtype=numpy.uint16) - self._changedEvent = threading.Event() - - def getMetadata(self) -> DiffractionMetadata: - return self._metadata - - def getContentsTree(self) -> SimpleTreeNode: - return self._contentsTree - - def getInfoText(self) -> str: - filePath = self._metadata.filePath - label = filePath.stem if filePath else 'None' - number, height, width = self._arrayData.shape - dtype = str(self._arrayData.dtype) - sizeInMB = self._arrayData.nbytes / (1024 * 1024) - return f'{label}: {number} x {width}W x {height}H {dtype} [{sizeInMB:.2f}MB]' - - @overload - def __getitem__(self, index: int) -> DiffractionPatternArray: ... - - @overload - def __getitem__(self, index: slice) -> Sequence[DiffractionPatternArray]: ... - - def __getitem__( - self, index: int | slice - ) -> DiffractionPatternArray | Sequence[DiffractionPatternArray]: - with self._arrayListLock: - return self._arrayList[index] - - def __len__(self) -> int: - with self._arrayListLock: - return len(self._arrayList) - - def reset(self, metadata: DiffractionMetadata, contentsTree: SimpleTreeNode) -> None: - with self._arrayListLock: - self._metadata = metadata - self._contentsTree = contentsTree - self._arrayList.clear() - - self._changedEvent.set() - - def realloc(self) -> None: - shape = ( - self._metadata.numberOfPatternsTotal, - self._diffractionPatternSizer.getHeightInPixels(), - self._diffractionPatternSizer.getWidthInPixels(), - ) - - with self._arrayListLock: - self._arrayList.clear() - - if self._settings.memmapEnabled.getValue(): - scratchDirectory = self._settings.scratchDirectory.getValue() - scratchDirectory.mkdir(mode=0o755, parents=True, exist_ok=True) - npyTempFile = tempfile.NamedTemporaryFile(dir=scratchDirectory, suffix='.npy') - logger.debug(f'Scratch data file {npyTempFile.name} is {shape}') - self._arrayData = numpy.memmap( - npyTempFile, dtype=self._metadata.patternDataType, shape=shape - ) - self._arrayData[:] = 0 - else: - logger.debug(f'Scratch memory is {shape}') - self._arrayData = numpy.zeros(shape, dtype=self._metadata.patternDataType) - - self._changedEvent.set() - - def insertArray(self, array: DiffractionPatternArray) -> None: - if array.getState() == DiffractionPatternState.LOADED: - data = self._diffractionPatternSizer(array.getData()) - - if self._settings.valueUpperBoundEnabled.getValue(): - valueLowerBound = self._settings.valueLowerBound.getValue() - valueUpperBound = self._settings.valueUpperBound.getValue() - data[data >= valueUpperBound] = 0 - - if self._settings.valueLowerBoundEnabled.getValue(): - valueLowerBound = self._settings.valueLowerBound.getValue() - data[data < valueLowerBound] = 0 - - if self._settings.flipXEnabled.getValue(): - data = numpy.flip(data, axis=-1) - - if self._settings.flipYEnabled.getValue(): - data = numpy.flip(data, axis=-2) - - offset = self._metadata.numberOfPatternsPerArray * array.getIndex() - sliceZ = slice(offset, offset + data.shape[0]) - dataView = self._arrayData[sliceZ, :, :] - dataView[:] = data - dataView.flags.writeable = False - - array = SimpleDiffractionPatternArray( - array.getLabel(), array.getIndex(), dataView, array.getState() - ) - - with self._arrayListLock: - self._arrayList.append(array) - self._arrayList.sort(key=lambda arr: arr.getIndex()) - - self._changedEvent.set() - - def getGoodPixelMask(self) -> BooleanArrayType: # FIXME - return numpy.full( - ( - self._diffractionPatternSizer.getHeightInPixels(), - self._diffractionPatternSizer.getWidthInPixels(), - ), - True, - ) - - def getAssembledIndexes(self) -> Sequence[int]: - indexes: list[int] = list() - - with self._arrayListLock: - for array in self._arrayList: - if array.getState() == DiffractionPatternState.LOADED: - offset = self._metadata.numberOfPatternsPerArray * array.getIndex() - size = array.getNumberOfPatterns() - indexes.extend(range(offset, offset + size)) - - return indexes - - def getAssembledData(self) -> DiffractionPatternArrayType: - indexes = self.getAssembledIndexes() - return self._arrayData[indexes] - - def setAssembledData( - self, - arrayData: DiffractionPatternArrayType, - arrayIndexes: DiffractionPatternIndexes, - ) -> None: - with self._arrayListLock: - numberOfPatterns, detectorHeight, detectorWidth = arrayData.shape - - self._metadata = DiffractionMetadata( - numberOfPatternsPerArray=numberOfPatterns, - numberOfPatternsTotal=numberOfPatterns, - patternDataType=arrayData.dtype, - detectorExtent=ImageExtent(detectorWidth, detectorHeight), - ) - - self._contentsTree = SimpleTreeNode.createRoot(['Name', 'Type', 'Details']) - - # TODO use arrayIndexes - self._arrayList = [ - SimpleDiffractionPatternArray( - label='Processed', - index=0, - data=arrayData[...], - state=DiffractionPatternState.LOADED, - ), - ] - self._arrayData = arrayData - - self.notifyObservers() - - def notifyObserversIfDatasetChanged(self) -> None: - if self._changedEvent.is_set(): - self._changedEvent.clear() - self.notifyObservers() diff --git a/src/ptychodus/model/patterns/api.py b/src/ptychodus/model/patterns/api.py index 11f00fe5..313f3a56 100644 --- a/src/ptychodus/model/patterns/api.py +++ b/src/ptychodus/model/patterns/api.py @@ -1,9 +1,6 @@ -from collections.abc import Sequence from pathlib import Path -from typing import Any import logging -import numpy from ptychodus.api.geometry import ImageExtent from ptychodus.api.patterns import ( @@ -17,127 +14,121 @@ from ptychodus.api.plugins import PluginChooser from ptychodus.api.tree import SimpleTreeNode -from .active import ActiveDiffractionDataset -from .builder import ActiveDiffractionDatasetBuilder -from .settings import PatternSettings +from .dataset import AssembledDiffractionDataset +from .settings import DetectorSettings, PatternSettings logger = logging.getLogger(__name__) -class PatternsAPI: - def __init__( - self, - settings: PatternSettings, - builder: ActiveDiffractionDatasetBuilder, - dataset: ActiveDiffractionDataset, - fileReaderChooser: PluginChooser[DiffractionFileReader], - fileWriterChooser: PluginChooser[DiffractionFileWriter], - ) -> None: - super().__init__() - self._settings = settings - self._builder = builder +class PatternsStreamingContext: + def __init__(self, dataset: AssembledDiffractionDataset, metadata: DiffractionMetadata) -> None: self._dataset = dataset - self._fileReaderChooser = fileReaderChooser - self._fileWriterChooser = fileWriterChooser + self._metadata = metadata + + def start(self) -> None: + contents_tree = SimpleTreeNode.create_root(['Name', 'Type', 'Details']) + stream_dataset = SimpleDiffractionDataset(self._metadata, contents_tree, []) + self._dataset.reload(stream_dataset) + self._dataset.start_loading() - def initializeStreaming(self, metadata: DiffractionMetadata) -> None: - contentsTree = SimpleTreeNode.createRoot(['Name', 'Type', 'Details']) - arrayList: list[DiffractionPatternArray] = list() - dataset = SimpleDiffractionDataset(metadata, contentsTree, arrayList) - self._builder.switchTo(dataset) + def append_array(self, array: DiffractionPatternArray) -> None: + self._dataset.append_array(array) - def startAssemblingDiffractionPatterns(self) -> None: - self._builder.start() + def get_queue_size(self) -> int: + return self._dataset.queue_size - def assemble(self, array: DiffractionPatternArray) -> None: - self._builder.insertArray(array) + def stop(self) -> None: + self._dataset.finish_loading(block=True) + self._dataset.assemble_patterns() - def getAssemblyQueueSize(self) -> int: - return self._builder.getAssemblyQueueSize() - def stopAssemblingDiffractionPatterns(self, finishAssembling: bool) -> None: - self._builder.stop(finishAssembling) +class PatternsAPI: + def __init__( + self, + pattern_settings: PatternSettings, + detector_settings: DetectorSettings, + dataset: AssembledDiffractionDataset, + file_reader_chooser: PluginChooser[DiffractionFileReader], + file_writer_chooser: PluginChooser[DiffractionFileWriter], + ) -> None: + super().__init__() + self._pattern_settings = pattern_settings + self._detector_settings = detector_settings + self._dataset = dataset + self._file_reader_chooser = file_reader_chooser + self._file_writer_chooser = file_writer_chooser - def getOpenFileFilterList(self) -> Sequence[str]: - return self._fileReaderChooser.getDisplayNameList() + def create_streaming_context(self, metadata: DiffractionMetadata) -> PatternsStreamingContext: + return PatternsStreamingContext(self._dataset, metadata) - def getOpenFileFilter(self) -> str: - return self._fileReaderChooser.currentPlugin.displayName + def get_file_reader_chooser(self) -> PluginChooser[DiffractionFileReader]: + return self._file_reader_chooser - def openPatterns( + def open_patterns( self, - filePath: Path, + file_path: Path, *, - fileType: str | None = None, - cropCenter: CropCenter | None = None, - cropExtent: ImageExtent | None = None, - assemble: bool = True, - ) -> str | None: - if cropCenter is not None: - self._settings.cropCenterXInPixels.setValue(cropCenter.positionXInPixels) - self._settings.cropCenterYInPixels.setValue(cropCenter.positionYInPixels) - - if cropExtent is not None: - self._settings.cropWidthInPixels.setValue(cropExtent.widthInPixels) - self._settings.cropHeightInPixels.setValue(cropExtent.heightInPixels) - - fileType_ = self._settings.fileType.getValue() if fileType is None else fileType - self._fileReaderChooser.setCurrentPluginByName(fileType_) - - if filePath.is_file(): - fileReader = self._fileReaderChooser.currentPlugin.strategy - fileType = self._fileReaderChooser.currentPlugin.simpleName - logger.debug(f'Reading "{filePath}" as "{fileType}"') + file_type: str | None = None, + crop_center: CropCenter | None = None, + crop_extent: ImageExtent | None = None, + detector_extent: ImageExtent | None = None, + ) -> int: + if crop_center is not None: + self._pattern_settings.crop_center_x_px.set_value(crop_center.position_x_px) + self._pattern_settings.crop_center_y_px.set_value(crop_center.position_y_px) + + if crop_extent is not None: + self._pattern_settings.crop_width_px.set_value(crop_extent.width_px) + self._pattern_settings.crop_height_px.set_value(crop_extent.height_px) + + if detector_extent is not None: + self._detector_settings.width_px.set_value(detector_extent.width_px) + self._detector_settings.height_px.set_value(detector_extent.height_px) + + if file_path.is_file(): + if file_type is not None: + self._file_reader_chooser.set_current_plugin(file_type) + + file_type = self._file_reader_chooser.get_current_plugin().simple_name + logger.debug(f'Reading "{file_path}" as "{file_type}"') + file_reader = self._file_reader_chooser.get_current_plugin().strategy try: - dataset = fileReader.read(filePath) + dataset = file_reader.read(file_path) except Exception as exc: - raise RuntimeError(f'Failed to read "{filePath}"') from exc + raise RuntimeError(f'Failed to read "{file_path}"') from exc else: - self._builder.switchTo(dataset) + self._dataset.reload(dataset) + return 0 else: - logger.warning(f'Refusing to read invalid file path {filePath}') - return None + logger.warning(f'Refusing to read invalid file path {file_path}') - if assemble: - self._builder.start() - self._builder.stop(finishAssembling=True) + return -1 - return self._fileReaderChooser.currentPlugin.simpleName + def start_assembling_diffraction_patterns(self) -> None: + self._dataset.start_loading() - def getSaveFileFilterList(self) -> Sequence[str]: - return self._fileWriterChooser.getDisplayNameList() + def finish_assembling_diffraction_patterns(self, *, block: bool) -> None: + self._dataset.finish_loading(block=block) - def getSaveFileFilter(self) -> str: - return self._fileWriterChooser.currentPlugin.displayName + if block: + self._dataset.assemble_patterns() - def savePatterns(self, filePath: Path, fileType: str) -> None: - self._fileWriterChooser.setCurrentPluginByName(fileType) - fileType = self._fileWriterChooser.currentPlugin.simpleName - logger.debug(f'Writing "{filePath}" as "{fileType}"') - writer = self._fileWriterChooser.currentPlugin.strategy - writer.write(filePath, self._dataset) + def close_patterns(self) -> None: + self._dataset.clear() - def importProcessedPatterns(self, filePath: Path) -> None: - if filePath.is_file(): - logger.debug(f'Reading processed patterns from "{filePath}"') + def get_file_writer_chooser(self) -> PluginChooser[DiffractionFileWriter]: + return self._file_writer_chooser - try: - contents = numpy.load(filePath) - except Exception as exc: - raise RuntimeError(f'Failed to read "{filePath}"') from exc + def save_patterns(self, file_path: Path, file_type: str) -> None: + self._file_writer_chooser.set_current_plugin(file_type) + file_type = self._file_writer_chooser.get_current_plugin().simple_name + logger.debug(f'Writing "{file_path}" as "{file_type}"') + writer = self._file_writer_chooser.get_current_plugin().strategy + writer.write(file_path, self._dataset) - self._builder.stop(finishAssembling=False) - self._dataset.setAssembledData(contents['patterns'], contents['indexes']) - self._builder.start() - self._builder.stop(finishAssembling=True) - else: - logger.warning(f'Refusing to read invalid file path {filePath}') - - def exportProcessedPatterns(self, filePath: Path) -> None: - contents: dict[str, Any] = { - 'indexes': numpy.array(self._dataset.getAssembledIndexes()), - 'patterns': numpy.array(self._dataset.getAssembledData()), - } - logger.debug(f'Writing processed patterns to "{filePath}"') - numpy.savez(filePath, **contents) + def import_assembled_patterns(self, file_path: Path) -> None: + self._dataset.import_assembled_patterns(file_path) + + def export_assembled_patterns(self, file_path: Path) -> None: + self._dataset.export_assembled_patterns(file_path) diff --git a/src/ptychodus/model/patterns/builder.py b/src/ptychodus/model/patterns/builder.py deleted file mode 100644 index f677dc89..00000000 --- a/src/ptychodus/model/patterns/builder.py +++ /dev/null @@ -1,122 +0,0 @@ -from __future__ import annotations -import logging -import queue -import threading - -import numpy - -from ptychodus.api.patterns import ( - DiffractionDataset, - DiffractionPatternArray, - DiffractionPatternState, - SimpleDiffractionPatternArray, -) - -from .active import ActiveDiffractionDataset -from .settings import PatternSettings - -__all__ = [ - 'ActiveDiffractionDatasetBuilder', -] - -logger = logging.getLogger(__name__) - - -class ActiveDiffractionDatasetBuilder: - def __init__(self, settings: PatternSettings, dataset: ActiveDiffractionDataset) -> None: - super().__init__() - self._settings = settings - self._dataset = dataset - self._unassembledDataset: DiffractionDataset | None = None - self._arrayQueue: queue.Queue[DiffractionPatternArray] = queue.Queue() - self._workers: list[threading.Thread] = list() - self._stopWorkEvent = threading.Event() - - @property - def isAssembling(self) -> bool: - return len(self._workers) > 0 - - def _getArrayAndAssemble(self) -> None: - while not self._stopWorkEvent.is_set(): - try: - array = self._arrayQueue.get(block=True, timeout=1) - - try: - self._assemble(array) - finally: - self._arrayQueue.task_done() - except queue.Empty: - pass - except Exception: - logger.exception('Error while assembling array!') - - def _assemble(self, array: DiffractionPatternArray) -> None: - logger.info(f'Assembling {array.getLabel()}...') - - try: - data = array.getData() - except Exception: - metadata = self._dataset.getMetadata() - data = numpy.zeros((0, 0, 0), dtype=metadata.patternDataType) - state = DiffractionPatternState.MISSING - else: - state = DiffractionPatternState.LOADED - - array = SimpleDiffractionPatternArray(array.getLabel(), array.getIndex(), data, state) - - self._dataset.insertArray(array) - - def insertArray(self, array: DiffractionPatternArray) -> None: - self._arrayQueue.put(array) - - def switchTo(self, dataset: DiffractionDataset) -> None: - if self.isAssembling: - self.stop(finishAssembling=False) - - self._dataset.reset(dataset.getMetadata(), dataset.getContentsTree()) - self._unassembledDataset = dataset - - def start(self) -> None: - if self.isAssembling: - self.stop(finishAssembling=False) - - if self._unassembledDataset is None: - logger.debug('Skipping data assembler reset.') - else: - logger.info('Resetting data assembler...') - - self._dataset.realloc() - - for array in self._unassembledDataset: - self.insertArray(array) - - logger.info('Data assembler reset.') - - logger.info('Starting data assembler...') - self._stopWorkEvent.clear() - - for idx in range(self._settings.numberOfDataThreads.getValue()): - thread = threading.Thread(target=self._getArrayAndAssemble) - thread.start() - self._workers.append(thread) - - logger.info('Data assembler started.') - - def stop(self, finishAssembling: bool) -> None: - if finishAssembling: - self._arrayQueue.join() - - logger.info('Stopping data assembler...') - self._stopWorkEvent.set() - - while self._workers: - thread = self._workers.pop() - thread.join() - - with self._arrayQueue.mutex: - self._arrayQueue.queue.clear() - - logger.info('Data assembler stopped.') - - def getAssemblyQueueSize(self) -> int: - return self._arrayQueue.qsize() diff --git a/src/ptychodus/model/patterns/core.py b/src/ptychodus/model/patterns/core.py index f0d22570..0ef5e6c2 100644 --- a/src/ptychodus/model/patterns/core.py +++ b/src/ptychodus/model/patterns/core.py @@ -1,179 +1,58 @@ -from __future__ import annotations -from collections.abc import Iterator -from dataclasses import dataclass -from pathlib import Path -from typing import Any import logging -import h5py - -from ptychodus.api.geometry import Interval from ptychodus.api.observer import Observable, Observer from ptychodus.api.patterns import ( DiffractionFileReader, DiffractionFileWriter, - DiffractionPatternArrayType, - DiffractionPatternState, ) from ptychodus.api.plugins import PluginChooser from ptychodus.api.settings import SettingsRegistry -from ptychodus.api.tree import SimpleTreeNode -from .active import ActiveDiffractionDataset from .api import PatternsAPI -from .builder import ActiveDiffractionDatasetBuilder -from .detector import Detector -from .io import DiffractionDatasetInputOutputPresenter -from .metadata import DiffractionMetadataPresenter -from .patterns import DiffractionPatternPresenter -from .settings import PatternSettings, ProductSettings +from .dataset import AssembledDiffractionDataset +from .settings import DetectorSettings, PatternSettings from .sizer import PatternSizer logger = logging.getLogger(__name__) -@dataclass(frozen=True) -class DiffractionPatternArrayPresenter: - label: str - state: DiffractionPatternState - data: DiffractionPatternArrayType | None - - @classmethod - def createNull(cls) -> DiffractionPatternArrayPresenter: - return cls(str(), DiffractionPatternState.UNKNOWN, None) - - -class DiffractionDatasetPresenter(Observable, Observer): - def __init__(self, settings: PatternSettings, dataset: ActiveDiffractionDataset) -> None: - super().__init__() - self._settings = settings - self._dataset = dataset - - settings.addObserver(self) - dataset.addObserver(self) - - def __iter__(self) -> Iterator[DiffractionPatternArrayPresenter]: - for array in self._dataset: - yield DiffractionPatternArrayPresenter( - label=array.getLabel(), - state=array.getState(), - data=array.getData(), - ) - - def __len__(self) -> int: - return len(self._dataset) - - def getInfoText(self) -> str: - return self._dataset.getInfoText() - - def isMemmapEnabled(self) -> bool: - return self._settings.memmapEnabled.getValue() - - def setMemmapEnabled(self, value: bool) -> None: - self._settings.memmapEnabled.setValue(value) - - def getScratchDirectory(self) -> Path: - return self._settings.scratchDirectory.getValue() - - def setScratchDirectory(self, directory: Path) -> None: - self._settings.scratchDirectory.setValue(directory) - - def getNumberOfDataThreadsLimits(self) -> Interval[int]: - return Interval[int](1, 64) - - def getNumberOfDataThreads(self) -> int: - limits = self.getNumberOfDataThreadsLimits() - return limits.clamp(self._settings.numberOfDataThreads.getValue()) - - def setNumberOfDataThreads(self, number: int) -> None: - self._settings.numberOfDataThreads.setValue(number) - - @property - def isAssembled(self) -> bool: - return len(self._dataset) > 0 - - def getContentsTree(self) -> SimpleTreeNode: - return self._dataset.getContentsTree() - - def openArray(self, dataPath: str) -> Any: # TODO generalize for other file formats - filePath = self._dataset.getMetadata().filePath - data = None - - if filePath and h5py.is_hdf5(filePath) and dataPath: - try: - with h5py.File(filePath, 'r') as h5File: - if dataPath in h5File: - item = h5File.get(dataPath) - - if isinstance(item, h5py.Dataset): - data = item[()] # TODO decode strings as needed - else: - parentPath, attrName = dataPath.rsplit('/', 1) - - if parentPath in h5File: - item = h5File.get(parentPath) - - if attrName in item.attrs: - attr = item.attrs[attrName] - stringInfo = h5py.check_string_dtype(attr.dtype) - - if stringInfo: - data = attr.decode(stringInfo.encoding) - else: - data = attr - except OSError: - logger.exception('Failed to open dataset!') - - return data - - def update(self, observable: Observable) -> None: - if observable is self._settings: - self.notifyObservers() - elif observable is self._dataset: - self.notifyObservers() - - -class PatternsCore: +class PatternsCore(Observer): def __init__( self, - settingsRegistry: SettingsRegistry, - fileReaderChooser: PluginChooser[DiffractionFileReader], - fileWriterChooser: PluginChooser[DiffractionFileWriter], + settings_registry: SettingsRegistry, + file_reader_chooser: PluginChooser[DiffractionFileReader], + file_writer_chooser: PluginChooser[DiffractionFileWriter], + reinit_observable: Observable, ) -> None: - self.detector = Detector(settingsRegistry) - self.patternSettings = PatternSettings(settingsRegistry) - self.productSettings = ProductSettings(settingsRegistry) - - # TODO vvv refactor vvv - fileReaderChooser.setCurrentPluginByName(self.patternSettings.fileType.getValue()) - fileWriterChooser.setCurrentPluginByName(self.patternSettings.fileType.getValue()) - # TODO ^^^^^^^^^^^^^^^^ - - self.patternSizer = PatternSizer.createInstance(self.patternSettings, self.detector) - self.patternPresenter = DiffractionPatternPresenter.createInstance( - self.patternSettings, self.patternSizer - ) - - self.dataset = ActiveDiffractionDataset(self.patternSettings, self.patternSizer) - self._builder = ActiveDiffractionDatasetBuilder(self.patternSettings, self.dataset) - self.patternsAPI = PatternsAPI( - self.patternSettings, - self._builder, + super().__init__() + self.detector_settings = DetectorSettings(settings_registry) + self.pattern_settings = PatternSettings(settings_registry) + self.pattern_sizer = PatternSizer(self.detector_settings, self.pattern_settings) + self.dataset = AssembledDiffractionDataset(self.pattern_settings, self.pattern_sizer) + self.patterns_api = PatternsAPI( + self.pattern_settings, + self.detector_settings, self.dataset, - fileReaderChooser, - fileWriterChooser, + file_reader_chooser, + file_writer_chooser, ) - self.metadataPresenter = DiffractionMetadataPresenter( - self.dataset, self.detector, self.patternSettings, self.productSettings - ) - self.datasetPresenter = DiffractionDatasetPresenter(self.patternSettings, self.dataset) - self.datasetInputOutputPresenter = DiffractionDatasetInputOutputPresenter( - self.patternSettings, self.dataset, self.patternsAPI, settingsRegistry - ) + file_reader_chooser.synchronize_with_parameter(self.pattern_settings.file_type) + file_writer_chooser.set_current_plugin(self.pattern_settings.file_type.get_value()) + + self._reinit_observable = reinit_observable + reinit_observable.add_observer(self) def start(self) -> None: pass def stop(self) -> None: - self._builder.stop(finishAssembling=False) + self.dataset.finish_loading(block=False) + + def _update(self, observable: Observable) -> None: + if observable is self._reinit_observable: + self.patterns_api.open_patterns( + file_path=self.pattern_settings.file_path.get_value(), + file_type=self.pattern_settings.file_type.get_value(), + ) + self.dataset.start_loading() diff --git a/src/ptychodus/model/patterns/dataset.py b/src/ptychodus/model/patterns/dataset.py new file mode 100644 index 00000000..0ee5570d --- /dev/null +++ b/src/ptychodus/model/patterns/dataset.py @@ -0,0 +1,420 @@ +from __future__ import annotations +from abc import ABC, abstractmethod +from bisect import bisect +from collections.abc import Iterator, Sequence +from dataclasses import dataclass +from pathlib import Path +from typing import overload +import logging +import queue +import tempfile +import threading + +import numpy +import numpy.typing + +from ptychodus.api.geometry import ImageExtent +from ptychodus.api.patterns import ( + DiffractionDataset, + DiffractionMetadata, + DiffractionPatternArray, + PatternDataType, + PatternIndexesType, +) +from ptychodus.api.tree import SimpleTreeNode +from ptychodus.api.typing import BooleanArrayType +from ptychodus.api.units import BYTES_PER_MEGABYTE + +from .settings import PatternSettings +from .sizer import PatternSizer + +logger = logging.getLogger(__name__) + +__all__ = [ + 'AssembledDiffractionDataset', + 'AssembledDiffractionPatternArray', + 'DiffractionDatasetObserver', +] + + +class DiffractionDatasetObserver(ABC): + @abstractmethod + def handle_array_inserted(self, index: int) -> None: + pass + + @abstractmethod + def handle_array_changed(self, index: int) -> None: + pass + + @abstractmethod + def handle_dataset_reloaded(self) -> None: + pass + + +class AssembledDiffractionPatternArray(DiffractionPatternArray): + def __init__( + self, + label: str, + indexes: PatternIndexesType, + data: PatternDataType, + good_pixels: BooleanArrayType, + array_index: int, + ) -> None: + super().__init__() + self._label = label + self._indexes = indexes + self._data = data + self._good_pixels = good_pixels + self._array_index = array_index + + @classmethod + def create_null(cls) -> AssembledDiffractionPatternArray: + indexes = numpy.array([0]) + data = numpy.zeros((1, 1, 1), dtype=numpy.uint16) + good_pixels = numpy.full((1, 1), True) + return cls('null', indexes, data, good_pixels, 0) + + def get_label(self) -> str: + return self._label + + def get_indexes(self) -> PatternIndexesType: + return self._indexes + + def get_data(self) -> PatternDataType: + return self._data + + def get_pattern(self, index: int) -> PatternDataType: + return self._data[index] + + def get_pattern_counts(self, index: int) -> int: + pattern = self._data[index] + return pattern[self._good_pixels].sum() + + def get_average_pattern(self) -> PatternDataType: + return self._data.mean(axis=0) + + def get_mean_pattern_counts(self) -> float: + loaded_data = self._data[self._indexes >= 0] + total_counts = numpy.sum(loaded_data[:, self._good_pixels], axis=-1) + return total_counts.mean() + + def get_max_pattern_counts(self) -> int: + loaded_data = self._data[self._indexes >= 0] + total_counts = numpy.sum(loaded_data[:, self._good_pixels], axis=-1) + return total_counts.max() + + def get_array_index(self) -> int: + return self._array_index + + +@dataclass(frozen=True) +class ArrayLoaderTask: + array: DiffractionPatternArray + index: int + + +class ArrayLoader: + def __init__(self, settings: PatternSettings, sizer: PatternSizer) -> None: + self._settings = settings + self._sizer = sizer + self._input_queue: queue.Queue[ArrayLoaderTask] = queue.Queue() + self._output_queue: queue.Queue[ArrayLoaderTask] = queue.Queue() + self._workers: list[threading.Thread] = list() + self._stop_work_event = threading.Event() + + @property + def input_queue_size(self) -> int: + return self._input_queue.qsize() + + @property + def output_queue_size(self) -> int: + return self._output_queue.qsize() + + def submit_task(self, task: ArrayLoaderTask) -> None: + self._input_queue.put(task) + + def _load_arrays(self) -> None: + processor = self._sizer.get_processor() + + while not self._stop_work_event.is_set(): + try: + task = self._input_queue.get(block=True, timeout=1) + except queue.Empty: + continue + + try: + processed_array = processor(task.array) + except FileNotFoundError: + logger.warning(f'File not found for array index={task.index}.') + except OSError: + logger.error(f'OS error while reading array index={task.index}.') + except Exception: + logger.exception(f'Error while loading array index={task.index}!') + else: + completed_task = ArrayLoaderTask(processed_array, task.index) + self._output_queue.put(completed_task) + finally: + self._input_queue.task_done() + + def completed_tasks(self) -> Iterator[ArrayLoaderTask]: + while True: + try: + task = self._output_queue.get(block=False) + except queue.Empty: + break + else: + self._output_queue.task_done() + + yield task + + def start(self) -> None: + logger.info('Starting data loader...') + self._stop_work_event.clear() + + # clear assembly queue + for _ in self.completed_tasks(): + pass + + for index in range(self._settings.num_data_threads.get_value()): + thread = threading.Thread(target=self._load_arrays) + thread.start() + self._workers.append(thread) + + logger.info('Data loader started.') + + def stop(self, *, finish_loading: bool) -> None: + if self._stop_work_event.is_set(): + logger.info('Data loader already stopped.') + return + + logger.info('Stopping data loader...') + + if finish_loading: + self._input_queue.join() + else: + # clear loading queue + while True: + try: + self._input_queue.get(block=False) + except queue.Empty: + break + else: + self._input_queue.task_done() + + self._stop_work_event.set() + + while self._workers: + thread = self._workers.pop() + thread.join() + + logger.info('Data loader stopped.') + + +class AssembledDiffractionDataset(DiffractionDataset): + def __init__(self, settings: PatternSettings, sizer: PatternSizer) -> None: + super().__init__() + self._settings = settings + self._sizer = sizer + self._loader = ArrayLoader(settings, sizer) + self._observer_list: list[DiffractionDatasetObserver] = [] + + self._contents_tree = SimpleTreeNode.create_root([]) + self._metadata = DiffractionMetadata.create_null() + self._indexes: PatternIndexesType = numpy.zeros((), dtype=int) + self._data: PatternDataType = numpy.zeros((0, 0, 0), dtype=int) + self._arrays: list[AssembledDiffractionPatternArray] = list() + self._array_counter = 0 + + @property + def queue_size(self) -> int: + return self._loader.input_queue_size + self._loader.output_queue_size + + def start_loading(self) -> None: + pattern_extent = self._sizer.get_processed_image_extent() + data_shape = self._indexes.size, *pattern_extent.shape + data_dtype = self._metadata.pattern_dtype + + if self._settings.is_memmap_enabled.get_value(): + scratch_dir = self._settings.scratch_directory.get_value() + scratch_dir.mkdir(mode=0o755, parents=True, exist_ok=True) + npy_tmp_file = tempfile.NamedTemporaryFile(dir=scratch_dir, suffix='.npy') + logger.info(f'Scratch data file {npy_tmp_file.name} is {data_shape}') + self._data = numpy.memmap(npy_tmp_file, dtype=data_dtype, shape=data_shape) + self._data[:] = 0 + else: + logger.info(f'Scratch memory is {data_shape}') + self._data = numpy.zeros(data_shape, dtype=data_dtype) + + for observer in self._observer_list: + observer.handle_dataset_reloaded() + + self._loader.start() + + def finish_loading(self, *, block: bool = True) -> None: + self._loader.stop(finish_loading=block) + + def add_observer(self, observer: DiffractionDatasetObserver) -> None: + if observer not in self._observer_list: + self._observer_list.append(observer) + + def remove_observer(self, observer: DiffractionDatasetObserver) -> None: + try: + self._observer_list.remove(observer) + except ValueError: + pass + + def get_contents_tree(self) -> SimpleTreeNode: + return self._contents_tree + + def get_metadata(self) -> DiffractionMetadata: + return self._metadata + + def get_processed_bad_pixels(self) -> BooleanArrayType: + # TODO support loading from file + # TODO keep consist with processed patterns + pattern_extent = self._sizer.get_processed_image_extent() + return numpy.full(pattern_extent.shape, False) + + def get_assembled_indexes(self) -> PatternIndexesType: + return self._indexes[self._indexes >= 0] + + def get_assembled_patterns(self) -> PatternDataType: + return self._data[self._indexes >= 0] + + def get_maximum_pattern_counts(self) -> int: + patterns = self.get_assembled_patterns() + good_pixels = numpy.logical_not(self.get_processed_bad_pixels()) + try: + total_counts = numpy.sum(patterns[:, good_pixels], axis=-1) + except IndexError: + # patterns not loaded + return 0 + + return total_counts.max() + + @overload + def __getitem__(self, index: int) -> AssembledDiffractionPatternArray: ... + + @overload + def __getitem__(self, index: slice) -> Sequence[AssembledDiffractionPatternArray]: ... + + def __getitem__( + self, index: int | slice + ) -> AssembledDiffractionPatternArray | Sequence[AssembledDiffractionPatternArray]: + return self._arrays[index] + + def __len__(self) -> int: + return len(self._arrays) + + def append_array(self, array: DiffractionPatternArray) -> None: + """Load a new array into the dataset. Assumes that arrays arrive in order.""" + task = ArrayLoaderTask(array, int(self._array_counter)) + self._array_counter += 1 + self._loader.submit_task(task) + + def assemble_patterns(self) -> None: + for task in self._loader.completed_tasks(): + array_size = self._metadata.num_patterns_per_array + array_slice = slice(task.index * array_size, (task.index + 1) * array_size) + + self._indexes[array_slice] = task.array.get_indexes() + pattern_indexes = self._indexes[array_slice] + pattern_indexes.flags.writeable = False + + self._data[array_slice, :, :] = task.array.get_data() + pattern_data = self._data[array_slice, :, :] + pattern_data.flags.writeable = False + + array = AssembledDiffractionPatternArray( + label=task.array.get_label(), + indexes=pattern_indexes, + data=pattern_data, + good_pixels=numpy.logical_not(self.get_processed_bad_pixels()), + array_index=task.index, + ) + + pos = bisect(self._arrays, array.get_array_index(), key=lambda x: x.get_array_index()) + self._arrays.insert(pos, array) + + for observer in self._observer_list: + observer.handle_array_inserted(pos) + + def clear(self) -> None: + self._loader.stop(finish_loading=False) + self._contents_tree = SimpleTreeNode.create_root([]) + self._metadata = DiffractionMetadata.create_null() + self._indexes = numpy.zeros((), dtype=int) + self._data = numpy.zeros((0, 0, 0), dtype=int) + self._arrays.clear() + self._array_counter = 0 + + for _ in self._loader.completed_tasks(): + pass + + for observer in self._observer_list: + observer.handle_dataset_reloaded() + + def reload(self, dataset: DiffractionDataset) -> None: + self.clear() + self._contents_tree = dataset.get_contents_tree() + self._metadata = dataset.get_metadata() + self._indexes = -numpy.ones(self._metadata.num_patterns_total, dtype=int) + + for observer in self._observer_list: + observer.handle_dataset_reloaded() + + for array in dataset: + self.append_array(array) + + def import_assembled_patterns(self, file_path: Path) -> None: + if file_path.is_file(): + self.clear() + logger.debug(f'Reading processed patterns from "{file_path}"') + + try: + contents = numpy.load(file_path) + except Exception as exc: + raise RuntimeError(f'Failed to read "{file_path}"') from exc + + self._indexes = contents['indexes'] + self._data = contents['patterns'] + num_patterns, detector_height, detector_width = self._data.shape + + self._contents_tree = SimpleTreeNode.create_root(['Name', 'Type', 'Details']) + self._metadata = DiffractionMetadata( + num_patterns_per_array=num_patterns, + num_patterns_total=num_patterns, + pattern_dtype=self._data.dtype, + detector_extent=ImageExtent(detector_width, detector_height), + ) + self._arrays = [ + AssembledDiffractionPatternArray( + label='Imported', + indexes=self._indexes, + data=self._data, + good_pixels=numpy.logical_not(self.get_processed_bad_pixels()), + array_index=0, + ) + ] + self._array_counter = 1 + + for observer in self._observer_list: + observer.handle_dataset_reloaded() + else: + logger.warning(f'Refusing to read invalid file path {file_path}') + + def export_assembled_patterns(self, file_path: Path) -> None: + logger.debug(f'Writing processed patterns to "{file_path}"') + numpy.savez( + file_path, + indexes=self.get_assembled_indexes(), + patterns=self.get_assembled_patterns(), + ) + + def get_info_text(self) -> str: + file_path = self._metadata.file_path + label = file_path.stem if file_path else 'None' + number, height, width = self._data.shape + dtype = str(self._data.dtype) + size_MB = self._data.nbytes / BYTES_PER_MEGABYTE # noqa: N806 + return f'{label}: {number} x {width}W x {height}H {dtype} [{size_MB:.2f}MB]' diff --git a/src/ptychodus/model/patterns/detector.py b/src/ptychodus/model/patterns/detector.py deleted file mode 100644 index 662afd45..00000000 --- a/src/ptychodus/model/patterns/detector.py +++ /dev/null @@ -1,50 +0,0 @@ -from __future__ import annotations - -from ptychodus.api.geometry import ImageExtent, PixelGeometry -from ptychodus.api.observer import Observable, Observer -from ptychodus.api.settings import SettingsRegistry - - -class Detector(Observable, Observer): - def __init__(self, registry: SettingsRegistry) -> None: - super().__init__() - self._settingsGroup = registry.createGroup('Detector') - self._settingsGroup.addObserver(self) - - self.widthInPixels = self._settingsGroup.createIntegerParameter( - 'WidthInPixels', 1024, minimum=0 - ) - self.pixelWidthInMeters = self._settingsGroup.createRealParameter( - 'PixelWidthInMeters', 75e-6, minimum=0.0 - ) - self.heightInPixels = self._settingsGroup.createIntegerParameter( - 'HeightInPixels', 1024, minimum=0 - ) - self.pixelHeightInMeters = self._settingsGroup.createRealParameter( - 'PixelHeightInMeters', 75e-6, minimum=0.0 - ) - self.bitDepth = self._settingsGroup.createIntegerParameter('BitDepth', 8, minimum=1) - - def getImageExtent(self) -> ImageExtent: - return ImageExtent( - widthInPixels=self.widthInPixels.getValue(), - heightInPixels=self.heightInPixels.getValue(), - ) - - def setImageExtent(self, imageExtent: ImageExtent) -> None: - self.widthInPixels.setValue(imageExtent.widthInPixels) - self.heightInPixels.setValue(imageExtent.heightInPixels) - - def getPixelGeometry(self) -> PixelGeometry: - return PixelGeometry( - widthInMeters=self.pixelWidthInMeters.getValue(), - heightInMeters=self.pixelHeightInMeters.getValue(), - ) - - def setPixelGeometry(self, pixelGeometry: PixelGeometry) -> None: - self.pixelWidthInMeters.setValue(pixelGeometry.widthInMeters) - self.pixelHeightInMeters.setValue(pixelGeometry.heightInMeters) - - def update(self, observable: Observable) -> None: - if observable is self._settingsGroup: - self.notifyObservers() diff --git a/src/ptychodus/model/patterns/io.py b/src/ptychodus/model/patterns/io.py deleted file mode 100644 index c3a22a35..00000000 --- a/src/ptychodus/model/patterns/io.py +++ /dev/null @@ -1,78 +0,0 @@ -from __future__ import annotations -from collections.abc import Sequence -from pathlib import Path -import logging - -from ptychodus.api.observer import Observable, Observer - -from .active import ActiveDiffractionDataset -from .api import PatternsAPI -from .settings import PatternSettings - -logger = logging.getLogger(__name__) - - -class DiffractionDatasetInputOutputPresenter(Observable, Observer): - def __init__( - self, - settings: PatternSettings, - dataset: ActiveDiffractionDataset, - patternsAPI: PatternsAPI, - reinitObservable: Observable, - ) -> None: - super().__init__() - self._settings = settings - self._dataset = dataset - self._patternsAPI = patternsAPI - self._reinitObservable = reinitObservable - - reinitObservable.addObserver(self) - - def getOpenFileFilterList(self) -> Sequence[str]: - return self._patternsAPI.getOpenFileFilterList() - - def getOpenFileFilter(self) -> str: - return self._patternsAPI.getOpenFileFilter() - - def openDiffractionFile(self, filePath: Path, fileFilter: str) -> None: - try: - fileType = self._patternsAPI.openPatterns( - filePath=filePath, fileType=fileFilter, assemble=False - ) - except Exception: - logger.exception('Failed to load diffraction dataset.') - return - - if fileType is None: - logger.error('Failed to load diffraction dataset.') - else: - self._settings.fileType.setValue(fileType) - self._settings.filePath.setValue(filePath) - - self.notifyObservers() - - def _openDiffractionFileFromSettings(self) -> None: - self._patternsAPI.openPatterns( - filePath=self._settings.filePath.getValue(), - fileType=self._settings.fileType.getValue(), - assemble=True, - ) - - def startAssemblingDiffractionPatterns(self) -> None: - self._patternsAPI.startAssemblingDiffractionPatterns() - - def stopAssemblingDiffractionPatterns(self, finishAssembling: bool) -> None: - self._patternsAPI.stopAssemblingDiffractionPatterns(finishAssembling) - - def getSaveFileFilterList(self) -> Sequence[str]: - return self._patternsAPI.getSaveFileFilterList() - - def getSaveFileFilter(self) -> str: - return self._patternsAPI.getSaveFileFilter() - - def saveDiffractionFile(self, filePath: Path, fileFilter: str) -> None: - self._patternsAPI.savePatterns(filePath, fileFilter) - - def update(self, observable: Observable) -> None: - if observable is self._reinitObservable: - self._openDiffractionFileFromSettings() diff --git a/src/ptychodus/model/patterns/metadata.py b/src/ptychodus/model/patterns/metadata.py deleted file mode 100644 index dfa2c614..00000000 --- a/src/ptychodus/model/patterns/metadata.py +++ /dev/null @@ -1,116 +0,0 @@ -from __future__ import annotations - -from ptychodus.api.observer import Observable, Observer -from ptychodus.api.patterns import DiffractionDataset, DiffractionMetadata - -from .detector import Detector -from .settings import PatternSettings, ProductSettings - - -class DiffractionMetadataPresenter(Observable, Observer): - def __init__( - self, - diffractionDataset: DiffractionDataset, - detector: Detector, - patternSettings: PatternSettings, - productSettings: ProductSettings, - ) -> None: - super().__init__() - self._diffractionDataset = diffractionDataset - self._detector = detector - self._patternSettings = patternSettings - self._productSettings = productSettings - - diffractionDataset.addObserver(self) - - @property - def _metadata(self) -> DiffractionMetadata: - return self._diffractionDataset.getMetadata() - - def canSyncDetectorPixelCount(self) -> bool: - return self._metadata.detectorExtent is not None - - def syncDetectorPixelCount(self) -> None: - detectorExtent = self._metadata.detectorExtent - - if detectorExtent: - self._detector.setImageExtent(detectorExtent) - - def canSyncDetectorPixelSize(self) -> bool: - return self._metadata.detectorPixelGeometry is not None - - def syncDetectorPixelSize(self) -> None: - pixelGeometry = self._metadata.detectorPixelGeometry - - if pixelGeometry: - self._detector.setPixelGeometry(pixelGeometry) - - def canSyncDetectorBitDepth(self) -> bool: - return self._metadata.detectorBitDepth is not None - - def syncDetectorBitDepth(self) -> None: - bitDepth = self._metadata.detectorBitDepth - - if bitDepth: - self._detector.bitDepth.setValue(bitDepth) - - def canSyncPatternCropCenter(self) -> bool: - return self._metadata.cropCenter is not None or self._metadata.detectorExtent is not None - - def canSyncPatternCropExtent(self) -> bool: - return self._metadata.detectorExtent is not None - - def syncPatternCrop(self, syncCenter: bool, syncExtent: bool) -> None: - if syncCenter: - cropCenter = self._metadata.cropCenter - - if cropCenter: - self._patternSettings.cropCenterXInPixels.setValue(cropCenter.positionXInPixels) - self._patternSettings.cropCenterYInPixels.setValue(cropCenter.positionYInPixels) - elif self._metadata.detectorExtent: - self._patternSettings.cropCenterXInPixels.setValue( - int(self._metadata.detectorExtent.widthInPixels) // 2 - ) - self._patternSettings.cropCenterYInPixels.setValue( - int(self._metadata.detectorExtent.heightInPixels) // 2 - ) - - if syncExtent and self._metadata.detectorExtent: - centerX = self._patternSettings.cropCenterXInPixels.getValue() - centerY = self._patternSettings.cropCenterYInPixels.getValue() - - extentX = int(self._metadata.detectorExtent.widthInPixels) - extentY = int(self._metadata.detectorExtent.heightInPixels) - - maxRadiusX = min(centerX, extentX - centerX) - maxRadiusY = min(centerY, extentY - centerY) - maxRadius = min(maxRadiusX, maxRadiusY) - cropDiameterInPixels = 1 - - while cropDiameterInPixels < maxRadius: - cropDiameterInPixels <<= 1 - - self._patternSettings.cropWidthInPixels.setValue(cropDiameterInPixels) - self._patternSettings.cropHeightInPixels.setValue(cropDiameterInPixels) - - def canSyncProbeEnergy(self) -> bool: - return self._metadata.probeEnergyInElectronVolts is not None - - def syncProbeEnergy(self) -> None: - energyInElectronVolts = self._metadata.probeEnergyInElectronVolts - - if energyInElectronVolts: - self._productSettings.probeEnergyInElectronVolts.setValue(energyInElectronVolts) - - def canSyncDetectorDistance(self) -> bool: - return self._metadata.detectorDistanceInMeters is not None - - def syncDetectorDistance(self) -> None: - distanceInMeters = self._metadata.detectorDistanceInMeters - - if distanceInMeters: - self._productSettings.detectorDistanceInMeters.setValue(distanceInMeters) - - def update(self, observable: Observable) -> None: - if observable is self._diffractionDataset: - self.notifyObservers() diff --git a/src/ptychodus/model/patterns/patterns.py b/src/ptychodus/model/patterns/patterns.py deleted file mode 100644 index 181c7d86..00000000 --- a/src/ptychodus/model/patterns/patterns.py +++ /dev/null @@ -1,113 +0,0 @@ -from __future__ import annotations -from typing import Final - -from ptychodus.api.geometry import Interval -from ptychodus.api.observer import Observable, Observer - -from .settings import PatternSettings -from .sizer import PatternSizer - - -class DiffractionPatternPresenter(Observable, Observer): - MAX_INT: Final[int] = 0x7FFFFFFF - - def __init__(self, settings: PatternSettings, sizer: PatternSizer) -> None: - super().__init__() - self._settings = settings - self._sizer = sizer - - @classmethod - def createInstance( - cls, settings: PatternSettings, sizer: PatternSizer - ) -> DiffractionPatternPresenter: - presenter = cls(settings, sizer) - sizer.addObserver(presenter) - return presenter - - def isCropEnabled(self) -> bool: - return self._sizer.isCropEnabled() - - def setCropEnabled(self, value: bool) -> None: - self._sizer.setCropEnabled(value) - - def getCropCenterXLimitsInPixels(self) -> Interval[int]: - return self._sizer.getCenterXLimitsInPixels() - - def getCropCenterXInPixels(self) -> int: - return self._sizer.getCenterXInPixels() - - def setCropCenterXInPixels(self, value: int) -> None: - self._settings.cropCenterXInPixels.setValue(value) - - def getCropCenterYLimitsInPixels(self) -> Interval[int]: - return self._sizer.getCenterYLimitsInPixels() - - def getCropCenterYInPixels(self) -> int: - return self._sizer.getCenterYInPixels() - - def setCropCenterYInPixels(self, value: int) -> None: - self._settings.cropCenterYInPixels.setValue(value) - - def getCropWidthLimitsInPixels(self) -> Interval[int]: - return self._sizer.getWidthLimitsInPixels() - - def getCropWidthInPixels(self) -> int: - return self._sizer.getWidthInPixels() - - def setCropWidthInPixels(self, value: int) -> None: - self._settings.cropWidthInPixels.setValue(value) - - def getCropHeightLimitsInPixels(self) -> Interval[int]: - return self._sizer.getHeightLimitsInPixels() - - def getCropHeightInPixels(self) -> int: - return self._sizer.getHeightInPixels() - - def setCropHeightInPixels(self, value: int) -> None: - self._settings.cropHeightInPixels.setValue(value) - - def isFlipXEnabled(self) -> bool: - return self._settings.flipXEnabled.getValue() - - def setFlipXEnabled(self, value: bool) -> None: - self._settings.flipXEnabled.setValue(value) - - def isFlipYEnabled(self) -> bool: - return self._settings.flipYEnabled.getValue() - - def setFlipYEnabled(self, value: bool) -> None: - self._settings.flipYEnabled.setValue(value) - - def isValueLowerBoundEnabled(self) -> bool: - return self._settings.valueLowerBoundEnabled.getValue() - - def setValueLowerBoundEnabled(self, value: bool) -> None: - self._settings.valueLowerBoundEnabled.setValue(value) - - def getValueLowerBoundLimits(self) -> Interval[int]: - return Interval[int](0, self.MAX_INT) - - def getValueLowerBound(self) -> int: - return self._settings.valueLowerBound.getValue() - - def setValueLowerBound(self, value: int) -> None: - self._settings.valueLowerBound.setValue(value) - - def isValueUpperBoundEnabled(self) -> bool: - return self._settings.valueUpperBoundEnabled.getValue() - - def setValueUpperBoundEnabled(self, value: bool) -> None: - self._settings.valueUpperBoundEnabled.setValue(value) - - def getValueUpperBoundLimits(self) -> Interval[int]: - return Interval[int](0, self.MAX_INT) - - def getValueUpperBound(self) -> int: - return self._settings.valueUpperBound.getValue() - - def setValueUpperBound(self, value: int) -> None: - self._settings.valueUpperBound.setValue(value) - - def update(self, observable: Observable) -> None: - if observable is self._sizer: - self.notifyObservers() diff --git a/src/ptychodus/model/patterns/processor.py b/src/ptychodus/model/patterns/processor.py new file mode 100644 index 00000000..6ba68cff --- /dev/null +++ b/src/ptychodus/model/patterns/processor.py @@ -0,0 +1,100 @@ +from __future__ import annotations +from dataclasses import dataclass + +import numpy + +from ptychodus.api.geometry import ImageExtent +from ptychodus.api.patterns import ( + CropCenter, + DiffractionPatternArray, + PatternDataType, + SimpleDiffractionPatternArray, +) + + +@dataclass(frozen=True) +class DiffractionPatternFilterValues: + lower_bound: int | None + upper_bound: int | None + + def apply(self, data: PatternDataType) -> PatternDataType: + if self.lower_bound is not None: + data[data < self.lower_bound] = 0 + + if self.upper_bound is not None: + data[data >= self.upper_bound] = 0 + + return data + + +class DiffractionPatternCrop: + def __init__(self, center: CropCenter, extent: ImageExtent) -> None: + center_x = center.position_x_px + radius_x = extent.width_px // 2 + self.slice_x = slice(center_x - radius_x, center_x + radius_x) + + center_y = center.position_y_px + radius_y = extent.height_px // 2 + self.slice_y = slice(center_y - radius_y, center_y + radius_y) + + def apply(self, data: PatternDataType) -> PatternDataType: + return data[:, self.slice_y, self.slice_x] + + +@dataclass(frozen=True) +class DiffractionPatternBinning: + bin_size_x: int + bin_size_y: int + + def apply(self, data: PatternDataType) -> PatternDataType: + binned_width = data.shape[-1] // self.bin_size_x + binned_height = data.shape[-2] // self.bin_size_y + shape = (-1, binned_height, self.bin_size_y, binned_width, self.bin_size_x) + return numpy.sum(data.reshape(shape), axis=(-3, -1), keepdims=False) + + +@dataclass(frozen=True) +class DiffractionPatternPadding: + pad_x: int + pad_y: int + + def apply(self, data: PatternDataType) -> PatternDataType: + pad_width = (0, 0, self.pad_y, self.pad_y, self.pad_x, self.pad_x) + return numpy.pad(data, pad_width, mode='constant', constant_values=0) + + +@dataclass(frozen=True) +class DiffractionPatternProcessor: + crop: DiffractionPatternCrop | None + filter_values: DiffractionPatternFilterValues | None + binning: DiffractionPatternBinning | None + padding: DiffractionPatternPadding | None + flip_x: bool + flip_y: bool + + def __call__(self, array: DiffractionPatternArray) -> DiffractionPatternArray: + data = array.get_data() + + if data.ndim != 3: + raise ValueError(f'Invalid diffraction pattern dimensions! (shape={data.shape})') + + if self.crop is not None: + data = self.crop.apply(data) + + if self.filter_values is not None: + data = self.filter_values.apply(data) + + if self.binning is not None: + # TODO handle binning with bad pixels + data = self.binning.apply(data) + + if self.padding is not None: + data = self.padding.apply(data) + + if self.flip_y: + data = numpy.flip(data, axis=-2) + + if self.flip_x: + data = numpy.flip(data, axis=-1) + + return SimpleDiffractionPatternArray(array.get_label(), array.get_indexes(), data) diff --git a/src/ptychodus/model/patterns/settings.py b/src/ptychodus/model/patterns/settings.py index 2a7ed846..ac71fe14 100644 --- a/src/ptychodus/model/patterns/settings.py +++ b/src/ptychodus/model/patterns/settings.py @@ -4,72 +4,81 @@ from ptychodus.api.settings import SettingsRegistry -class PatternSettings(Observable, Observer): +class DetectorSettings(Observable, Observer): def __init__(self, registry: SettingsRegistry) -> None: super().__init__() - self._settingsGroup = registry.createGroup('Patterns') - self._settingsGroup.addObserver(self) + self._group = registry.create_group('Detector') + self._group.add_observer(self) - self.fileType = self._settingsGroup.createStringParameter('FileType', 'HDF5') - self.filePath = self._settingsGroup.createPathParameter( - 'FilePath', Path('/path/to/data.h5') + self.width_px = self._group.create_integer_parameter('WidthInPixels', 1024, minimum=1) + self.pixel_width_m = self._group.create_real_parameter( + 'PixelWidthInMeters', 75e-6, minimum=0.0 + ) + self.height_px = self._group.create_integer_parameter('HeightInPixels', 1024, minimum=1) + self.pixel_height_m = self._group.create_real_parameter( + 'PixelHeightInMeters', 75e-6, minimum=0.0 ) - self.memmapEnabled = self._settingsGroup.createBooleanParameter('MemmapEnabled', False) - self.scratchDirectory = self._settingsGroup.createPathParameter( + self.bit_depth = self._group.create_integer_parameter('BitDepth', 8, minimum=1) + + def _update(self, observable: Observable) -> None: + if observable is self._group: + self.notify_observers() + + +class PatternSettings(Observable, Observer): + def __init__(self, registry: SettingsRegistry) -> None: + super().__init__() + self._group = registry.create_group('Patterns') + self._group.add_observer(self) + + self.file_type = self._group.create_string_parameter('FileType', 'NeXus') + self.file_path = self._group.create_path_parameter('FilePath', Path('/path/to/data.h5')) + self.is_memmap_enabled = self._group.create_boolean_parameter('MemmapEnabled', False) + self.scratch_directory = self._group.create_path_parameter( 'ScratchDirectory', Path.home() / '.ptychodus' ) - self.numberOfDataThreads = self._settingsGroup.createIntegerParameter( - 'NumberOfDataThreads', 8 + self.num_data_threads = self._group.create_integer_parameter( + 'NumberOfDataThreads', 8, minimum=1, maximum=64 ) - self.cropEnabled = self._settingsGroup.createBooleanParameter('CropEnabled', True) - self.cropCenterXInPixels = self._settingsGroup.createIntegerParameter( - 'CropCenterXInPixels', 32 - ) - self.cropCenterYInPixels = self._settingsGroup.createIntegerParameter( - 'CropCenterYInPixels', 32 + self.is_crop_enabled = self._group.create_boolean_parameter('CropEnabled', True) + self.crop_center_x_px = self._group.create_integer_parameter( + 'CropCenterXInPixels', 32, minimum=0 ) - self.cropWidthInPixels = self._settingsGroup.createIntegerParameter('CropWidthInPixels', 64) - self.cropHeightInPixels = self._settingsGroup.createIntegerParameter( - 'CropHeightInPixels', 64 + self.crop_center_y_px = self._group.create_integer_parameter( + 'CropCenterYInPixels', 32, minimum=0 ) - self.flipXEnabled = self._settingsGroup.createBooleanParameter('FlipXEnabled', False) - self.flipYEnabled = self._settingsGroup.createBooleanParameter('FlipYEnabled', False) - self.valueLowerBoundEnabled = self._settingsGroup.createBooleanParameter( - 'ValueLowerBoundEnabled', False + self.crop_width_px = self._group.create_integer_parameter( + 'CropWidthInPixels', 64, minimum=1 ) - self.valueLowerBound = self._settingsGroup.createIntegerParameter('ValueLowerBound', 0) - self.valueUpperBoundEnabled = self._settingsGroup.createBooleanParameter( - 'ValueUpperBoundEnabled', False + self.crop_height_px = self._group.create_integer_parameter( + 'CropHeightInPixels', 64, minimum=1 ) - self.valueUpperBound = self._settingsGroup.createIntegerParameter('ValueUpperBound', 65535) - def update(self, observable: Observable) -> None: - if observable is self._settingsGroup: - self.notifyObservers() + self.is_binning_enabled = self._group.create_boolean_parameter('BinningEnabled', False) + self.bin_size_x = self._group.create_integer_parameter('BinSizeX', 1, minimum=1) + self.bin_size_y = self._group.create_integer_parameter('BinSizeY', 1, minimum=1) + self.is_padding_enabled = self._group.create_boolean_parameter('PaddingEnabled', False) + self.pad_x = self._group.create_integer_parameter('PadX', 0, minimum=0) + self.pad_y = self._group.create_integer_parameter('PadY', 0, minimum=0) -class ProductSettings(Observable, Observer): - def __init__(self, registry: SettingsRegistry) -> None: - super().__init__() - self._settingsGroup = registry.createGroup('Products') - self._settingsGroup.addObserver(self) + self.is_flip_x_enabled = self._group.create_boolean_parameter('FlipXEnabled', False) + self.is_flip_y_enabled = self._group.create_boolean_parameter('FlipYEnabled', False) - self.name = self._settingsGroup.createStringParameter('Name', 'Unnamed') - self.fileType = self._settingsGroup.createStringParameter('FileType', 'HDF5') - self.detectorDistanceInMeters = self._settingsGroup.createRealParameter( - 'DetectorDistanceInMeters', 1.0, minimum=0.0 + self.is_value_lower_bound_enabled = self._group.create_boolean_parameter( + 'ValueLowerBoundEnabled', False ) - self.probeEnergyInElectronVolts = self._settingsGroup.createRealParameter( - 'ProbeEnergyInElectronVolts', 10000.0, minimum=0.0 + self.value_lower_bound = self._group.create_integer_parameter( + 'ValueLowerBound', 0, minimum=0 ) - self.probePhotonsPerSecond = self._settingsGroup.createRealParameter( - 'ProbePhotonsPerSecond', 0.0, minimum=0.0 + self.is_value_upper_bound_enabled = self._group.create_boolean_parameter( + 'ValueUpperBoundEnabled', False ) - self.exposureTimeInSeconds = self._settingsGroup.createRealParameter( - 'ExposureTimeInSeconds', 0.0, minimum=0.0 + self.value_upper_bound = self._group.create_integer_parameter( + 'ValueUpperBound', 65535, minimum=0 ) - def update(self, observable: Observable) -> None: - if observable is self._settingsGroup: - self.notifyObservers() + def _update(self, observable: Observable) -> None: + if observable is self._group: + self.notify_observers() diff --git a/src/ptychodus/model/patterns/sizer.py b/src/ptychodus/model/patterns/sizer.py index 11011c06..00c8d054 100644 --- a/src/ptychodus/model/patterns/sizer.py +++ b/src/ptychodus/model/patterns/sizer.py @@ -1,125 +1,224 @@ -from __future__ import annotations - -from ptychodus.api.geometry import ImageExtent, Interval +from ptychodus.api.geometry import ImageExtent, Interval, PixelGeometry from ptychodus.api.observer import Observable, Observer -from ptychodus.api.patterns import DiffractionPatternArrayType - -from .detector import Detector -from .settings import PatternSettings +from ptychodus.api.parametric import BooleanParameter, IntegerParameter, RealParameter +from ptychodus.api.patterns import CropCenter + +from .processor import ( + DiffractionPatternBinning, + DiffractionPatternCrop, + DiffractionPatternFilterValues, + DiffractionPatternPadding, + DiffractionPatternProcessor, +) +from .settings import DetectorSettings, PatternSettings + + +class PatternAxisSizer(Observable, Observer): + def __init__( + self, + detector_size: IntegerParameter, + detector_pixel_size_m: RealParameter, + crop_enabled: BooleanParameter, + crop_size: IntegerParameter, + crop_center: IntegerParameter, + binning_enabled: BooleanParameter, + bin_size: IntegerParameter, + padding_enabled: BooleanParameter, + pad_size: IntegerParameter, + ) -> None: + super().__init__() + self._detector_size = detector_size + self._detector_pixel_size_m = detector_pixel_size_m + self._crop_enabled = crop_enabled + self._crop_size = crop_size + self._crop_center = crop_center + self._binning_enabled = binning_enabled + self._bin_size = bin_size + self._padding_enabled = padding_enabled + self._pad_size = pad_size + + detector_size.add_observer(self) + detector_pixel_size_m.add_observer(self) + crop_enabled.add_observer(self) + crop_size.add_observer(self) + crop_center.add_observer(self) + binning_enabled.add_observer(self) + bin_size.add_observer(self) + padding_enabled.add_observer(self) + pad_size.add_observer(self) + + def get_detector_size(self) -> int: + return self._detector_size.get_value() + + def get_crop_size_limits(self) -> Interval[int]: + return Interval[int](1, self.get_detector_size()) + + def get_crop_size(self) -> int: + if self._crop_enabled.get_value(): + limits = self.get_crop_size_limits() + return limits.clamp(self._crop_size.get_value()) + + return self.get_detector_size() + + def get_crop_center_limits(self) -> Interval[int]: + xmin = (self.get_crop_size() + 1) // 2 + xmax = self.get_detector_size() - 1 - xmin + return Interval[int](xmin, xmax) + + def get_crop_center(self) -> int: + limits = self.get_crop_center_limits() + return limits.clamp(self._crop_center.get_value()) + + def get_bin_size_limits(self) -> Interval[int]: + return Interval[int](1, self.get_crop_size()) + + def get_bin_size(self) -> int: + if self._binning_enabled.get_value(): + limits = self.get_bin_size_limits() + return limits.clamp(self._bin_size.get_value()) + + return 1 + + def validate_bin_size(self) -> None: + crop_size = self.get_crop_size() + bin_size = self.get_bin_size() + + if crop_size % bin_size != 0: + raise ValueError(f'Invalid binning size! ({crop_size=}, {bin_size=})') + + def get_pad_size(self) -> int: + if self._padding_enabled.get_value(): + return self._pad_size.get_value() + + return 0 + + def get_processed_size(self) -> int: + return self.get_crop_size() // self.get_bin_size() + self.get_pad_size() + + def get_processed_pixel_size_m(self) -> float: + return self.get_bin_size() * self._detector_pixel_size_m.get_value() + + def get_processed_size_m(self) -> float: + return self.get_processed_size() * self.get_processed_pixel_size_m() + + def _update(self, observable: Observable) -> None: + if observable in ( + self._detector_size, + self._detector_pixel_size_m, + self._crop_enabled, + self._crop_size, + self._crop_center, + self._binning_enabled, + self._bin_size, + self._padding_enabled, + self._pad_size, + ): + self.notify_observers() class PatternSizer(Observable, Observer): - def __init__(self, settings: PatternSettings, detector: Detector) -> None: + def __init__( + self, detector_settings: DetectorSettings, pattern_settings: PatternSettings + ) -> None: super().__init__() - self._settings = settings - self._detector = detector - self._sliceX = slice(0) - self._sliceY = slice(0) - - @classmethod - def createInstance(cls, settings: PatternSettings, detector: Detector) -> PatternSizer: - sizer = cls(settings, detector) - sizer._updateSlicesAndNotifyObservers() - settings.addObserver(sizer) - detector.addObserver(sizer) - return sizer - - def isCropEnabled(self) -> bool: - return self._settings.cropEnabled.getValue() - - def setCropEnabled(self, value: bool) -> None: - self._settings.cropEnabled.setValue(value) - - def getWidthLimitsInPixels(self) -> Interval[int]: - return Interval[int](1, self._detector.getImageExtent().widthInPixels) - - def getWidthInPixels(self) -> int: - limitsInPixels = self.getWidthLimitsInPixels() - return ( - limitsInPixels.clamp(self._settings.cropWidthInPixels.getValue()) - if self.isCropEnabled() - else limitsInPixels.upper + self._pattern_settings = pattern_settings + self.axis_x = PatternAxisSizer( + detector_settings.width_px, + detector_settings.pixel_width_m, + pattern_settings.is_crop_enabled, + pattern_settings.crop_width_px, + pattern_settings.crop_center_x_px, + pattern_settings.is_binning_enabled, + pattern_settings.bin_size_x, + pattern_settings.is_padding_enabled, + pattern_settings.pad_x, ) - - def getCenterXLimitsInPixels(self) -> Interval[int]: - return Interval[int](0, self._detector.getImageExtent().widthInPixels) - - def getCenterXInPixels(self) -> int: - limitsInPixels = self.getCenterXLimitsInPixels() - return ( - limitsInPixels.clamp(self._settings.cropCenterXInPixels.getValue()) - if self.isCropEnabled() - else limitsInPixels.midrange + self.axis_y = PatternAxisSizer( + detector_settings.height_px, + detector_settings.pixel_height_m, + pattern_settings.is_crop_enabled, + pattern_settings.crop_height_px, + pattern_settings.crop_center_y_px, + pattern_settings.is_binning_enabled, + pattern_settings.bin_size_y, + pattern_settings.is_padding_enabled, + pattern_settings.pad_y, ) - def _getSafeCenterXInPixels(self) -> int: - lower = self.getWidthInPixels() // 2 - upper = self._detector.getImageExtent().widthInPixels - 1 - lower - limits = Interval[int](lower, upper) - return limits.clamp(self.getCenterXInPixels()) - - def getPixelWidthInMeters(self) -> float: - return self._detector.pixelWidthInMeters.getValue() + self.axis_x.add_observer(self) + self.axis_y.add_observer(self) - def getWidthInMeters(self) -> float: - return self.getWidthInPixels() * self.getPixelWidthInMeters() + def get_processed_width_m(self) -> float: + return self.axis_x.get_processed_size_m() - def getHeightLimitsInPixels(self) -> Interval[int]: - return Interval[int](1, self._detector.getImageExtent().heightInPixels) + def get_processed_height_m(self) -> float: + return self.axis_y.get_processed_size_m() - def getHeightInPixels(self) -> int: - limitsInPixels = self.getHeightLimitsInPixels() - return ( - limitsInPixels.clamp(self._settings.cropHeightInPixels.getValue()) - if self.isCropEnabled() - else limitsInPixels.upper + def get_processed_image_extent(self) -> ImageExtent: + return ImageExtent( + width_px=self.axis_x.get_processed_size(), + height_px=self.axis_y.get_processed_size(), ) - def getCenterYLimitsInPixels(self) -> Interval[int]: - return Interval[int](0, self._detector.getImageExtent().heightInPixels) - - def getCenterYInPixels(self) -> int: - limitsInPixels = self.getCenterYLimitsInPixels() - return ( - limitsInPixels.clamp(self._settings.cropCenterYInPixels.getValue()) - if self.isCropEnabled() - else limitsInPixels.midrange + def get_processed_pixel_geometry(self) -> PixelGeometry: + return PixelGeometry( + width_m=self.axis_x.get_processed_pixel_size_m(), + height_m=self.axis_y.get_processed_pixel_size_m(), ) - def _getSafeCenterYInPixels(self) -> int: - lower = self.getHeightInPixels() // 2 - upper = self._detector.getImageExtent().heightInPixels - 1 - lower - limits = Interval[int](lower, upper) - return limits.clamp(self.getCenterYInPixels()) + def get_processor(self) -> DiffractionPatternProcessor: + value_lower_bound: int | None = None + value_upper_bound: int | None = None + crop: DiffractionPatternCrop | None = None + binning: DiffractionPatternBinning | None = None + padding: DiffractionPatternPadding | None = None - def getPixelHeightInMeters(self) -> float: - return self._detector.pixelHeightInMeters.getValue() + if self._pattern_settings.is_value_upper_bound_enabled.get_value(): + value_lower_bound = self._pattern_settings.value_lower_bound.get_value() - def getHeightInMeters(self) -> float: - return self.getHeightInPixels() * self.getPixelHeightInMeters() + if self._pattern_settings.is_value_upper_bound_enabled.get_value(): + value_upper_bound = self._pattern_settings.value_upper_bound.get_value() - def getImageExtent(self) -> ImageExtent: - return ImageExtent( - widthInPixels=self.getWidthInPixels(), - heightInPixels=self.getHeightInPixels(), + filter_values = DiffractionPatternFilterValues( + lower_bound=value_lower_bound, + upper_bound=value_upper_bound, ) - def __call__(self, data: DiffractionPatternArrayType) -> DiffractionPatternArrayType: - return data[:, self._sliceY, self._sliceX] if self.isCropEnabled() else data - - def _updateSlicesAndNotifyObservers(self) -> None: - centerXInPixels = self._getSafeCenterXInPixels() - radiusXInPixels = self.getWidthInPixels() // 2 - self._sliceX = slice(centerXInPixels - radiusXInPixels, centerXInPixels + radiusXInPixels) - - centerYInPixels = self._getSafeCenterYInPixels() - radiusYInPixels = self.getHeightInPixels() // 2 - self._sliceY = slice(centerYInPixels - radiusYInPixels, centerYInPixels + radiusYInPixels) - - self.notifyObservers() + if self._pattern_settings.is_crop_enabled.get_value(): + crop = DiffractionPatternCrop( + center=CropCenter( + self.axis_x.get_crop_center(), + self.axis_y.get_crop_center(), + ), + extent=ImageExtent( + self.axis_x.get_crop_size(), + self.axis_y.get_crop_size(), + ), + ) + + if self._pattern_settings.is_binning_enabled.get_value(): + self.axis_x.validate_bin_size() + self.axis_y.validate_bin_size() + binning = DiffractionPatternBinning( + bin_size_x=self.axis_x.get_bin_size(), + bin_size_y=self.axis_y.get_bin_size(), + ) + + if self._pattern_settings.is_padding_enabled.get_value(): + padding = DiffractionPatternPadding( + pad_x=self.axis_x.get_pad_size(), + pad_y=self.axis_y.get_pad_size(), + ) + + return DiffractionPatternProcessor( + filter_values=filter_values, + crop=crop, + binning=binning, + padding=padding, + flip_x=self._pattern_settings.is_flip_x_enabled.get_value(), + flip_y=self._pattern_settings.is_flip_y_enabled.get_value(), + ) - def update(self, observable: Observable) -> None: - if observable is self._settings: - self._updateSlicesAndNotifyObservers() - elif observable is self._detector: - self._updateSlicesAndNotifyObservers() + def _update(self, observable: Observable) -> None: + if observable in (self.axis_x, self.axis_y): + self.notify_observers() diff --git a/src/ptychodus/model/phase_unwrapper.py b/src/ptychodus/model/phase_unwrapper.py new file mode 100644 index 00000000..32969e3d --- /dev/null +++ b/src/ptychodus/model/phase_unwrapper.py @@ -0,0 +1,463 @@ +# mypy: ignore-errors + +import numpy +from numpy.typing import NDArray +from scipy import signal, ndimage +from typing import Literal, Optional, Tuple + + +class PhaseUnwrapper: + def __init__( + self, + fourier_shift_step: float = 0.5, + image_grad_method: Literal[ + 'fourier_shift', 'fourier_differentiation', 'nearest' + ] = 'fourier_differentiation', + image_integration_method: Literal['fourier', 'discrete', 'deconvolution'] = 'fourier', + weight_map: Optional[NDArray] = None, + eps: float = 1e-9, + ) -> None: + """Get the unwrapped phase of a complex 2D image. + + Parameters + ---------- + fourier_shift_step : float + The finite-difference step size used to calculate the gradient, + if the Fourier shift method is used. + image_grad_method : str + The method used to calculate the phase gradient. + - "fourier_shift": Use Fourier shift to perform shift. + - "nearest": Use nearest neighbor to perform shift. + - "fourier_differentiation": Use Fourier differentiation. + image_integration_method : str + The method used to integrate the image back from gradients. + - "fourier": Use Fourier integration as implemented in PtychoShelves. + - "deconvolution": Deconvolve ramp filter. + - "discrete": Use cumulative sum. + weight_map : Optional[NDArray] + A weight map multiplied to the input image. + eps : float + A small number to avoid division by zero. + """ + self.fourier_shift_step = fourier_shift_step + self.image_grad_method = image_grad_method + self.image_integration_method = image_integration_method + self.weight_map = weight_map + self.eps = eps + + def unwrap(self, img: NDArray) -> NDArray: + """Run unwrapping. + + Parameters + ---------- + img : NDArray + A 2D complex array giving the image to be unwrapped. + + Returns + ------- + NDArray + A 2D real array giving the unwrapped phase of the input image. + """ + if not numpy.iscomplexobj(img): + raise ValueError('Input array must be complex.') + + if self.weight_map is not None: + weight_map = float(numpy.clip(self.weight_map, 0.0, 1.0)) + else: + weight_map = 1.0 + + img = weight_map * img / (numpy.abs(img) + self.eps) + bc_center = numpy.angle(img[img.shape[0] // 2, img.shape[1] // 2]) + + # Pad image to avoid FFT boundary artifacts. + padding = [64, 64] + if any(numpy.array(padding) > 0): + img = numpy.pad( + img, ((padding[0], padding[0]), (padding[1], padding[1])), mode='reflect' + ) + img = vignett(img, margin=10, sigma=2.5) + + gy, gx = get_phase_gradient( + img, + fourier_shift_step=self.fourier_shift_step, + image_grad_method=self.image_grad_method, + ) + + if self.image_integration_method == 'discrete' and any(numpy.array(padding) > 0): + gy = gy[padding[0] : -padding[0], padding[1] : -padding[1]] + gx = gx[padding[0] : -padding[0], padding[1] : -padding[1]] + if self.image_integration_method == 'discrete': + phase = numpy.real(integrate_image_2d(gy, gx, bc_center=bc_center)) + elif self.image_integration_method == 'fourier': + phase = numpy.real(integrate_image_2d_fourier(gy, gx)) + elif self.image_integration_method == 'deconvolution': + phase = numpy.real(integrate_image_2d_deconvolution(gy, gx, bc_center=bc_center)) + else: + raise ValueError(f'Unknown integration method: {self.image_integration_method}') + + if self.image_integration_method != 'discrete' and any(numpy.array(padding) > 0): + gy = gy[padding[0] : -padding[0], padding[1] : -padding[1]] + gx = gx[padding[0] : -padding[0], padding[1] : -padding[1]] + phase = phase[padding[0] : -padding[0], padding[1] : -padding[1]] + + return phase + + +def vignett(img: NDArray, margin: int = 20, sigma: float = 1.0) -> NDArray: + """Vignett an image so that it gradually decays near the boundary. + For each dimension of the image, a mask with a width of `2 * margin` + and with half of it filled with 0s and half with 1s is + generated and convolved with a Gaussian kernel of size + `margin` and standard deviation `sigma`. The blurred mask is cropped and + multiplied to the near-edge regions of the image. + + Parameters + ---------- + img : Tensor + The input image. + margin : int + The margin of image where the decay takes place. + sigma : float + The standard deviation of the Gaussian kernel. + """ + img = img.copy() + for i_dim in range(img.ndim): + if img.shape[i_dim] <= 2 * margin: + continue + + mask_shape = ( + [img.shape[i] for i in range(i_dim)] + + [2 * margin] + + [img.shape[i] for i in range(i_dim + 1, img.ndim)] + ) + mask = numpy.zeros(mask_shape) + mask_slicer = [slice(None)] * i_dim + [slice(margin, None)] + mask[tuple(mask_slicer)] = 1.0 + + gauss_win = signal.windows.gaussian(margin // 2, std=sigma) + gauss_win = gauss_win / numpy.sum(gauss_win) + mask = ndimage.convolve1d(mask, gauss_win, axis=i_dim, mode='constant') + mask_final_slicer = [slice(None)] * i_dim + [slice(len(gauss_win), len(gauss_win) + margin)] + + mask = mask[tuple(mask_final_slicer)] + + mask = numpy.where(mask < 1e-3, 0, mask) + + slicer = tuple([slice(None)] * i_dim + [slice(0, margin)]) + img[slicer] = img[slicer] * mask + + slicer = tuple([slice(None)] * i_dim + [slice(-margin, None)]) + img[slicer] = img[slicer] * numpy.flip(mask, axis=i_dim) + return img + + +def nearest_neighbor_gradient( + image: NDArray, direction: Literal['forward', 'backward'], dim: Tuple[int, ...] = (0, 1) +) -> Tuple[NDArray, NDArray]: + """ + Calculate the nearest neighbor gradient of a 2D image. + + Parameters + ---------- + image : NDArray + a (... H, W) tensor of images. + direction : str + 'forward' or 'backward'. + dim : tuple of int, optional + Dimensions to calculate gradient. Default is (0, 1). + + Returns + ------- + tuple of NDArray + a tuple of 2 images with the gradient in y and x directions. + """ + if not hasattr(dim, '__len__'): + dim = (dim,) + grad_x = None + grad_y = None + if direction == 'forward': + if 1 in dim: + grad_x = numpy.concatenate([image[:, 1:], image[:, -1:]], axis=1) - image + if 0 in dim: + grad_y = numpy.concatenate([image[1:, :], image[-1:, :]], axis=0) - image + elif direction == 'backward': + if 1 in dim: + grad_x = image - numpy.concatenate([image[:, :1], image[:, :-1]], axis=1) + if 0 in dim: + grad_y = image - numpy.concatenate([image[:1, :], image[:-1, :]], axis=0) + else: + raise ValueError("direction must be 'forward' or 'backward'") + return grad_y, grad_x + + +def gaussian_gradient(image: NDArray, sigma: float = 1.0, kernel_size=5) -> Tuple[NDArray, NDArray]: + """ + Calculate the gradient of a 2D image with a Gaussian-derivative kernel. + + Parameters + ---------- + image : NDArray + A (... H, W) tensor of images. + sigma : float + Sigma of the Gaussian. + + Returns + ------- + tuple of NDArray + A tuple of 2 images with the gradient in y and x directions. + """ + r = numpy.arange(kernel_size) - (kernel_size - 1) / 2.0 + kernel = -r / (numpy.sqrt(2 * numpy.pi) * sigma**3) * numpy.exp(-(r**2) / (2 * sigma**2)) + grad_y = ndimage.convolve(image, kernel.reshape(-1, 1), mode='nearest') + grad_x = ndimage.convolve(image, kernel.reshape(1, -1), mode='nearest') + + # Gate the gradients + grads = [grad_y, grad_x] + for i, g in enumerate(grads): + m = numpy.logical_and(numpy.abs(grad_y) < 1e-6, numpy.abs(grad_y) != 0) + if numpy.count_nonzero(m) > 0: + print('Gradient magnitudes between 0 and 1e-6 are set to 0.') + g = g * numpy.logical_not(m) + grads[i] = g + grad_y, grad_x = grads + return grad_y, grad_x + + +def fourier_gradient(image: NDArray) -> Tuple[NDArray, NDArray]: + """Calculate gradient using NumPy FFT operations""" + u = numpy.fft.fftfreq(image.shape[0]) + v = numpy.fft.fftfreq(image.shape[1]) + u, v = numpy.meshgrid(u, v, indexing='ij') + + grad_y = numpy.fft.ifft(numpy.fft.fft(image, axis=-2) * (2j * numpy.pi * u), axis=-2) + grad_x = numpy.fft.ifft(numpy.fft.fft(image, axis=-1) * (2j * numpy.pi * v), axis=-1) + + return grad_y, grad_x + + +def get_phase_gradient( + img: NDArray, + fourier_shift_step: float = 0, + image_grad_method: Literal[ + 'fourier_shift', 'fourier_differentiation', 'nearest' + ] = 'fourier_shift', + eps: float = 1e-6, +) -> Tuple[NDArray, NDArray]: + """ + Get the gradient of the phase of a complex 2D image by first calculating + the spatial gradient of the complex image, then taking the phase of the + complex gradient -- i.e., it takes the phase of the gradient rather than + the gradient of the phase. This avoids the sharp gradients due to phase + wrapping when directly taking the gradient of the phase. + + Parameters + ---------- + img : NDArray + A [N, H, W] or [H, W] tensor giving a batch of images or a single image. + step : float + The finite-difference step size used to calculate the gradient, if + the Fourier shift method is used. + finite_diff_method : enums.ImageGradientMethods + The method used to calculate the phase gradient. + - "fourier_shift": Use Fourier shift to perform shift. + - "nearest": Use nearest neighbor to perform shift. + - "fourier_differentiation": Use Fourier differentiation. + eps : float + A stablizing constant. + + Returns + ------- + Tuple[NDArray, NDArray] + A tuple of 2 images with the gradient in y and x directions. + """ + if fourier_shift_step <= 0 and image_grad_method == 'fourier_shift': + raise ValueError('Step must be positive.') + + if image_grad_method == 'fourier_differentiation': + gy, gx = fourier_gradient(img) + gy = numpy.imag(numpy.conj(img) * gy) + gx = numpy.imag(numpy.conj(img) * gx) + else: + # Use finite difference. + if img.ndim == 2: + img = img[None, ...] + pad = int(numpy.ceil(fourier_shift_step)) + 1 + img = numpy.pad(img, ((0, 0), (pad, pad), (pad, pad)), mode='reflect') + + sy1 = numpy.array([[-fourier_shift_step, 0]]).repeat(img.shape[0], axis=0) + sy2 = numpy.array([[fourier_shift_step, 0]]).repeat(img.shape[0], axis=0) + if image_grad_method == 'fourier_shift': + # If the image contains zero-valued pixels, Fourier shift can result in small + # non-zero values that dangles around 0. This can cause the phase + # of the shifted image to dangle between pi and -pi. In that case, use + # `finite_diff_method="nearest" instead`, or use `step=1`. + complex_prod = fourier_shift(img, sy1) * fourier_shift(img, sy2).conj() + elif image_grad_method == 'nearest': + complex_prod = img * numpy.concatenate([img[:, :1, :], img[:, :-1, :]], axis=1).conj() + else: + raise ValueError(f'Unknown finite-difference method: {image_grad_method}') + complex_prod = numpy.where( + numpy.abs(complex_prod) < numpy.abs(complex_prod).max() * 1e-6, 0, complex_prod + ) + gy = numpy.angle(complex_prod) / (2 * fourier_shift_step) + gy = gy[0, pad:-pad, pad:-pad] + + sx1 = numpy.array([[0, -fourier_shift_step]]).repeat(img.shape[0], axis=0) + sx2 = numpy.array([[0, fourier_shift_step]]).repeat(img.shape[0], axis=0) + if image_grad_method == 'fourier_shift': + complex_prod = fourier_shift(img, sx1) * fourier_shift(img, sx2).conj() + elif image_grad_method == 'nearest': + complex_prod = img * numpy.concatenate([img[:, :, :1], img[:, :, :-1]], axis=2).conj() + complex_prod = numpy.where( + numpy.abs(complex_prod) < numpy.abs(complex_prod).max() * 1e-6, 0, complex_prod + ) + gx = numpy.angle(complex_prod) / (2 * fourier_shift_step) + gx = gx[0, pad:-pad, pad:-pad] + return gy, gx + + +def integrate_image_2d_fourier(grad_y: NDArray, grad_x: NDArray) -> NDArray: + """ + Integrate an image with the gradient in y and x directions using Fourier + differentiation. + + Parameters + ---------- + grad_y, grad_x: NDArray + A (H, W) tensor of gradients in y or x directions. + + Returns + ------- + NDArray + The integrated image. + """ + shape = grad_y.shape + f = numpy.fft.fft2(grad_x + 1j * grad_y) + y, x = numpy.fft.fftfreq(shape[0]), numpy.fft.fftfreq(shape[1]) + + r = 1.0 + r = r / (2j * numpy.pi * (x + 1j * y[:, None]) + 1e-15) + r[0, 0] = 0 + integrated_image = f * r + integrated_image = numpy.fft.ifft2(integrated_image) + if not numpy.iscomplexobj(grad_x): + integrated_image = integrated_image.real + return integrated_image + + +def integrate_image_2d_deconvolution( + grad_y: NDArray, + grad_x: NDArray, + tf_y: Optional[NDArray] = None, + tf_x: Optional[NDArray] = None, + bc_center: float = 0, +) -> NDArray: + """ + Integrate an image with the gradient in y and x directions by deconvolving + the differentiation kernel, whose transfer function is assumed to be a + ramp function. + + Adapted from Tripathi, A., McNulty, I., Munson, T., & Wild, S. M. (2016). + Single-view phase retrieval of an extended sample by exploiting edge detection + and sparsity. Optics Express, 24(21), 24719–24738. doi:10.1364/OE.24.024719 + + Parameters + ---------- + grad_y, grad_x: NDArray + A (H, W) tensor of gradients in y or x directions. + tf_y, tf_x: NDArray + A (H, W) tensor of transfer functions in y or x directions. If not + provided, they are assumed to be 2i * pi * u (or v), which are the + effective transfer functions in Fourier differentiation. + bc_center: float + The value of the boundary condition at the center of the image. + + Returns + ------- + NDArray + The integrated image. + """ + u, v = numpy.fft.fftfreq(grad_x.shape[0]), numpy.fft.fftfreq(grad_x.shape[1]) + u, v = numpy.meshgrid(u, v, indexing='ij') + if tf_y is None or tf_x is None: + tf_y = 2j * numpy.pi * u + tf_x = 2j * numpy.pi * v + f_grad_y = numpy.fft.fft2(grad_y) + f_grad_x = numpy.fft.fft2(grad_x) + img = (f_grad_y * tf_y + f_grad_x * tf_x) / (numpy.abs(tf_y) ** 2 + numpy.abs(tf_x) ** 2 + 1e-5) + img = -numpy.fft.ifft2(img) + img = img + bc_center - img[img.shape[0] // 2, img.shape[1] // 2] + return img + + +def integrate_image_2d(grad_y: NDArray, grad_x: NDArray, bc_center: float = 0) -> NDArray: + """ + Integrate an image with the gradient in y and x directions. + + Parameters + ---------- + grad_y : NDArray + The gradient in y direction. + grad_x : NDArray + The gradient in x direction. + bc_center : float + The boundary condition at the center of the image, by default 0 + + Returns + ------- + NDArray + The integrated image. + """ + left_boundary = numpy.cumsum(grad_y[:, 0], axis=0) + int_img = numpy.cumsum(grad_x, axis=1) + left_boundary[:, None] + int_img = int_img + bc_center - int_img[int_img.shape[0] // 2, int_img.shape[1] // 2] + return int_img + + +def fourier_shift( + images: NDArray, shifts: NDArray, strictly_preserve_zeros: bool = False +) -> NDArray: + """ + Apply Fourier shift to a batch of images. + + Parameters + ---------- + images : NDArray + A [N, H, W] array of images. + shifts : NDArray + A [N, 2] array of shifts in pixels. + strictly_preserve_zeros : bool + If True, mask of strictly zero pixels will be generated and shifted + by the same amount. Pixels that have a non-zero value in the shifted + mask will be set to zero in the shifted image. This preserves the zero + pixels in the original image, preventing FFT from introducing small + non-zero values due to machine precision. + + Returns + ------- + NDArray + Shifted images. + """ + if strictly_preserve_zeros: + zero_mask = images == 0 + zero_mask = zero_mask.float() + zero_mask_shifted = fourier_shift(zero_mask, shifts, strictly_preserve_zeros=False) + ft_images = numpy.fft.fft2(images) + freq_y, freq_x = numpy.meshgrid( + numpy.fft.fftfreq(images.shape[-2]), numpy.fft.fftfreq(images.shape[-1]), indexing='ij' + ) + freq_x = freq_x.repeat(images.shape[0], axis=0) + freq_y = freq_y.repeat(images.shape[0], axis=0) + mult = numpy.exp( + 1j + * -2 + * numpy.pi + * (freq_x * shifts[:, 1].reshape([-1, 1, 1]) + freq_y * shifts[:, 0].reshape([-1, 1, 1])) + ) + ft_images = ft_images * mult + shifted_images = numpy.fft.ifft2(ft_images) + if not numpy.iscomplexobj(images): + shifted_images = shifted_images.real + if strictly_preserve_zeros: + shifted_images[zero_mask_shifted > 0] = 0 + return shifted_images diff --git a/src/ptychodus/model/product/__init__.py b/src/ptychodus/model/product/__init__.py index ebc3ae1b..350e65c9 100644 --- a/src/ptychodus/model/product/__init__.py +++ b/src/ptychodus/model/product/__init__.py @@ -1,14 +1,16 @@ -from .api import ObjectAPI, ProbeAPI, ProductAPI, ScanAPI +from .api import ObjectAPI, ProbeAPI, ProductAPI, ScanAPI, PositionsStreamingContext from .core import ProductCore from .item import ProductRepositoryItem, ProductRepositoryObserver -from .objectRepository import ObjectRepository -from .probeRepository import ProbeRepository -from .productRepository import ProductRepository -from .scanRepository import ScanRepository +from .object_repository import ObjectRepository +from .probe_repository import ProbeRepository +from .repository import ProductRepository +from .scan_repository import ScanRepository +from .settings import ProductSettings __all__ = [ 'ObjectAPI', 'ObjectRepository', + 'PositionsStreamingContext', 'ProbeAPI', 'ProbeRepository', 'ProductAPI', @@ -16,6 +18,7 @@ 'ProductRepository', 'ProductRepositoryItem', 'ProductRepositoryObserver', + 'ProductSettings', 'ScanAPI', 'ScanRepository', ] diff --git a/src/ptychodus/model/product/api.py b/src/ptychodus/model/product/api.py index ee7aefe5..36cd16e1 100644 --- a/src/ptychodus/model/product/api.py +++ b/src/ptychodus/model/product/api.py @@ -4,39 +4,69 @@ import logging from ptychodus.api.plugins import PluginChooser -from ptychodus.api.product import ProductFileReader, ProductFileWriter +from ptychodus.api.product import Product, ProductFileReader, ProductFileWriter -from ..patterns import ProductSettings -from .object.builderFactory import ObjectBuilderFactory +from .item import ProductRepositoryItem +from .item_factory import ProductRepositoryItemFactory +from .object.builder_factory import ObjectBuilderFactory from .object.settings import ObjectSettings -from .objectRepository import ObjectRepository -from .probe.builderFactory import ProbeBuilderFactory +from .object_repository import ObjectRepository +from .probe.builder_factory import ProbeBuilderFactory from .probe.settings import ProbeSettings -from .probeRepository import ProbeRepository -from .productRepository import ProductRepository -from .scan.builderFactory import ScanBuilderFactory +from .probe_repository import ProbeRepository +from .repository import ProductRepository +from .scan.builder_factory import ScanBuilderFactory from .scan.settings import ScanSettings -from .scanRepository import ScanRepository +from .scan_repository import ScanRepository +from .settings import ProductSettings logger = logging.getLogger(__name__) +class PositionsStreamingContext: + def __init__(self) -> None: + self._positions_x_m: list[float] = [] + self._triggers_x: list[int] = [] + self._positions_y_m: list[float] = [] + self._triggers_y: list[int] = [] + + def start(self) -> None: + self._positions_x_m.clear() + self._triggers_x.clear() + self._positions_y_m.clear() + self._triggers_y.clear() + + def append_positions_x(self, values_m: Sequence[float], trigger_counts: Sequence[int]) -> None: + self._positions_x_m.extend(values_m) + self._triggers_x.extend(trigger_counts) + + def append_positions_y(self, values_m: Sequence[float], trigger_counts: Sequence[int]) -> None: + self._positions_y_m.extend(values_m) + self._triggers_y.extend(trigger_counts) + + def stop(self) -> None: + pass # TODO + + class ScanAPI: def __init__( self, settings: ScanSettings, repository: ScanRepository, - builderFactory: ScanBuilderFactory, + builder_factory: ScanBuilderFactory, ) -> None: self._settings = settings self._repository = repository - self._builderFactory = builderFactory + self._builder_factory = builder_factory - def builderNames(self) -> Iterator[str]: - return iter(self._builderFactory) + def create_streaming_context(self) -> PositionsStreamingContext: + return PositionsStreamingContext() - def buildScan( - self, index: int, builderName: str, builderParameters: Mapping[str, Any] = {} + def builder_names(self) -> Iterator[str]: + return iter(self._builder_factory) + + def build_scan( + self, index: int, builder_name: str, builder_parameters: Mapping[str, Any] = {} ) -> None: try: item = self._repository[index] @@ -45,25 +75,24 @@ def buildScan( return try: - builder = self._builderFactory.create(builderName) + builder = self._builder_factory.create(builder_name) except KeyError: - logger.warning(f'Failed to create builder {builderName}!') + logger.warning(f'Failed to create builder {builder_name}!') return - for parameterName, parameterValue in builderParameters.items(): + for parameter_name, parameter_value in builder_parameters.items(): try: - parameter = builder.parameters()[parameterName] + parameter = builder.parameters()[parameter_name] except KeyError: logger.warning( - f'Scan builder "{builder.getName()}" does not have' - f' parameter "{parameterName}"!' + f'Scan builder "{builder.get_name()}" does not have parameter "{parameter_name}"!' ) else: - parameter.setValue(parameterValue) + parameter.set_value(parameter_value) - item.setBuilder(builder) + item.set_builder(builder) - def buildScanFromSettings(self, index: int) -> None: + def build_scan_from_settings(self, index: int) -> None: try: item = self._repository[index] except IndexError: @@ -71,23 +100,23 @@ def buildScanFromSettings(self, index: int) -> None: return try: - builder = self._builderFactory.createFromSettings() + builder = self._builder_factory.create_from_settings() except KeyError: logger.warning('Failed to create builder from settings!') return - item.setBuilder(builder) + item.set_builder(builder) - def getOpenFileFilterList(self) -> Sequence[str]: - return self._builderFactory.getOpenFileFilterList() + def get_open_file_filters(self) -> Iterator[str]: + return self._builder_factory.get_open_file_filters() - def getOpenFileFilter(self) -> str: - return self._builderFactory.getOpenFileFilter() + def get_open_file_filter(self) -> str: + return self._builder_factory.get_open_file_filter() - def openScan(self, index: int, filePath: Path, *, fileType: str | None = None) -> None: - builder = self._builderFactory.createScanFromFile( - filePath, - self._settings.fileType.getValue() if fileType is None else fileType, + def open_scan(self, index: int, file_path: Path, *, file_type: str | None = None) -> None: + builder = self._builder_factory.create_scan_from_file( + file_path, + self._settings.file_type.get_value() if file_type is None else file_type, ) try: @@ -95,38 +124,38 @@ def openScan(self, index: int, filePath: Path, *, fileType: str | None = None) - except IndexError: logger.warning(f'Failed to open scan {index}!') else: - item.setBuilder(builder) + item.set_builder(builder) - def copyScan(self, sourceIndex: int, destinationIndex: int) -> None: - logger.debug(f'Copying {sourceIndex} -> {destinationIndex}') + def copy_scan(self, source_index: int, destination_index: int) -> None: + logger.debug(f'Copying {source_index} -> {destination_index}') try: - sourceItem = self._repository[sourceIndex] + source_item = self._repository[source_index] except IndexError: - logger.warning(f'Failed to access source scan {sourceIndex} for copying!') + logger.warning(f'Failed to access source scan {source_index} for copying!') return try: - destinationItem = self._repository[destinationIndex] + destination_item = self._repository[destination_index] except IndexError: - logger.warning(f'Failed to access destination scan {destinationIndex} for copying!') + logger.warning(f'Failed to access destination scan {destination_index} for copying!') return - destinationItem.assignItem(sourceItem) + destination_item.assign_item(source_item) - def getSaveFileFilterList(self) -> Sequence[str]: - return self._builderFactory.getSaveFileFilterList() + def get_save_file_filters(self) -> Iterator[str]: + return self._builder_factory.get_save_file_filters() - def getSaveFileFilter(self) -> str: - return self._builderFactory.getSaveFileFilter() + def get_save_file_filter(self) -> str: + return self._builder_factory.get_save_file_filter() - def saveScan(self, index: int, filePath: Path, fileType: str) -> None: + def save_scan(self, index: int, file_path: Path, file_type: str) -> None: try: item = self._repository[index] except IndexError: logger.warning(f'Failed to save scan {index}!') else: - self._builderFactory.saveScan(filePath, fileType, item.getScan()) + self._builder_factory.save_scan(file_path, file_type, item.get_scan()) class ProbeAPI: @@ -134,17 +163,17 @@ def __init__( self, settings: ProbeSettings, repository: ProbeRepository, - builderFactory: ProbeBuilderFactory, + builder_factory: ProbeBuilderFactory, ) -> None: self._settings = settings self._repository = repository - self._builderFactory = builderFactory + self._builder_factory = builder_factory - def builderNames(self) -> Iterator[str]: - return iter(self._builderFactory) + def builder_names(self) -> Iterator[str]: + return iter(self._builder_factory) - def buildProbe( - self, index: int, builderName: str, builderParameters: Mapping[str, Any] = {} + def build_probe( + self, index: int, builder_name: str, builder_parameters: Mapping[str, Any] = {} ) -> None: try: item = self._repository[index] @@ -153,25 +182,25 @@ def buildProbe( return try: - builder = self._builderFactory.create(builderName) + builder = self._builder_factory.create(builder_name) except KeyError: - logger.warning(f'Failed to create builder {builderName}!') + logger.warning(f'Failed to create builder {builder_name}!') return - for parameterName, parameterValue in builderParameters.items(): + for parameter_name, parameter_value in builder_parameters.items(): try: - parameter = builder.parameters()[parameterName] + parameter = builder.parameters()[parameter_name] except KeyError: logger.warning( - f'Probe builder "{builder.getName()}" does not have' - f' parameter "{parameterName}"!' + f'Probe builder "{builder.get_name()}" does not have' + f' parameter "{parameter_name}"!' ) else: - parameter.setValue(parameterValue) + parameter.set_value(parameter_value) - item.setBuilder(builder) + item.set_builder(builder) - def buildProbeFromSettings(self, index: int) -> None: + def build_probe_from_settings(self, index: int) -> None: try: item = self._repository[index] except IndexError: @@ -179,23 +208,23 @@ def buildProbeFromSettings(self, index: int) -> None: return try: - builder = self._builderFactory.createFromSettings() + builder = self._builder_factory.create_from_settings() except KeyError: logger.warning('Failed to create builder from settings!') return - item.setBuilder(builder) + item.set_builder(builder) - def getOpenFileFilterList(self) -> Sequence[str]: - return self._builderFactory.getOpenFileFilterList() + def get_open_file_filters(self) -> Iterator[str]: + return self._builder_factory.get_open_file_filters() - def getOpenFileFilter(self) -> str: - return self._builderFactory.getOpenFileFilter() + def get_open_file_filter(self) -> str: + return self._builder_factory.get_open_file_filter() - def openProbe(self, index: int, filePath: Path, *, fileType: str | None = None) -> None: - builder = self._builderFactory.createProbeFromFile( - filePath, - self._settings.fileType.getValue() if fileType is None else fileType, + def open_probe(self, index: int, file_path: Path, *, file_type: str | None = None) -> None: + builder = self._builder_factory.create_probe_from_file( + file_path, + self._settings.file_type.get_value() if file_type is None else file_type, ) try: @@ -203,38 +232,38 @@ def openProbe(self, index: int, filePath: Path, *, fileType: str | None = None) except IndexError: logger.warning(f'Failed to open probe {index}!') else: - item.setBuilder(builder) + item.set_builder(builder) - def copyProbe(self, sourceIndex: int, destinationIndex: int) -> None: - logger.debug(f'Copying {sourceIndex} -> {destinationIndex}') + def copy_probe(self, source_index: int, destination_index: int) -> None: + logger.debug(f'Copying {source_index} -> {destination_index}') try: - sourceItem = self._repository[sourceIndex] + source_item = self._repository[source_index] except IndexError: - logger.warning(f'Failed to access source probe {sourceIndex} for copying!') + logger.warning(f'Failed to access source probe {source_index} for copying!') return try: - destinationItem = self._repository[destinationIndex] + destination_item = self._repository[destination_index] except IndexError: - logger.warning(f'Failed to access destination probe {destinationIndex} for copying!') + logger.warning(f'Failed to access destination probe {destination_index} for copying!') return - destinationItem.assignItem(sourceItem) + destination_item.assign_item(source_item) - def getSaveFileFilterList(self) -> Sequence[str]: - return self._builderFactory.getSaveFileFilterList() + def get_save_file_filters(self) -> Iterator[str]: + return self._builder_factory.get_save_file_filters() - def getSaveFileFilter(self) -> str: - return self._builderFactory.getSaveFileFilter() + def get_save_file_filter(self) -> str: + return self._builder_factory.get_save_file_filter() - def saveProbe(self, index: int, filePath: Path, fileType: str) -> None: + def save_probe(self, index: int, file_path: Path, file_type: str) -> None: try: item = self._repository[index] except IndexError: logger.warning(f'Failed to save probe {index}!') else: - self._builderFactory.saveProbe(filePath, fileType, item.getProbe()) + self._builder_factory.save_probe(file_path, file_type, item.get_probes()) class ObjectAPI: @@ -242,17 +271,17 @@ def __init__( self, settings: ObjectSettings, repository: ObjectRepository, - builderFactory: ObjectBuilderFactory, + builder_factory: ObjectBuilderFactory, ) -> None: self._settings = settings self._repository = repository - self._builderFactory = builderFactory + self._builder_factory = builder_factory - def builderNames(self) -> Iterator[str]: - return iter(self._builderFactory) + def builder_names(self) -> Iterator[str]: + return iter(self._builder_factory) - def buildObject( - self, index: int, builderName: str, builderParameters: Mapping[str, Any] = {} + def build_object( + self, index: int, builder_name: str, builder_parameters: Mapping[str, Any] = {} ) -> None: try: item = self._repository[index] @@ -261,25 +290,25 @@ def buildObject( return try: - builder = self._builderFactory.create(builderName) + builder = self._builder_factory.create(builder_name) except KeyError: - logger.warning(f'Failed to create builder {builderName}!') + logger.warning(f'Failed to create builder {builder_name}!') return - for parameterName, parameterValue in builderParameters.items(): + for parameter_name, parameter_value in builder_parameters.items(): try: - parameter = builder.parameters()[parameterName] + parameter = builder.parameters()[parameter_name] except KeyError: logger.warning( - f'Object builder "{builder.getName()}" does not have' - f' parameter "{parameterName}"!' + f'Object builder "{builder.get_name()}" does not have' + f' parameter "{parameter_name}"!' ) else: - parameter.setValue(parameterValue) + parameter.set_value(parameter_value) - item.setBuilder(builder) + item.set_builder(builder) - def buildObjectFromSettings(self, index: int) -> None: + def build_object_from_settings(self, index: int) -> None: try: item = self._repository[index] except IndexError: @@ -287,23 +316,23 @@ def buildObjectFromSettings(self, index: int) -> None: return try: - builder = self._builderFactory.createFromSettings() + builder = self._builder_factory.create_from_settings() except KeyError: logger.warning('Failed to create builder from settings!') return - item.setBuilder(builder) + item.set_builder(builder) - def getOpenFileFilterList(self) -> Sequence[str]: - return self._builderFactory.getOpenFileFilterList() + def get_open_file_filters(self) -> Iterator[str]: + return self._builder_factory.get_open_file_filters() - def getOpenFileFilter(self) -> str: - return self._builderFactory.getOpenFileFilter() + def get_open_file_filter(self) -> str: + return self._builder_factory.get_open_file_filter() - def openObject(self, index: int, filePath: Path, *, fileType: str | None = None) -> None: - builder = self._builderFactory.createObjectFromFile( - filePath, - self._settings.fileType.getValue() if fileType is None else fileType, + def open_object(self, index: int, file_path: Path, *, file_type: str | None = None) -> None: + builder = self._builder_factory.create_object_from_file( + file_path, + self._settings.file_type.get_value() if file_type is None else file_type, ) try: @@ -311,38 +340,38 @@ def openObject(self, index: int, filePath: Path, *, fileType: str | None = None) except IndexError: logger.warning(f'Failed to open object {index}!') else: - item.setBuilder(builder) + item.set_builder(builder) - def copyObject(self, sourceIndex: int, destinationIndex: int) -> None: - logger.debug(f'Copying {sourceIndex} -> {destinationIndex}') + def copy_object(self, source_index: int, destination_index: int) -> None: + logger.debug(f'Copying {source_index} -> {destination_index}') try: - sourceItem = self._repository[sourceIndex] + source_item = self._repository[source_index] except IndexError: - logger.warning(f'Failed to access source object {sourceIndex} for copying!') + logger.warning(f'Failed to access source object {source_index} for copying!') return try: - destinationItem = self._repository[destinationIndex] + destination_item = self._repository[destination_index] except IndexError: - logger.warning(f'Failed to access destination object {destinationIndex} for copying!') + logger.warning(f'Failed to access destination object {destination_index} for copying!') return - destinationItem.assignItem(sourceItem) + destination_item.assign_item(source_item) - def getSaveFileFilterList(self) -> Sequence[str]: - return self._builderFactory.getSaveFileFilterList() + def get_save_file_filters(self) -> Iterator[str]: + return self._builder_factory.get_save_file_filters() - def getSaveFileFilter(self) -> str: - return self._builderFactory.getSaveFileFilter() + def get_save_file_filter(self) -> str: + return self._builder_factory.get_save_file_filter() - def saveObject(self, index: int, filePath: Path, fileType: str) -> None: + def save_object(self, index: int, file_path: Path, file_type: str) -> None: try: item = self._repository[index] except IndexError: logger.warning(f'Failed to save object {index}!') else: - self._builderFactory.saveObject(filePath, fileType, item.getObject()) + self._builder_factory.save_object(file_path, file_type, item.get_object()) class ProductAPI: @@ -350,82 +379,97 @@ def __init__( self, settings: ProductSettings, repository: ProductRepository, - fileReaderChooser: PluginChooser[ProductFileReader], - fileWriterChooser: PluginChooser[ProductFileWriter], + item_factory: ProductRepositoryItemFactory, + file_reader_chooser: PluginChooser[ProductFileReader], + file_writer_chooser: PluginChooser[ProductFileWriter], ) -> None: self._settings = settings self._repository = repository - self._fileReaderChooser = fileReaderChooser - self._fileWriterChooser = fileWriterChooser + self._item_factory = item_factory + self._file_reader_chooser = file_reader_chooser + self._file_writer_chooser = file_writer_chooser - def insertNewProduct( + def insert_new_product( self, name: str = 'Unnamed', *, comments: str = '', - detectorDistanceInMeters: float | None = None, - probeEnergyInElectronVolts: float | None = None, - probePhotonsPerSecond: float | None = None, - exposureTimeInSeconds: float | None = None, - likeIndex: int = -1, + detector_distance_m: float | None = None, + probe_energy_eV: float | None = None, # noqa: N803 + probe_photon_count: float | None = None, + exposure_time_s: float | None = None, + mass_attenuation_m2_kg: float | None = None, + tomography_angle_deg: float | None = None, ) -> int: - return self._repository.insertNewProduct( + item = self._item_factory.create_from_values( name=name, comments=comments, - detectorDistanceInMeters=detectorDistanceInMeters, - probeEnergyInElectronVolts=probeEnergyInElectronVolts, - probePhotonsPerSecond=probePhotonsPerSecond, - exposureTimeInSeconds=exposureTimeInSeconds, - likeIndex=likeIndex, + detector_distance_m=detector_distance_m, + probe_energy_eV=probe_energy_eV, + probe_photon_count=probe_photon_count, + exposure_time_s=exposure_time_s, + mass_attenuation_m2_kg=mass_attenuation_m2_kg, + tomography_angle_deg=tomography_angle_deg, ) + return self._repository.insert_product(item) + + def insert_product(self, product: Product) -> int: + item = self._item_factory.create_from_product(product) + return self._repository.insert_product(item) + + def insert_product_from_settings(self) -> int: + item = self._item_factory.create_from_settings() + return self._repository.insert_product(item) - def getItemName(self, productIndex: int) -> str: - item = self._repository[productIndex] - return item.getName() + def get_item(self, product_index: int) -> ProductRepositoryItem: + return self._repository[product_index] - def getOpenFileFilterList(self) -> Sequence[str]: - return self._fileReaderChooser.getDisplayNameList() + def get_open_file_filters(self) -> Iterator[str]: + for plugin in self._file_reader_chooser: + yield plugin.display_name - def getOpenFileFilter(self) -> str: - return self._fileReaderChooser.currentPlugin.displayName + def get_open_file_filter(self) -> str: + return self._file_reader_chooser.get_current_plugin().display_name - def openProduct(self, filePath: Path, *, fileType: str | None = None) -> int: - if filePath.is_file(): - self._fileReaderChooser.setCurrentPluginByName( - self._settings.fileType.getValue() if fileType is None else fileType - ) - fileType = self._fileReaderChooser.currentPlugin.simpleName - logger.debug(f'Reading "{filePath}" as "{fileType}"') - fileReader = self._fileReaderChooser.currentPlugin.strategy + def open_product(self, file_path: Path, *, file_type: str | None = None) -> int: + if file_path.is_file(): + if file_type is not None: + self._file_reader_chooser.set_current_plugin(file_type) + + file_type = self._file_reader_chooser.get_current_plugin().simple_name + logger.debug(f'Reading "{file_path}" as "{file_type}"') + file_reader = self._file_reader_chooser.get_current_plugin().strategy try: - product = fileReader.read(filePath) + product = file_reader.read(file_path) except Exception as exc: - raise RuntimeError(f'Failed to read "{filePath}"') from exc + raise RuntimeError(f'Failed to read "{file_path}"') from exc else: - return self._repository.insertProduct(product) + item = self._item_factory.create_from_product(product) + return self._repository.insert_product(item) else: - logger.warning(f'Refusing to create product with invalid file path "{filePath}"') + logger.warning(f'Refusing to create product with invalid file path "{file_path}"') return -1 - def getSaveFileFilterList(self) -> Sequence[str]: - return self._fileWriterChooser.getDisplayNameList() + def get_save_file_filters(self) -> Iterator[str]: + for plugin in self._file_writer_chooser: + yield plugin.display_name - def getSaveFileFilter(self) -> str: - return self._fileWriterChooser.currentPlugin.displayName + def get_save_file_filter(self) -> str: + return self._file_writer_chooser.get_current_plugin().display_name - def saveProduct(self, index: int, filePath: Path, *, fileType: str | None = None) -> None: + def save_product(self, index: int, file_path: Path, *, file_type: str | None = None) -> None: try: item = self._repository[index] except IndexError: logger.warning(f'Failed to save product {index}!') return - self._fileWriterChooser.setCurrentPluginByName( - self._settings.fileType.getValue() if fileType is None else fileType - ) - fileType = self._fileWriterChooser.currentPlugin.simpleName - logger.debug(f'Writing "{filePath}" as "{fileType}"') - writer = self._fileWriterChooser.currentPlugin.strategy - writer.write(filePath, item.getProduct()) + if file_type is not None: + self._file_writer_chooser.set_current_plugin(file_type) + + file_type = self._file_writer_chooser.get_current_plugin().simple_name + logger.debug(f'Writing "{file_path}" as "{file_type}"') + writer = self._file_writer_chooser.get_current_plugin().strategy + writer.write(file_path, item.get_product()) diff --git a/src/ptychodus/model/product/core.py b/src/ptychodus/model/product/core.py index f97cfa3e..f163679a 100644 --- a/src/ptychodus/model/product/core.py +++ b/src/ptychodus/model/product/core.py @@ -5,109 +5,116 @@ from ptychodus.api.plugins import PluginChooser from ptychodus.api.probe import FresnelZonePlate, ProbeFileReader, ProbeFileWriter from ptychodus.api.product import ProductFileReader, ProductFileWriter -from ptychodus.api.scan import ScanFileReader, ScanFileWriter +from ptychodus.api.scan import PositionFileReader, PositionFileWriter from ptychodus.api.settings import SettingsRegistry -from ..patterns import ActiveDiffractionDataset, Detector, PatternSizer, ProductSettings +from ..patterns import AssembledDiffractionDataset, PatternSizer from .api import ObjectAPI, ProbeAPI, ProductAPI, ScanAPI +from .item_factory import ProductRepositoryItemFactory from .object import ObjectBuilderFactory, ObjectRepositoryItemFactory, ObjectSettings -from .objectRepository import ObjectRepository +from .object_repository import ObjectRepository from .probe import ProbeBuilderFactory, ProbeRepositoryItemFactory, ProbeSettings -from .probeRepository import ProbeRepository -from .productRepository import ProductRepository +from .probe_repository import ProbeRepository +from .repository import ProductRepository from .scan import ScanBuilderFactory, ScanRepositoryItemFactory, ScanSettings -from .scanRepository import ScanRepository +from .scan_repository import ScanRepository +from .settings import ProductSettings class ProductCore(Observer): def __init__( self, rng: numpy.random.Generator, - settingsRegistry: SettingsRegistry, - detector: Detector, - settings: ProductSettings, - patternSizer: PatternSizer, - patterns: ActiveDiffractionDataset, - scanFileReaderChooser: PluginChooser[ScanFileReader], - scanFileWriterChooser: PluginChooser[ScanFileWriter], - fresnelZonePlateChooser: PluginChooser[FresnelZonePlate], - probeFileReaderChooser: PluginChooser[ProbeFileReader], - probeFileWriterChooser: PluginChooser[ProbeFileWriter], - objectFileReaderChooser: PluginChooser[ObjectFileReader], - objectFileWriterChooser: PluginChooser[ObjectFileWriter], - productFileReaderChooser: PluginChooser[ProductFileReader], - productFileWriterChooser: PluginChooser[ProductFileWriter], - reinitObservable: Observable, + settings_registry: SettingsRegistry, + pattern_sizer: PatternSizer, + dataset: AssembledDiffractionDataset, + scan_file_reader_chooser: PluginChooser[PositionFileReader], + scan_file_writer_chooser: PluginChooser[PositionFileWriter], + fresnel_zone_plate_chooser: PluginChooser[FresnelZonePlate], + probe_file_reader_chooser: PluginChooser[ProbeFileReader], + probe_file_writer_chooser: PluginChooser[ProbeFileWriter], + object_file_reader_chooser: PluginChooser[ObjectFileReader], + object_file_writer_chooser: PluginChooser[ObjectFileWriter], + product_file_reader_chooser: PluginChooser[ProductFileReader], + product_file_writer_chooser: PluginChooser[ProductFileWriter], + reinit_observable: Observable, ) -> None: super().__init__() - self._scanSettings = ScanSettings(settingsRegistry) - self._scanBuilderFactory = ScanBuilderFactory( - self._scanSettings, scanFileReaderChooser, scanFileWriterChooser + self.settings = ProductSettings(settings_registry) + + self._scan_settings = ScanSettings(settings_registry) + self._scan_builder_factory = ScanBuilderFactory( + self._scan_settings, scan_file_reader_chooser, scan_file_writer_chooser ) - self._scanRepositoryItemFactory = ScanRepositoryItemFactory( - rng, self._scanSettings, self._scanBuilderFactory + self._scan_repository_item_factory = ScanRepositoryItemFactory( + rng, self._scan_settings, self._scan_builder_factory ) - self._probeSettings = ProbeSettings(settingsRegistry) - self._probeBuilderFactory = ProbeBuilderFactory( - self._probeSettings, - detector, - patterns, - fresnelZonePlateChooser, - probeFileReaderChooser, - probeFileWriterChooser, + self._probe_settings = ProbeSettings(settings_registry) + self._probe_builder_factory = ProbeBuilderFactory( + self._probe_settings, + dataset, + fresnel_zone_plate_chooser, + probe_file_reader_chooser, + probe_file_writer_chooser, ) - self._probeRepositoryItemFactory = ProbeRepositoryItemFactory( - rng, self._probeSettings, self._probeBuilderFactory + self._probe_repository_item_factory = ProbeRepositoryItemFactory( + rng, self._probe_settings, self._probe_builder_factory ) - self._objectSettings = ObjectSettings(settingsRegistry) - self._objectBuilderFactory = ObjectBuilderFactory( - rng, self._objectSettings, objectFileReaderChooser, objectFileWriterChooser + self._object_settings = ObjectSettings(settings_registry) + self._object_builder_factory = ObjectBuilderFactory( + rng, self._object_settings, object_file_reader_chooser, object_file_writer_chooser ) - self._objectRepositoryItemFactory = ObjectRepositoryItemFactory( - rng, self._objectSettings, self._objectBuilderFactory + self._object_repository_item_factory = ObjectRepositoryItemFactory( + rng, self._object_settings, self._object_builder_factory ) - self.productRepository = ProductRepository( - settings, - patternSizer, - patterns, - self._scanRepositoryItemFactory, - self._probeRepositoryItemFactory, - self._objectRepositoryItemFactory, + self.product_repository = ProductRepository() + self._item_factory = ProductRepositoryItemFactory( + self.settings, + pattern_sizer, + dataset, + self._scan_repository_item_factory, + self._probe_repository_item_factory, + self._object_repository_item_factory, + self.product_repository, + product_file_reader_chooser, + ) + self.product_api = ProductAPI( + self.settings, + self.product_repository, + self._item_factory, + product_file_reader_chooser, + product_file_writer_chooser, ) - self.productAPI = ProductAPI( - settings, - self.productRepository, - productFileReaderChooser, - productFileWriterChooser, + self.scan_repository = ScanRepository(self.product_repository) + self.scan_api = ScanAPI( + self._scan_settings, self.scan_repository, self._scan_builder_factory ) - self.scanRepository = ScanRepository(self.productRepository) - self.scanAPI = ScanAPI(self._scanSettings, self.scanRepository, self._scanBuilderFactory) - self.probeRepository = ProbeRepository(self.productRepository) - self.probeAPI = ProbeAPI( - self._probeSettings, self.probeRepository, self._probeBuilderFactory + self.probe_repository = ProbeRepository(self.product_repository) + self.probe_api = ProbeAPI( + self._probe_settings, self.probe_repository, self._probe_builder_factory ) - self.objectRepository = ObjectRepository(self.productRepository) - self.objectAPI = ObjectAPI( - self._objectSettings, self.objectRepository, self._objectBuilderFactory + self.object_repository = ObjectRepository(self.product_repository) + self.object_api = ObjectAPI( + self._object_settings, self.object_repository, self._object_builder_factory ) # TODO vvv refactor vvv - productFileReaderChooser.setCurrentPluginByName(settings.fileType.getValue()) - productFileWriterChooser.setCurrentPluginByName(settings.fileType.getValue()) - scanFileReaderChooser.setCurrentPluginByName(self._scanSettings.fileType.getValue()) - scanFileWriterChooser.setCurrentPluginByName(self._scanSettings.fileType.getValue()) - probeFileReaderChooser.setCurrentPluginByName(self._probeSettings.fileType.getValue()) - probeFileWriterChooser.setCurrentPluginByName(self._probeSettings.fileType.getValue()) - objectFileReaderChooser.setCurrentPluginByName(self._objectSettings.fileType.getValue()) - objectFileWriterChooser.setCurrentPluginByName(self._objectSettings.fileType.getValue()) + product_file_reader_chooser.synchronize_with_parameter(self.settings.file_type) + product_file_writer_chooser.set_current_plugin(self.settings.file_type.get_value()) + scan_file_reader_chooser.synchronize_with_parameter(self._scan_settings.file_type) + scan_file_writer_chooser.set_current_plugin(self._scan_settings.file_type.get_value()) + probe_file_reader_chooser.synchronize_with_parameter(self._probe_settings.file_type) + probe_file_writer_chooser.set_current_plugin(self._probe_settings.file_type.get_value()) + object_file_reader_chooser.synchronize_with_parameter(self._object_settings.file_type) + object_file_writer_chooser.set_current_plugin(self._object_settings.file_type.get_value()) # TODO ^^^^^^^^^^^^^^^^ - self._reinitObservable = reinitObservable - reinitObservable.addObserver(self) + self._reinit_observable = reinit_observable + reinit_observable.add_observer(self) - def update(self, observable: Observable) -> None: - if observable is self._reinitObservable: - self.productRepository.insertProductFromSettings() + def _update(self, observable: Observable) -> None: + if observable is self._reinit_observable: + self.product_api.insert_product_from_settings() diff --git a/src/ptychodus/model/product/geometry.py b/src/ptychodus/model/product/geometry.py new file mode 100644 index 00000000..48940f19 --- /dev/null +++ b/src/ptychodus/model/product/geometry.py @@ -0,0 +1,161 @@ +import numpy + +from ptychodus.api.geometry import PixelGeometry +from ptychodus.api.object import ObjectGeometry, ObjectGeometryProvider +from ptychodus.api.observer import Observable, Observer +from ptychodus.api.probe import ProbeGeometry, ProbeGeometryProvider +from ptychodus.api.product import ( + ELECTRON_VOLT_J, + LIGHT_SPEED_M_PER_S, + PLANCK_CONSTANT_J_PER_HZ, +) + +from ..patterns import PatternSizer +from .metadata import MetadataRepositoryItem +from .scan import ScanRepositoryItem + + +class ProductGeometry(ProbeGeometryProvider, ObjectGeometryProvider, Observable, Observer): + def __init__( + self, + pattern_sizer: PatternSizer, + metadata_item: MetadataRepositoryItem, + scan_item: ScanRepositoryItem, + ) -> None: + super().__init__() + self._pattern_sizer = pattern_sizer + self._metadata_item = metadata_item + self._scan_item = scan_item + + self._pattern_sizer.add_observer(self) + self._metadata_item.add_observer(self) + self._scan_item.add_observer(self) + + @property + def probe_photon_count(self) -> float: + return self._metadata_item.probe_photon_count.get_value() + + @property + def probe_energy_J(self) -> float: # noqa: N802 + return self._metadata_item.probe_energy_eV.get_value() * ELECTRON_VOLT_J + + @property + def probe_wavelength_m(self) -> float: + hc_Jm = PLANCK_CONSTANT_J_PER_HZ * LIGHT_SPEED_M_PER_S # noqa: N806 + + try: + return hc_Jm / self.probe_energy_J + except ZeroDivisionError: + return 0.0 + + @property + def probe_wavelengths_per_m(self) -> float: + """wavenumber""" + return 1.0 / self.probe_wavelength_m + + @property + def probe_radians_per_m(self) -> float: + """angular wavenumber""" + return 2.0 * numpy.pi / self.probe_wavelength_m + + @property + def probe_photons_per_s(self) -> float: + try: + return self.probe_photon_count / self._metadata_item.exposure_time_s.get_value() + except ZeroDivisionError: + return 0.0 + + @property + def probe_power_W(self) -> float: # noqa: N802 + return self.probe_energy_J * self.probe_photons_per_s + + @property + def num_scan_points(self) -> int: + return len(self._scan_item.get_scan()) + + @property + def detector_distance_m(self) -> float: + return self._metadata_item.detector_distance_m.get_value() + + @property + def _lambda_z_m2(self) -> float: + return self.probe_wavelength_m * self.detector_distance_m + + @property + def object_plane_pixel_width_m(self) -> float: + return self._lambda_z_m2 / self._pattern_sizer.get_processed_width_m() + + @property + def object_plane_pixel_height_m(self) -> float: + return self._lambda_z_m2 / self._pattern_sizer.get_processed_height_m() + + def get_detector_pixel_geometry(self): + return self._pattern_sizer.get_processed_pixel_geometry() + + def get_object_plane_pixel_geometry(self) -> PixelGeometry: + return PixelGeometry( + width_m=self.object_plane_pixel_width_m, + height_m=self.object_plane_pixel_height_m, + ) + + @property + def fresnel_number(self) -> float: + width_m = self._pattern_sizer.get_processed_width_m() + height_m = self._pattern_sizer.get_processed_height_m() + area_m2 = width_m * height_m + return area_m2 / self._lambda_z_m2 + + def get_probe_geometry(self) -> ProbeGeometry: + extent = self._pattern_sizer.get_processed_image_extent() + return ProbeGeometry( + width_px=extent.width_px, + height_px=extent.height_px, + pixel_width_m=self.object_plane_pixel_width_m, + pixel_height_m=self.object_plane_pixel_height_m, + ) + + def is_probe_geometry_valid(self, geometry: ProbeGeometry) -> bool: + expected = self.get_probe_geometry() + width_is_valid = geometry.pixel_width_m > 0.0 and geometry.width_m == expected.width_m + height_is_valid = geometry.pixel_height_m > 0.0 and geometry.height_m == expected.height_m + return width_is_valid and height_is_valid + + def get_object_geometry(self) -> ObjectGeometry: + probe_geometry = self.get_probe_geometry() + width_m = probe_geometry.width_m + height_m = probe_geometry.height_m + center_x_m = 0.0 + center_y_m = 0.0 + + scan_bbox = self._scan_item.get_bounding_box() + + if scan_bbox is not None: + width_m += scan_bbox.width_m + height_m += scan_bbox.height_m + center_x_m = scan_bbox.center_x_m + center_y_m = scan_bbox.center_y_m + + width_px = width_m / self.object_plane_pixel_width_m + height_px = height_m / self.object_plane_pixel_height_m + + return ObjectGeometry( + width_px=int(numpy.ceil(width_px)), + height_px=int(numpy.ceil(height_px)), + pixel_width_m=self.object_plane_pixel_width_m, + pixel_height_m=self.object_plane_pixel_height_m, + center_x_m=center_x_m, + center_y_m=center_y_m, + ) + + def is_object_geometry_valid(self, geometry: ObjectGeometry) -> bool: + expected_geometry = self.get_object_geometry() + pixel_size_is_valid = geometry.pixel_width_m > 0.0 and geometry.pixel_height_m > 0.0 + return pixel_size_is_valid and geometry.contains(expected_geometry) + + def _update(self, observable: Observable) -> None: + if observable is self._metadata_item: + self.notify_observers() + elif observable is self._scan_item: + self.notify_observers() + elif observable is self._pattern_sizer: + self.notify_observers() diff --git a/src/ptychodus/model/product/item.py b/src/ptychodus/model/product/item.py index 096dbbbe..4e4b633a 100644 --- a/src/ptychodus/model/product/item.py +++ b/src/ptychodus/model/product/item.py @@ -7,35 +7,35 @@ from ptychodus.api.parametric import ParameterGroup from ptychodus.api.product import Product -from .metadata import MetadataRepositoryItem +from .geometry import ProductGeometry +from .metadata import MetadataRepositoryItem, UniqueNameFactory from .object import ObjectRepositoryItem from .probe import ProbeRepositoryItem -from .productGeometry import ProductGeometry -from .productValidator import ProductValidator from .scan import ScanRepositoryItem +from .validator import ProductValidator logger = logging.getLogger(__name__) -class ProductRepositoryItemObserver(ABC): +class ProductRepositoryItemObserver(UniqueNameFactory): @abstractmethod - def handleMetadataChanged(self, item: ProductRepositoryItem) -> None: + def handle_metadata_changed(self, item: ProductRepositoryItem) -> None: pass @abstractmethod - def handleScanChanged(self, item: ProductRepositoryItem) -> None: + def handle_scan_changed(self, item: ProductRepositoryItem) -> None: pass @abstractmethod - def handleProbeChanged(self, item: ProductRepositoryItem) -> None: + def handle_probe_changed(self, item: ProductRepositoryItem) -> None: pass @abstractmethod - def handleObjectChanged(self, item: ProductRepositoryItem) -> None: + def handle_object_changed(self, item: ProductRepositoryItem) -> None: pass @abstractmethod - def handleCostsChanged(self, item: ProductRepositoryItem) -> None: + def handle_costs_changed(self, item: ProductRepositoryItem) -> None: pass @@ -43,130 +43,124 @@ class ProductRepositoryItem(ParameterGroup): def __init__( self, parent: ProductRepositoryItemObserver, - metadata: MetadataRepositoryItem, - scan: ScanRepositoryItem, + metadata_item: MetadataRepositoryItem, + scan_item: ScanRepositoryItem, geometry: ProductGeometry, - probe: ProbeRepositoryItem, - object_: ObjectRepositoryItem, + probe_item: ProbeRepositoryItem, + object_item: ObjectRepositoryItem, validator: ProductValidator, costs: Sequence[float], ) -> None: super().__init__() self._parent = parent - self._metadata = metadata - self._scan = scan + self._metadata_item = metadata_item + self._scan_item = scan_item self._geometry = geometry - self._probe = probe - self._object = object_ + self._probe_item = probe_item + self._object_item = object_item self._validator = validator self._costs = list(costs) - self._addGroup('metadata', self._metadata, observe=True) - self._addGroup('scan', self._scan, observe=True) - self._addGroup('probe', self._probe, observe=True) - self._addGroup('object', self._object, observe=True) + self._add_group('metadata', self._metadata_item, observe=True) + self._add_group('scan', self._scan_item, observe=True) + self._add_group('probe', self._probe_item, observe=True) + self._add_group('object', self._object_item, observe=True) - def assignItem(self, item: ProductRepositoryItem, *, notify: bool = True) -> None: - self._metadata.assignItem(item.getMetadata()) - self._scan.assignItem(item.getScan()) - self._probe.assignItem(item.getProbe()) - self._object.assignItem(item.getObject()) - self._costs = list(item.getCosts()) - - if notify: - self._parent.handleCostsChanged(self) + self._index = -1 # used by ProductRepository def assign(self, product: Product) -> None: - self._metadata.assign(product.metadata) - self._scan.assign(product.scan) - self._probe.assign(product.probe) - self._object.assign(product.object_) + self._metadata_item.assign(product.metadata) + self._scan_item.assign(product.positions) + self._probe_item.assign(product.probes) + self._object_item.assign(product.object_) self._costs = list(product.costs) - self._parent.handleCostsChanged(self) + self._parent.handle_costs_changed(self) - def syncToSettings(self) -> None: - self._metadata.syncToSettings() - self._scan.syncToSettings() - self._probe.syncToSettings() - self._object.syncToSettings() + def sync_to_settings(self) -> None: + self._metadata_item.sync_to_settings() + self._scan_item.sync_to_settings() + self._probe_item.sync_to_settings() + self._object_item.sync_to_settings() - def getName(self) -> str: - return self._metadata.getName() + def get_name(self) -> str: + return self._metadata_item.name.get_value() - def setName(self, name: str) -> None: - self._metadata.setName(name) + def set_name(self, name: str) -> None: + self._metadata_item.name.set_value(name) - def getMetadata(self) -> MetadataRepositoryItem: - return self._metadata + def get_metadata_item(self) -> MetadataRepositoryItem: + return self._metadata_item - def getScan(self) -> ScanRepositoryItem: - return self._scan + def get_scan_item(self) -> ScanRepositoryItem: + return self._scan_item - def getGeometry(self) -> ProductGeometry: + def get_geometry(self) -> ProductGeometry: return self._geometry - def getProbe(self) -> ProbeRepositoryItem: - return self._probe + def get_probe_item(self) -> ProbeRepositoryItem: + return self._probe_item - def getObject(self) -> ObjectRepositoryItem: - return self._object + def get_object_item(self) -> ObjectRepositoryItem: + return self._object_item - def _invalidateCosts(self) -> None: + def _invalidate_costs(self) -> None: self._costs = list() - self._parent.handleCostsChanged(self) + self._parent.handle_costs_changed(self) - def getCosts(self) -> Sequence[float]: + def get_costs(self) -> Sequence[float]: return self._costs - def getProduct(self) -> Product: + def get_product(self) -> Product: return Product( - metadata=self._metadata.getMetadata(), - scan=self._scan.getScan(), - probe=self._probe.getProbe(), - object_=self._object.getObject(), - costs=self.getCosts(), + metadata=self._metadata_item.get_metadata(), + positions=self._scan_item.get_scan(), + probes=self._probe_item.get_probes(), + object_=self._object_item.get_object(), + costs=self.get_costs(), ) - def update(self, observable: Observable) -> None: - if observable is self._metadata: - self._invalidateCosts() - self._parent.handleMetadataChanged(self) - elif observable is self._scan: - self._invalidateCosts() - self._parent.handleScanChanged(self) - elif observable is self._probe: - self._invalidateCosts() - self._parent.handleProbeChanged(self) - elif observable is self._object: - self._invalidateCosts() - self._parent.handleObjectChanged(self) + def _update(self, observable: Observable) -> None: + if observable is self._metadata_item: + self._invalidate_costs() + self._parent.handle_metadata_changed(self) + elif observable is self._scan_item: + self._invalidate_costs() + self._parent.handle_scan_changed(self) + elif observable is self._probe_item: + self._invalidate_costs() + self._parent.handle_probe_changed(self) + elif observable is self._object_item: + self._invalidate_costs() + self._parent.handle_object_changed(self) + else: + super()._update(observable) class ProductRepositoryObserver(ABC): @abstractmethod - def handleItemInserted(self, index: int, item: ProductRepositoryItem) -> None: + def handle_item_inserted(self, index: int, item: ProductRepositoryItem) -> None: pass @abstractmethod - def handleMetadataChanged(self, index: int, item: MetadataRepositoryItem) -> None: + def handle_metadata_changed(self, index: int, item: MetadataRepositoryItem) -> None: pass @abstractmethod - def handleScanChanged(self, index: int, item: ScanRepositoryItem) -> None: + def handle_scan_changed(self, index: int, item: ScanRepositoryItem) -> None: pass @abstractmethod - def handleProbeChanged(self, index: int, item: ProbeRepositoryItem) -> None: + def handle_probe_changed(self, index: int, item: ProbeRepositoryItem) -> None: pass @abstractmethod - def handleObjectChanged(self, index: int, item: ObjectRepositoryItem) -> None: + def handle_object_changed(self, index: int, item: ObjectRepositoryItem) -> None: pass @abstractmethod - def handleCostsChanged(self, index: int, costs: Sequence[float]) -> None: + def handle_costs_changed(self, index: int, costs: Sequence[float]) -> None: pass @abstractmethod - def handleItemRemoved(self, index: int, item: ProductRepositoryItem) -> None: + def handle_item_removed(self, index: int, item: ProductRepositoryItem) -> None: pass diff --git a/src/ptychodus/model/product/item_factory.py b/src/ptychodus/model/product/item_factory.py new file mode 100644 index 00000000..0b25d1b3 --- /dev/null +++ b/src/ptychodus/model/product/item_factory.py @@ -0,0 +1,148 @@ +import logging + +from ptychodus.api.plugins import PluginChooser +from ptychodus.api.product import Product, ProductFileReader + +from ..patterns import AssembledDiffractionDataset, PatternSizer +from .geometry import ProductGeometry +from .item import ProductRepositoryItem +from .metadata import MetadataRepositoryItem +from .object import ObjectRepositoryItemFactory +from .probe import ProbeRepositoryItemFactory +from .repository import ProductRepository +from .scan import ScanRepositoryItemFactory +from .settings import ProductSettings +from .validator import ProductValidator + +logger = logging.getLogger(__name__) + + +class ProductRepositoryItemFactory: + def __init__( + self, + settings: ProductSettings, + pattern_sizer: PatternSizer, + dataset: AssembledDiffractionDataset, + scan_item_factory: ScanRepositoryItemFactory, + probe_item_factory: ProbeRepositoryItemFactory, + object_item_factory: ObjectRepositoryItemFactory, + repository: ProductRepository, + file_reader_chooser: PluginChooser[ProductFileReader], + ) -> None: + super().__init__() + self._settings = settings + self._pattern_sizer = pattern_sizer + self._dataset = dataset + self._scan_item_factory = scan_item_factory + self._probe_item_factory = probe_item_factory + self._object_item_factory = object_item_factory + self._repository = repository + self._file_reader_chooser = file_reader_chooser + + def create_from_values( + self, + *, + name: str = '', + comments: str = '', + detector_distance_m: float | None = None, + probe_energy_eV: float | None = None, # noqa: N803 + probe_photon_count: float | None = None, + exposure_time_s: float | None = None, + mass_attenuation_m2_kg: float | None = None, + tomography_angle_deg: float | None = None, + ) -> ProductRepositoryItem: + metadata_item = MetadataRepositoryItem( + self._settings, + self._repository, + name=name, + comments=comments, + detector_distance_m=detector_distance_m, + probe_energy_eV=probe_energy_eV, + probe_photon_count=probe_photon_count, + exposure_time_s=exposure_time_s, + mass_attenuation_m2_kg=mass_attenuation_m2_kg, + tomography_angle_deg=tomography_angle_deg, + ) + + if metadata_item.probe_photon_count.get_value() <= 0: + metadata_item.probe_photon_count.set_value(self._dataset.get_maximum_pattern_counts()) + + scan_item = self._scan_item_factory.create() + geometry = ProductGeometry(self._pattern_sizer, metadata_item, scan_item) + probe_item = self._probe_item_factory.create(geometry) + object_item = self._object_item_factory.create(geometry) + validator = ProductValidator(self._dataset, scan_item, geometry, probe_item, object_item) + + return ProductRepositoryItem( + parent=self._repository, + metadata_item=metadata_item, + scan_item=scan_item, + geometry=geometry, + probe_item=probe_item, + object_item=object_item, + validator=validator, + costs=list(), + ) + + def create_from_product(self, product: Product) -> ProductRepositoryItem: + metadata_item = MetadataRepositoryItem( + self._settings, + self._repository, + name=product.metadata.name, + comments=product.metadata.comments, + detector_distance_m=product.metadata.detector_distance_m, + probe_energy_eV=product.metadata.probe_energy_eV, + probe_photon_count=product.metadata.probe_photon_count, + exposure_time_s=product.metadata.exposure_time_s, + mass_attenuation_m2_kg=product.metadata.mass_attenuation_m2_kg, + tomography_angle_deg=product.metadata.tomography_angle_deg, + ) + + scan_item = self._scan_item_factory.create(product.positions) + geometry = ProductGeometry(self._pattern_sizer, metadata_item, scan_item) + probe_item = self._probe_item_factory.create(geometry, product.probes) + object_item = self._object_item_factory.create(geometry, product.object_) + validator = ProductValidator(self._dataset, scan_item, geometry, probe_item, object_item) + + return ProductRepositoryItem( + parent=self._repository, + metadata_item=metadata_item, + scan_item=scan_item, + geometry=geometry, + probe_item=probe_item, + object_item=object_item, + validator=validator, + costs=product.costs, + ) + + def create_from_settings(self) -> ProductRepositoryItem: + file_path = self._settings.file_path.get_value() + + if file_path.is_file(): + file_type = self._file_reader_chooser.get_current_plugin().simple_name + logger.debug(f'Reading "{file_path}" as "{file_type}"') + file_reader = self._file_reader_chooser.get_current_plugin().strategy + + try: + product = file_reader.read(file_path) + except Exception as exc: + raise RuntimeError(f'Failed to read "{file_path}"') from exc + else: + return self.create_from_product(product) + + metadata_item = MetadataRepositoryItem(self._settings, self._repository) + scan_item = self._scan_item_factory.create_from_settings() + geometry = ProductGeometry(self._pattern_sizer, metadata_item, scan_item) + probe_item = self._probe_item_factory.create_from_settings(geometry) + object_item = self._object_item_factory.create_from_settings(geometry) + + return ProductRepositoryItem( + parent=self._repository, + metadata_item=metadata_item, + scan_item=scan_item, + geometry=geometry, + probe_item=probe_item, + object_item=object_item, + validator=ProductValidator(self._dataset, scan_item, geometry, probe_item, object_item), + costs=list(), + ) diff --git a/src/ptychodus/model/product/metadata.py b/src/ptychodus/model/product/metadata.py index 7cd517b3..9ca47871 100644 --- a/src/ptychodus/model/product/metadata.py +++ b/src/ptychodus/model/product/metadata.py @@ -1,115 +1,139 @@ from __future__ import annotations -from abc import ABC, abstractmethod +from abc import abstractmethod, ABC import logging -from ptychodus.api.parametric import ParameterGroup +from ptychodus.api.parametric import Parameter, ParameterGroup from ptychodus.api.product import ProductMetadata -from ..patterns import ProductSettings +from .settings import ProductSettings logger = logging.getLogger(__name__) class UniqueNameFactory(ABC): @abstractmethod - def createUniqueName(self, candidateName: str) -> str: + def create_unique_name(self, candidate_name: str) -> str: pass +class UniqueStringParameter(Parameter[str]): + def __init__( + self, value: str | None, name_factory: UniqueNameFactory, parent: Parameter[str] + ) -> None: + super().__init__(parent) + self._value = name_factory.create_unique_name(value or parent.get_value()) + self._name_factory = name_factory + + def get_value(self) -> str: + return self._value + + def set_value(self, value: str, *, notify: bool = True) -> None: + if value: + if self._value != value: + self._value = self._name_factory.create_unique_name(value) + + if notify: + self.notify_observers() + else: + self.notify_observers() + + def get_value_as_string(self) -> str: + return str(self._value) + + def set_value_from_string(self, value: str) -> None: + self.set_value(str(value)) + + def copy(self) -> UniqueStringParameter: + return UniqueStringParameter(self.get_value(), self._name_factory, self) + + class MetadataRepositoryItem(ParameterGroup): def __init__( self, settings: ProductSettings, - nameFactory: UniqueNameFactory, + name_factory: UniqueNameFactory, *, name: str = '', comments: str = '', - detectorDistanceInMeters: float | None = None, - probeEnergyInElectronVolts: float | None = None, - probePhotonsPerSecond: float | None = None, - exposureTimeInSeconds: float | None = None, + detector_distance_m: float | None = None, + probe_energy_eV: float | None = None, # noqa: N803 + probe_photon_count: float | None = None, + exposure_time_s: float | None = None, + mass_attenuation_m2_kg: float | None = None, + tomography_angle_deg: float | None = None, ) -> None: super().__init__() self._settings = settings - self._nameFactory = nameFactory - self._name = settings.name.copy() - self._setName(name if name else settings.name.getValue()) - self._addParameter('name', self._name) - self.comments = self.createStringParameter('comments', comments) + self.name = UniqueStringParameter(name, name_factory, settings.name.copy()) + self._add_parameter('name', self.name) - self.detectorDistanceInMeters = settings.detectorDistanceInMeters.copy() + self.comments = self.create_string_parameter('comments', comments) - if detectorDistanceInMeters is not None: - self.detectorDistanceInMeters.setValue(detectorDistanceInMeters) + self.detector_distance_m = settings.detector_distance_m.copy() - self._addParameter('detector_distance_m', self.detectorDistanceInMeters) + if detector_distance_m is not None: + self.detector_distance_m.set_value(detector_distance_m) - self.probeEnergyInElectronVolts = settings.probeEnergyInElectronVolts.copy() + self._add_parameter('detector_distance_m', self.detector_distance_m) - if probeEnergyInElectronVolts is not None: - self.probeEnergyInElectronVolts.setValue(probeEnergyInElectronVolts) + self.probe_energy_eV = settings.probe_energy_eV.copy() - self._addParameter('probe_energy_eV', self.probeEnergyInElectronVolts) + if probe_energy_eV is not None: + self.probe_energy_eV.set_value(probe_energy_eV) - self.probePhotonsPerSecond = settings.probePhotonsPerSecond.copy() + self._add_parameter('probe_energy_eV', self.probe_energy_eV) - if probePhotonsPerSecond is not None: - self.probePhotonsPerSecond.setValue(probePhotonsPerSecond) + self.probe_photon_count = settings.probe_photon_count.copy() - self._addParameter('probe_photons_per_second', self.probePhotonsPerSecond) + if probe_photon_count is not None: + self.probe_photon_count.set_value(probe_photon_count) - self.exposureTimeInSeconds = settings.exposureTimeInSeconds.copy() + self._add_parameter('probe_photon_count', self.probe_photon_count) - if exposureTimeInSeconds is not None: - self.exposureTimeInSeconds.setValue(exposureTimeInSeconds) + self.exposure_time_s = settings.exposure_time_s.copy() - self._addParameter('exposure_time_s', self.exposureTimeInSeconds) + if exposure_time_s is not None: + self.exposure_time_s.set_value(exposure_time_s) - self._index = -1 + self._add_parameter('exposure_time_s', self.exposure_time_s) - def assignItem(self, item: MetadataRepositoryItem, *, notify: bool = True) -> None: - self.setName(item.getName()) - self.comments.setValue(item.comments.getValue()) - self.detectorDistanceInMeters.setValue(item.detectorDistanceInMeters.getValue()) - self.probeEnergyInElectronVolts.setValue(item.probeEnergyInElectronVolts.getValue()) - self.probePhotonsPerSecond.setValue(item.probePhotonsPerSecond.getValue()) - self.exposureTimeInSeconds.setValue(item.exposureTimeInSeconds.getValue()) + self.mass_attenuation_m2_kg = settings.mass_attenuation_m2_kg.copy() - def assign(self, metadata: ProductMetadata) -> None: - self.setName(metadata.name) - self.comments.setValue(metadata.comments) - self.detectorDistanceInMeters.setValue(metadata.detectorDistanceInMeters) - self.probeEnergyInElectronVolts.setValue(metadata.probeEnergyInElectronVolts) - self.probePhotonsPerSecond.setValue(metadata.probePhotonsPerSecond) - self.exposureTimeInSeconds.setValue(metadata.exposureTimeInSeconds) - - def syncToSettings(self) -> None: - for parameter in self.parameters().values(): - parameter.syncValueToParent() + if mass_attenuation_m2_kg is not None: + self.mass_attenuation_m2_kg.set_value(mass_attenuation_m2_kg) - def getName(self) -> str: - return self._name.getValue() + self._add_parameter('mass_attenuation_m2_kg', self.mass_attenuation_m2_kg) - def _setName(self, name: str) -> None: - uniqueName = self._nameFactory.createUniqueName(name) - self._name.setValue(uniqueName) + self.tomography_angle_deg = settings.tomography_angle_deg.copy() - def setName(self, name: str) -> None: - if name: - self._setName(name) - else: - self._name.notifyObservers() + if tomography_angle_deg is not None: + self.tomography_angle_deg.set_value(tomography_angle_deg) + + self._add_parameter('tomography_angle_deg', self.tomography_angle_deg) - def getIndex(self) -> int: - return self._index + def assign(self, metadata: ProductMetadata) -> None: + self.name.set_value(metadata.name) + self.comments.set_value(metadata.comments) + self.detector_distance_m.set_value(metadata.detector_distance_m) + self.probe_energy_eV.set_value(metadata.probe_energy_eV) + self.probe_photon_count.set_value(metadata.probe_photon_count) + self.exposure_time_s.set_value(metadata.exposure_time_s) + self.mass_attenuation_m2_kg.set_value(metadata.mass_attenuation_m2_kg) + self.tomography_angle_deg.set_value(metadata.tomography_angle_deg) + + def sync_to_settings(self) -> None: + for parameter in self.parameters().values(): + parameter.sync_value_to_parent() - def getMetadata(self) -> ProductMetadata: + def get_metadata(self) -> ProductMetadata: return ProductMetadata( - name=self._name.getValue(), - comments=self.comments.getValue(), - detectorDistanceInMeters=self.detectorDistanceInMeters.getValue(), - probeEnergyInElectronVolts=self.probeEnergyInElectronVolts.getValue(), - probePhotonsPerSecond=self.probePhotonsPerSecond.getValue(), - exposureTimeInSeconds=self.exposureTimeInSeconds.getValue(), + name=self.name.get_value(), + comments=self.comments.get_value(), + detector_distance_m=self.detector_distance_m.get_value(), + probe_energy_eV=self.probe_energy_eV.get_value(), + probe_photon_count=self.probe_photon_count.get_value(), + exposure_time_s=self.exposure_time_s.get_value(), + mass_attenuation_m2_kg=self.mass_attenuation_m2_kg.get_value(), + tomography_angle_deg=self.tomography_angle_deg.get_value(), ) diff --git a/src/ptychodus/model/product/metadataFactory.py b/src/ptychodus/model/product/metadataFactory.py deleted file mode 100644 index 45e86cce..00000000 --- a/src/ptychodus/model/product/metadataFactory.py +++ /dev/null @@ -1,91 +0,0 @@ -from collections.abc import Sequence -import logging - -from ptychodus.api.product import ProductMetadata - -from ..patterns import ProductSettings -from .item import ProductRepositoryItem, ProductRepositoryObserver -from .metadata import MetadataRepositoryItem, UniqueNameFactory -from .object import ObjectRepositoryItem -from .probe import ProbeRepositoryItem -from .scan import ScanRepositoryItem - -logger = logging.getLogger(__name__) - - -class MetadataRepositoryItemFactory(UniqueNameFactory, ProductRepositoryObserver): - def __init__( - self, repository: Sequence[ProductRepositoryItem], settings: ProductSettings - ) -> None: - self._repository = repository - self._settings = settings - - def create(self, metadata: ProductMetadata) -> MetadataRepositoryItem: - return MetadataRepositoryItem( - self._settings, - self, - name=metadata.name, - comments=metadata.comments, - detectorDistanceInMeters=metadata.detectorDistanceInMeters, - probeEnergyInElectronVolts=metadata.probeEnergyInElectronVolts, - probePhotonsPerSecond=metadata.probePhotonsPerSecond, - exposureTimeInSeconds=metadata.exposureTimeInSeconds, - ) - - def createDefault( - self, - *, - name: str = '', - comments: str = '', - detectorDistanceInMeters: float | None = None, - probeEnergyInElectronVolts: float | None = None, - probePhotonsPerSecond: float | None = None, - exposureTimeInSeconds: float | None = None, - ) -> MetadataRepositoryItem: - return MetadataRepositoryItem( - self._settings, - self, - name=name, - comments=comments, - detectorDistanceInMeters=detectorDistanceInMeters, - probeEnergyInElectronVolts=probeEnergyInElectronVolts, - probePhotonsPerSecond=probePhotonsPerSecond, - exposureTimeInSeconds=exposureTimeInSeconds, - ) - - def createUniqueName(self, candidateName: str) -> str: - reservedNames = set([item.getName() for item in self._repository]) - name = candidateName if candidateName else 'Unnamed' - match = 0 - - while name in reservedNames: - match += 1 - name = candidateName + f'-{match}' - - return name - - def _updateLUT(self) -> None: - for index, item in enumerate(self._repository): - metadata = item.getMetadata() - metadata._index = index - - def handleItemInserted(self, index: int, item: ProductRepositoryItem) -> None: - self._updateLUT() - - def handleMetadataChanged(self, index: int, item: MetadataRepositoryItem) -> None: - pass - - def handleScanChanged(self, index: int, item: ScanRepositoryItem) -> None: - pass - - def handleProbeChanged(self, index: int, item: ProbeRepositoryItem) -> None: - pass - - def handleObjectChanged(self, index: int, item: ObjectRepositoryItem) -> None: - pass - - def handleCostsChanged(self, index: int, costs: Sequence[float]) -> None: - pass - - def handleItemRemoved(self, index: int, item: ProductRepositoryItem) -> None: - self._updateLUT() diff --git a/src/ptychodus/model/product/object/__init__.py b/src/ptychodus/model/product/object/__init__.py index dd98c16e..ea2c54ef 100644 --- a/src/ptychodus/model/product/object/__init__.py +++ b/src/ptychodus/model/product/object/__init__.py @@ -1,7 +1,7 @@ from .builder import ObjectBuilder -from .builderFactory import ObjectBuilderFactory +from .builder_factory import ObjectBuilderFactory from .item import ObjectRepositoryItem -from .itemFactory import ObjectRepositoryItemFactory +from .item_factory import ObjectRepositoryItemFactory from .random import RandomObjectBuilder from .settings import ObjectSettings diff --git a/src/ptychodus/model/product/object/builder.py b/src/ptychodus/model/product/object/builder.py index 7a63de7b..727b3ab9 100644 --- a/src/ptychodus/model/product/object/builder.py +++ b/src/ptychodus/model/product/object/builder.py @@ -16,15 +16,15 @@ class ObjectBuilder(ParameterGroup): def __init__(self, settings: ObjectSettings, name: str) -> None: super().__init__() self._name = settings.builder.copy() - self._name.setValue(name) - self._addParameter('name', self._name) + self._name.set_value(name) + self._add_parameter('name', self._name) - def getName(self) -> str: - return self._name.getValue() + def get_name(self) -> str: + return self._name.get_value() - def syncToSettings(self) -> None: + def sync_to_settings(self) -> None: for parameter in self.parameters().values(): - parameter.syncValueToParent() + parameter.sync_value_to_parent() @abstractmethod def copy(self) -> ObjectBuilder: @@ -33,8 +33,8 @@ def copy(self) -> ObjectBuilder: @abstractmethod def build( self, - geometryProvider: ObjectGeometryProvider, - layerDistanceInMeters: Sequence[float], + geometry_provider: ObjectGeometryProvider, + layer_spacing_m: Sequence[float], ) -> Object: pass @@ -50,43 +50,67 @@ def copy(self) -> FromMemoryObjectBuilder: def build( self, - geometryProvider: ObjectGeometryProvider, - layerDistanceInMeters: Sequence[float], + geometry_provider: ObjectGeometryProvider, + layer_spacing_m: Sequence[float], ) -> Object: return self._object class FromFileObjectBuilder(ObjectBuilder): def __init__( - self, settings: ObjectSettings, filePath: Path, fileType: str, fileReader: ObjectFileReader + self, + settings: ObjectSettings, + file_path: Path, + file_type: str, + file_reader: ObjectFileReader, ) -> None: super().__init__(settings, 'from_file') self._settings = settings - self.filePath = settings.filePath.copy() - self.filePath.setValue(filePath) - self._addParameter('file_path', self.filePath) - self.fileType = settings.fileType.copy() - self.fileType.setValue(fileType) - self._addParameter('file_type', self.fileType) - self._fileReader = fileReader + self.file_path = settings.file_path.copy() + self.file_path.set_value(file_path) + self._add_parameter('file_path', self.file_path) + self.file_type = settings.file_type.copy() + self.file_type.set_value(file_type) + self._add_parameter('file_type', self.file_type) + self._file_reader = file_reader def copy(self) -> FromFileObjectBuilder: return FromFileObjectBuilder( - self._settings, self.filePath.getValue(), self.fileType.getValue(), self._fileReader + self._settings, + self.file_path.get_value(), + self.file_type.get_value(), + self._file_reader, ) def build( self, - geometryProvider: ObjectGeometryProvider, - layerDistanceInMeters: Sequence[float], + geometry_provider: ObjectGeometryProvider, + layer_spacing_m: Sequence[float], ) -> Object: - filePath = self.filePath.getValue() - fileType = self.fileType.getValue() - logger.debug(f'Reading "{filePath}" as "{fileType}"') + file_path = self.file_path.get_value() + file_type = self.file_type.get_value() + logger.debug(f'Reading "{file_path}" as "{file_type}"') try: - object_ = self._fileReader.read(filePath) + object_from_file = self._file_reader.read(file_path) except Exception as exc: - raise RuntimeError(f'Failed to read "{filePath}"') from exc + raise RuntimeError(f'Failed to read "{file_path}"') from exc + + object_geometry = geometry_provider.get_object_geometry() + + try: + pixel_geometry = object_from_file.get_pixel_geometry() + except ValueError: + pixel_geometry = object_geometry.get_pixel_geometry() - return object_ + try: + center = object_from_file.get_center() + except ValueError: + center = object_geometry.get_center() + + return Object( + object_from_file.get_array(), + pixel_geometry, + center, + object_from_file.layer_spacing_m, + ) diff --git a/src/ptychodus/model/product/object/builderFactory.py b/src/ptychodus/model/product/object/builderFactory.py deleted file mode 100644 index baa8a93b..00000000 --- a/src/ptychodus/model/product/object/builderFactory.py +++ /dev/null @@ -1,81 +0,0 @@ -from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence -from pathlib import Path -import logging - -import numpy - -from ptychodus.api.object import Object, ObjectFileReader, ObjectFileWriter -from ptychodus.api.plugins import PluginChooser - -from .builder import FromFileObjectBuilder, ObjectBuilder -from .random import RandomObjectBuilder -from .settings import ObjectSettings - -logger = logging.getLogger(__name__) - - -class ObjectBuilderFactory(Iterable[str]): - def __init__( - self, - rng: numpy.random.Generator, - settings: ObjectSettings, - fileReaderChooser: PluginChooser[ObjectFileReader], - fileWriterChooser: PluginChooser[ObjectFileWriter], - ) -> None: - self._settings = settings - self._fileReaderChooser = fileReaderChooser - self._fileWriterChooser = fileWriterChooser - self._builders: Mapping[str, Callable[[], ObjectBuilder]] = { - 'random': lambda: RandomObjectBuilder(rng, settings), - } - - def __iter__(self) -> Iterator[str]: - return iter(self._builders) - - def create(self, name: str) -> ObjectBuilder: - try: - factory = self._builders[name] - except KeyError as exc: - raise KeyError(f'Unknown object builder "{name}"!') from exc - - return factory() - - def createDefault(self) -> ObjectBuilder: - return next(iter(self._builders.values()))() - - def createFromSettings(self) -> ObjectBuilder: - name = self._settings.builder.getValue() - nameRepaired = name.casefold() - - if nameRepaired == 'from_file': - return self.createObjectFromFile( - self._settings.filePath.getValue(), - self._settings.fileType.getValue(), - ) - - return self.create(nameRepaired) - - def getOpenFileFilterList(self) -> Sequence[str]: - return self._fileReaderChooser.getDisplayNameList() - - def getOpenFileFilter(self) -> str: - return self._fileReaderChooser.currentPlugin.displayName - - def createObjectFromFile(self, filePath: Path, fileFilter: str) -> ObjectBuilder: - self._fileReaderChooser.setCurrentPluginByName(fileFilter) - fileType = self._fileReaderChooser.currentPlugin.simpleName - fileReader = self._fileReaderChooser.currentPlugin.strategy - return FromFileObjectBuilder(self._settings, filePath, fileType, fileReader) - - def getSaveFileFilterList(self) -> Sequence[str]: - return self._fileWriterChooser.getDisplayNameList() - - def getSaveFileFilter(self) -> str: - return self._fileWriterChooser.currentPlugin.displayName - - def saveObject(self, filePath: Path, fileFilter: str, object_: Object) -> None: - self._fileWriterChooser.setCurrentPluginByName(fileFilter) - fileType = self._fileWriterChooser.currentPlugin.simpleName - logger.debug(f'Writing "{filePath}" as "{fileType}"') - fileWriter = self._fileWriterChooser.currentPlugin.strategy - fileWriter.write(filePath, object_) diff --git a/src/ptychodus/model/product/object/builder_factory.py b/src/ptychodus/model/product/object/builder_factory.py new file mode 100644 index 00000000..65b4b099 --- /dev/null +++ b/src/ptychodus/model/product/object/builder_factory.py @@ -0,0 +1,83 @@ +from collections.abc import Callable, Iterable, Iterator, Mapping +from pathlib import Path +import logging + +import numpy + +from ptychodus.api.object import Object, ObjectFileReader, ObjectFileWriter +from ptychodus.api.plugins import PluginChooser + +from .builder import FromFileObjectBuilder, ObjectBuilder +from .random import RandomObjectBuilder +from .settings import ObjectSettings + +logger = logging.getLogger(__name__) + + +class ObjectBuilderFactory(Iterable[str]): + def __init__( + self, + rng: numpy.random.Generator, + settings: ObjectSettings, + file_reader_chooser: PluginChooser[ObjectFileReader], + file_writer_chooser: PluginChooser[ObjectFileWriter], + ) -> None: + self._settings = settings + self._file_reader_chooser = file_reader_chooser + self._file_writer_chooser = file_writer_chooser + self._builders: Mapping[str, Callable[[], ObjectBuilder]] = { + 'random': lambda: RandomObjectBuilder(rng, settings), + } + + def __iter__(self) -> Iterator[str]: + return iter(self._builders) + + def create(self, name: str) -> ObjectBuilder: + try: + factory = self._builders[name] + except KeyError as exc: + raise KeyError(f'Unknown object builder "{name}"!') from exc + + return factory() + + def create_default(self) -> ObjectBuilder: + return next(iter(self._builders.values()))() + + def create_from_settings(self) -> ObjectBuilder: + name = self._settings.builder.get_value() + name_repaired = name.casefold() + + if name_repaired == 'from_file': + return self.create_object_from_file( + self._settings.file_path.get_value(), + self._settings.file_type.get_value(), + ) + + return self.create(name_repaired) + + def get_open_file_filters(self) -> Iterator[str]: + for plugin in self._file_reader_chooser: + yield plugin.display_name + + def get_open_file_filter(self) -> str: + return self._file_reader_chooser.get_current_plugin().display_name + + def create_object_from_file(self, file_path: Path, file_filter: str) -> ObjectBuilder: + self._file_reader_chooser.set_current_plugin(file_filter) + file_type = self._file_reader_chooser.get_current_plugin().simple_name + file_reader = self._file_reader_chooser.get_current_plugin().strategy + return FromFileObjectBuilder(self._settings, file_path, file_type, file_reader) + + def get_save_file_filters(self) -> Iterator[str]: + for plugin in self._file_writer_chooser: + yield plugin.display_name + + def get_save_file_filter(self) -> str: + return self._file_writer_chooser.get_current_plugin().display_name + + def save_object(self, file_path: Path, file_filter: str, object_: Object) -> None: + self._file_writer_chooser.set_current_plugin(file_filter) + file_type = self._file_writer_chooser.get_current_plugin().simple_name + logger.debug(f'Writing "{file_path}" as "{file_type}"') + file_writer = self._file_writer_chooser.get_current_plugin().strategy + file_writer.write(file_path, object_) diff --git a/src/ptychodus/model/product/object/item.py b/src/ptychodus/model/product/object/item.py index d4e02bce..cc720bf7 100644 --- a/src/ptychodus/model/product/object/item.py +++ b/src/ptychodus/model/product/object/item.py @@ -1,7 +1,6 @@ from __future__ import annotations import logging -import numpy from ptychodus.api.object import Object, ObjectGeometryProvider from ptychodus.api.observer import Observable @@ -16,89 +15,95 @@ class ObjectRepositoryItem(ParameterGroup): def __init__( self, - geometryProvider: ObjectGeometryProvider, + geometry_provider: ObjectGeometryProvider, settings: ObjectSettings, builder: ObjectBuilder, ) -> None: super().__init__() - self._geometryProvider = geometryProvider + self._geometry_provider = geometry_provider self._settings = settings self._builder = builder - self._object = Object() + self._object = Object(array=None, pixel_geometry=None, center=None) - self._addGroup('builder', builder, observe=True) - # TODO sync layer distance to/from settings - self.layerDistanceInMeters = self.createRealArrayParameter('layer_distance_m', [numpy.inf]) + self.layer_spacing_m = settings.object_layer_spacing_m.copy() + self._add_parameter('layer_spacing_m', self.layer_spacing_m) - self._rebuild() + self._add_group('builder', builder, observe=True) + self.rebuild() - def assignItem(self, item: ObjectRepositoryItem) -> None: - self.layerDistanceInMeters.setValue(item.layerDistanceInMeters.getValue(), notify=False) - self.setBuilder(item.getBuilder().copy()) + def assign_item(self, item: ObjectRepositoryItem) -> None: + self.layer_spacing_m.set_value(item.layer_spacing_m.get_value(), notify=False) + self.set_builder(item.get_builder().copy()) + self.rebuild() def assign(self, object_: Object) -> None: builder = FromMemoryObjectBuilder(self._settings, object_) - self.setBuilder(builder) + self.set_builder(builder) - def syncToSettings(self) -> None: + def sync_to_settings(self) -> None: for parameter in self.parameters().values(): - parameter.syncValueToParent() + parameter.sync_value_to_parent() - self._builder.syncToSettings() + self._builder.sync_to_settings() - def getNumberOfLayers(self) -> int: - return len(self.layerDistanceInMeters) + def get_num_layers(self) -> int: + return len(self.layer_spacing_m) + 1 - def setNumberOfLayers(self, number: int) -> None: - numRequested = max(1, number) - distanceInMeters = list(self.layerDistanceInMeters.getValue()) - numExisting = len(distanceInMeters) - defaultDistanceInMeters = float(self._settings.objectLayerDistanceInMeters.getValue()) + def set_num_layers(self, num_layers: int) -> None: + num_spaces = max(0, num_layers - 1) + distance_m = list(self.layer_spacing_m.get_value()) - if numExisting < 2: - distanceInMeters = [defaultDistanceInMeters] * numRequested - elif numExisting < numRequested: - distanceInMeters[-1] = distanceInMeters[-2] # overwrite inf - distanceInMeters.extend(distanceInMeters[-1:] * (numRequested - numExisting)) - elif numExisting > numRequested: - distanceInMeters = distanceInMeters[:numRequested] + try: + default_distance_m = distance_m[-1] + except IndexError: + default_distance_m = 0.0 + + while len(distance_m) < num_spaces: + distance_m.append(default_distance_m) - distanceInMeters[-1] = numpy.inf - self.layerDistanceInMeters.setValue(distanceInMeters) - self._rebuild() + if len(distance_m) > num_spaces: + distance_m = distance_m[:num_spaces] - def getObject(self) -> Object: + self.layer_spacing_m.set_value(distance_m) + self.rebuild() + + def get_object(self) -> Object: return self._object - def getBuilder(self) -> ObjectBuilder: + def get_builder(self) -> ObjectBuilder: return self._builder - def setBuilder(self, builder: ObjectBuilder) -> None: - self._removeGroup('builder') - self._builder.removeObserver(self) + def set_builder(self, builder: ObjectBuilder) -> None: + self._remove_group('builder') + self._builder.remove_observer(self) self._builder = builder - self._builder.addObserver(self) - self._addGroup('builder', self._builder, observe=True) - self._rebuild() - - def _rebuild(self) -> None: - layerDistanceInMeters = list(self.layerDistanceInMeters.getValue()) - - if len(layerDistanceInMeters) < 1: - layerDistanceInMeters.append(numpy.inf) + self._builder.add_observer(self) + self._add_group('builder', self._builder, observe=True) + self.rebuild() + def rebuild(self, *, recenter: bool = False) -> None: try: - object_ = self._builder.build(self._geometryProvider, layerDistanceInMeters) + object_ = self._builder.build(self._geometry_provider, self.layer_spacing_m.get_value()) except Exception as exc: - logger.error(''.join(exc.args)) + logger.exception('Failed to rebuild object!') return - self._object = object_ - self.layerDistanceInMeters.setValue(object_.layerDistanceInMeters) - self.notifyObservers() + if recenter: + object_geometry = self._geometry_provider.get_object_geometry() + self._object = Object( + array=object_.get_array(), + layer_spacing_m=object_.layer_spacing_m, + pixel_geometry=object_.get_pixel_geometry(), + center=object_geometry.get_center(), + ) + else: + self._object = object_ + + self.layer_spacing_m.set_value(object_.layer_spacing_m) + self.notify_observers() - def update(self, observable: Observable) -> None: + def _update(self, observable: Observable) -> None: if observable is self._builder: - self._rebuild() + self.rebuild() else: - super().update(observable) + super()._update(observable) diff --git a/src/ptychodus/model/product/object/itemFactory.py b/src/ptychodus/model/product/object/itemFactory.py deleted file mode 100644 index b79d9ed8..00000000 --- a/src/ptychodus/model/product/object/itemFactory.py +++ /dev/null @@ -1,43 +0,0 @@ -import logging - -import numpy - -from ptychodus.api.object import Object, ObjectGeometryProvider - -from .builder import FromMemoryObjectBuilder -from .builderFactory import ObjectBuilderFactory -from .item import ObjectRepositoryItem -from .settings import ObjectSettings - -logger = logging.getLogger(__name__) - - -class ObjectRepositoryItemFactory: - def __init__( - self, - rng: numpy.random.Generator, - settings: ObjectSettings, - builderFactory: ObjectBuilderFactory, - ) -> None: - self._rng = rng - self._settings = settings - self._builderFactory = builderFactory - - def create( - self, geometryProvider: ObjectGeometryProvider, object_: Object | None = None - ) -> ObjectRepositoryItem: - builder = ( - self._builderFactory.createDefault() - if object_ is None - else FromMemoryObjectBuilder(self._settings, object_) - ) - return ObjectRepositoryItem(geometryProvider, self._settings, builder) - - def createFromSettings(self, geometryProvider: ObjectGeometryProvider) -> ObjectRepositoryItem: - try: - builder = self._builderFactory.createFromSettings() - except Exception as exc: - logger.error(''.join(exc.args)) - builder = self._builderFactory.createDefault() - - return ObjectRepositoryItem(geometryProvider, self._settings, builder) diff --git a/src/ptychodus/model/product/object/item_factory.py b/src/ptychodus/model/product/object/item_factory.py new file mode 100644 index 00000000..07c2d619 --- /dev/null +++ b/src/ptychodus/model/product/object/item_factory.py @@ -0,0 +1,48 @@ +import logging + +import numpy + +from ptychodus.api.object import Object, ObjectGeometryProvider + +from .builder import FromMemoryObjectBuilder +from .builder_factory import ObjectBuilderFactory +from .item import ObjectRepositoryItem +from .settings import ObjectSettings + +logger = logging.getLogger(__name__) + + +class ObjectRepositoryItemFactory: + def __init__( + self, + rng: numpy.random.Generator, + settings: ObjectSettings, + builder_factory: ObjectBuilderFactory, + ) -> None: + self._rng = rng + self._settings = settings + self._builder_factory = builder_factory + + def create( + self, geometry_provider: ObjectGeometryProvider, object_: Object | None = None + ) -> ObjectRepositoryItem: + # TODO layers_builder = MultilayerObjectBuilder() + + if object_ is None: + builder = self._builder_factory.create_default() + else: + builder = FromMemoryObjectBuilder(self._settings, object_) + # TODO layers_builder.set_identity() + + return ObjectRepositoryItem(geometry_provider, self._settings, builder) + + def create_from_settings( + self, geometry_provider: ObjectGeometryProvider + ) -> ObjectRepositoryItem: + try: + builder = self._builder_factory.create_from_settings() + except Exception as exc: + logger.error(''.join(exc.args)) + builder = self._builder_factory.create_default() + + return ObjectRepositoryItem(geometry_provider, self._settings, builder) diff --git a/src/ptychodus/model/product/object/random.py b/src/ptychodus/model/product/object/random.py index 0a0b4982..0eeedc28 100644 --- a/src/ptychodus/model/product/object/random.py +++ b/src/ptychodus/model/product/object/random.py @@ -4,6 +4,7 @@ import numpy from ptychodus.api.object import Object, ObjectGeometryProvider +from ptychodus.model.phase_unwrapper import PhaseUnwrapper from .builder import ObjectBuilder from .settings import ObjectSettings @@ -15,53 +16,100 @@ def __init__(self, rng: numpy.random.Generator, settings: ObjectSettings) -> Non self._rng = rng self._settings = settings - self.extraPaddingX = settings.extraPaddingX.copy() - self._addParameter('extra_padding_x', self.extraPaddingX) - self.extraPaddingY = settings.extraPaddingY.copy() - self._addParameter('extra_padding_y', self.extraPaddingY) + self.extra_padding_x = settings.extra_padding_x.copy() + self._add_parameter('extra_padding_x', self.extra_padding_x) + self.extra_padding_y = settings.extra_padding_y.copy() + self._add_parameter('extra_padding_y', self.extra_padding_y) - self.amplitudeMean = settings.amplitudeMean.copy() - self._addParameter('amplitude_mean', self.amplitudeMean) - self.amplitudeDeviation = settings.amplitudeDeviation.copy() - self._addParameter('amplitude_deviation', self.amplitudeDeviation) + self.amplitude_mean = settings.amplitude_mean.copy() + self._add_parameter('amplitude_mean', self.amplitude_mean) + self.amplitude_deviation = settings.amplitude_deviation.copy() + self._add_parameter('amplitude_deviation', self.amplitude_deviation) - self.phaseDeviation = settings.phaseDeviation.copy() - self._addParameter('phase_deviation', self.phaseDeviation) + self.phase_deviation = settings.phase_deviation.copy() + self._add_parameter('phase_deviation', self.phase_deviation) def copy(self) -> RandomObjectBuilder: builder = RandomObjectBuilder(self._rng, self._settings) for key, value in self.parameters().items(): - builder.parameters()[key].setValue(value.getValue()) + builder.parameters()[key].set_value(value.get_value()) return builder def build( self, - geometryProvider: ObjectGeometryProvider, - layerDistanceInMeters: Sequence[float], + geometry_provider: ObjectGeometryProvider, + layer_spacing_m: Sequence[float], ) -> Object: - geometry = geometryProvider.getObjectGeometry() - heightInPixels = geometry.heightInPixels + 2 * self.extraPaddingY.getValue() - widthInPixels = geometry.widthInPixels + 2 * self.extraPaddingX.getValue() - objectShape = (len(layerDistanceInMeters), heightInPixels, widthInPixels) + geometry = geometry_provider.get_object_geometry() + height_px = geometry.height_px + 2 * self.extra_padding_y.get_value() + width_px = geometry.width_px + 2 * self.extra_padding_x.get_value() + object_shape = (1 + len(layer_spacing_m), height_px, width_px) amplitude = self._rng.normal( - self.amplitudeMean.getValue(), - self.amplitudeDeviation.getValue(), - objectShape, + self.amplitude_mean.get_value(), + self.amplitude_deviation.get_value(), + object_shape, ) phase = self._rng.normal( 0.0, - self.phaseDeviation.getValue(), - objectShape, + self.phase_deviation.get_value(), + object_shape, ) return Object( array=numpy.clip(amplitude, 0.0, 1.0) * numpy.exp(1j * phase), - layerDistanceInMeters=layerDistanceInMeters, - pixelWidthInMeters=geometry.pixelWidthInMeters, - pixelHeightInMeters=geometry.pixelHeightInMeters, - centerXInMeters=geometry.centerXInMeters, - centerYInMeters=geometry.centerYInMeters, + layer_spacing_m=layer_spacing_m, + pixel_geometry=geometry.get_pixel_geometry(), + center=geometry.get_center(), + ) + + +class UserObjectBuilder(ObjectBuilder): # TODO use + def __init__(self, object_: Object, settings: ObjectSettings) -> None: + """Create an object from an existing object with a potentially + different number of slices. + + If the new object is supposed to be a multislice object with a + different number of slices than the existing object, the object is + created as + `abs(o) ** (1 / nSlices) * exp(i * unwrapPhase(o) / nSlices)`. + Otherwise, the object is copied as is. + """ + super().__init__(settings, 'user') + self._existing_object = object_ + self._settings = settings + + def copy(self) -> UserObjectBuilder: + builder = UserObjectBuilder(self._existing_object, self._settings) + + for key, value in self.parameters().items(): + builder.parameters()[key].set_value(value.get_value()) + + return builder + + def build( + self, + geometry_provider: ObjectGeometryProvider, + layer_spacing_m: Sequence[float], + ) -> Object: + geometry = self._existing_object.get_geometry() + existing_object_array = self._existing_object.get_array() + num_slices = len(layer_spacing_m) + 1 + + if num_slices > 1 and num_slices != existing_object_array.shape[0]: + amplitude = numpy.abs(existing_object_array[0:1]) ** (1.0 / num_slices) + amplitude = amplitude.repeat(num_slices, axis=0) + phase = PhaseUnwrapper().unwrap(existing_object_array[0])[None, ...] / num_slices + phase = phase.repeat(num_slices, axis=0) + data = numpy.clip(amplitude, 0.0, 1.0) * numpy.exp(1j * phase) + else: + data = existing_object_array + + return Object( + array=data, + layer_spacing_m=layer_spacing_m, + pixel_geometry=geometry.get_pixel_geometry(), + center=geometry.get_center(), ) diff --git a/src/ptychodus/model/product/object/settings.py b/src/ptychodus/model/product/object/settings.py index 15780279..79ed959e 100644 --- a/src/ptychodus/model/product/object/settings.py +++ b/src/ptychodus/model/product/object/settings.py @@ -9,35 +9,29 @@ class ObjectSettings(Observable, Observer): def __init__(self, registry: SettingsRegistry) -> None: super().__init__() - self._settingsGroup = registry.createGroup('Object') - self._settingsGroup.addObserver(self) + self._group = registry.create_group('Object') + self._group.add_observer(self) - self.builder = self._settingsGroup.createStringParameter('Builder', 'Random') - self.filePath = self._settingsGroup.createPathParameter( - 'FilePath', Path('/path/to/object.npy') - ) - self.fileType = self._settingsGroup.createStringParameter('FileType', 'NPY') + self.builder = self._group.create_string_parameter('Builder', 'Random') + self.file_path = self._group.create_path_parameter('FilePath', Path('/path/to/object.npy')) + self.file_type = self._group.create_string_parameter('FileType', 'NPY') - self.objectLayerDistanceInMeters = self._settingsGroup.createRealParameter( - 'ObjectLayerDistanceInMeters', 1e-7 + self.object_layer_spacing_m = self._group.create_real_sequence_parameter( + 'ObjectLayerSpacingInMeters', [] ) - self.extraPaddingX = self._settingsGroup.createIntegerParameter( - 'ExtraPaddingX', 1, minimum=0 - ) - self.extraPaddingY = self._settingsGroup.createIntegerParameter( - 'ExtraPaddingY', 1, minimum=0 - ) - self.amplitudeMean = self._settingsGroup.createRealParameter( + self.extra_padding_x = self._group.create_integer_parameter('ExtraPaddingX', 1, minimum=0) + self.extra_padding_y = self._group.create_integer_parameter('ExtraPaddingY', 1, minimum=0) + self.amplitude_mean = self._group.create_real_parameter( 'AmplitudeMean', 1.0, minimum=0.0, maximum=1.0 ) - self.amplitudeDeviation = self._settingsGroup.createRealParameter( + self.amplitude_deviation = self._group.create_real_parameter( 'AmplitudeDeviation', 0.0, minimum=0.0, maximum=1.0 ) - self.phaseDeviation = self._settingsGroup.createRealParameter( + self.phase_deviation = self._group.create_real_parameter( 'PhaseDeviation', 0.0, minimum=0.0, maximum=numpy.pi ) - def update(self, observable: Observable) -> None: - if observable is self._settingsGroup: - self.notifyObservers() + def _update(self, observable: Observable) -> None: + if observable is self._group: + self.notify_observers() diff --git a/src/ptychodus/model/product/objectRepository.py b/src/ptychodus/model/product/objectRepository.py deleted file mode 100644 index 6e615f70..00000000 --- a/src/ptychodus/model/product/objectRepository.py +++ /dev/null @@ -1,65 +0,0 @@ -from collections.abc import Sequence -from typing import overload -import logging - -from ptychodus.api.observer import ObservableSequence - -from .item import ProductRepositoryItem, ProductRepositoryObserver -from .metadata import MetadataRepositoryItem -from .object import ObjectRepositoryItem -from .probe import ProbeRepositoryItem -from .productRepository import ProductRepository -from .scan import ScanRepositoryItem - -logger = logging.getLogger(__name__) - - -class ObjectRepository(ObservableSequence[ObjectRepositoryItem], ProductRepositoryObserver): - def __init__(self, repository: ProductRepository) -> None: - super().__init__() - self._repository = repository - self._repository.addObserver(self) - - def getName(self, index: int) -> str: - return self._repository[index].getName() - - def setName(self, index: int, name: str) -> None: - self._repository[index].setName(name) - - @overload - def __getitem__(self, index: int) -> ObjectRepositoryItem: ... - - @overload - def __getitem__(self, index: slice) -> Sequence[ObjectRepositoryItem]: ... - - def __getitem__( - self, index: int | slice - ) -> ObjectRepositoryItem | Sequence[ObjectRepositoryItem]: - if isinstance(index, slice): - return [item.getObject() for item in self._repository[index]] - else: - return self._repository[index].getObject() - - def __len__(self) -> int: - return len(self._repository) - - def handleItemInserted(self, index: int, item: ProductRepositoryItem) -> None: - self.notifyObserversItemInserted(index, item.getObject()) - - def handleMetadataChanged(self, index: int, item: MetadataRepositoryItem) -> None: - pass - - def handleScanChanged(self, index: int, item: ScanRepositoryItem) -> None: - pass - - def handleProbeChanged(self, index: int, item: ProbeRepositoryItem) -> None: - pass - - def handleObjectChanged(self, index: int, item: ObjectRepositoryItem) -> None: - self.notifyObserversItemChanged(index, item) - - def handleCostsChanged(self, index: int, costs: Sequence[float]) -> None: - pass - - def handleItemRemoved(self, index: int, item: ProductRepositoryItem) -> None: - self.notifyObserversItemRemoved(index, item.getObject()) diff --git a/src/ptychodus/model/product/object_repository.py b/src/ptychodus/model/product/object_repository.py new file mode 100644 index 00000000..76e41c8b --- /dev/null +++ b/src/ptychodus/model/product/object_repository.py @@ -0,0 +1,65 @@ +from collections.abc import Sequence +from typing import overload +import logging + +from ptychodus.api.observer import ObservableSequence + +from .item import ProductRepositoryItem, ProductRepositoryObserver +from .metadata import MetadataRepositoryItem +from .object import ObjectRepositoryItem +from .probe import ProbeRepositoryItem +from .repository import ProductRepository +from .scan import ScanRepositoryItem + +logger = logging.getLogger(__name__) + + +class ObjectRepository(ObservableSequence[ObjectRepositoryItem], ProductRepositoryObserver): + def __init__(self, repository: ProductRepository) -> None: + super().__init__() + self._repository = repository + self._repository.add_observer(self) + + def get_name(self, index: int) -> str: + return self._repository[index].get_name() + + def set_name(self, index: int, name: str) -> None: + self._repository[index].set_name(name) + + @overload + def __getitem__(self, index: int) -> ObjectRepositoryItem: ... + + @overload + def __getitem__(self, index: slice) -> Sequence[ObjectRepositoryItem]: ... + + def __getitem__( + self, index: int | slice + ) -> ObjectRepositoryItem | Sequence[ObjectRepositoryItem]: + if isinstance(index, slice): + return [item.get_object_item() for item in self._repository[index]] + else: + return self._repository[index].get_object_item() + + def __len__(self) -> int: + return len(self._repository) + + def handle_item_inserted(self, index: int, item: ProductRepositoryItem) -> None: + self.notify_observers_item_inserted(index, item.get_object_item()) + + def handle_metadata_changed(self, index: int, item: MetadataRepositoryItem) -> None: + pass + + def handle_scan_changed(self, index: int, item: ScanRepositoryItem) -> None: + pass + + def handle_probe_changed(self, index: int, item: ProbeRepositoryItem) -> None: + pass + + def handle_object_changed(self, index: int, item: ObjectRepositoryItem) -> None: + self.notify_observers_item_changed(index, item) + + def handle_costs_changed(self, index: int, costs: Sequence[float]) -> None: + pass + + def handle_item_removed(self, index: int, item: ProductRepositoryItem) -> None: + self.notify_observers_item_removed(index, item.get_object_item()) diff --git a/src/ptychodus/model/product/probe/__init__.py b/src/ptychodus/model/product/probe/__init__.py index e38b2c94..5c860b7f 100644 --- a/src/ptychodus/model/product/probe/__init__.py +++ b/src/ptychodus/model/product/probe/__init__.py @@ -1,14 +1,14 @@ -from .averagePattern import AveragePatternProbeBuilder -from .builder import ProbeBuilder -from .builderFactory import ProbeBuilderFactory +from .average_pattern import AveragePatternProbeBuilder +from .builder import ProbeSequenceBuilder +from .builder_factory import ProbeBuilderFactory from .disk import DiskProbeBuilder from .fzp import FresnelZonePlateProbeBuilder from .item import ProbeRepositoryItem -from .itemFactory import ProbeRepositoryItemFactory +from .item_factory import ProbeRepositoryItemFactory from .multimodal import MultimodalProbeBuilder, ProbeModeDecayType from .rect import RectangularProbeBuilder from .settings import ProbeSettings -from .superGaussian import SuperGaussianProbeBuilder +from .super_gaussian import SuperGaussianProbeBuilder from .zernike import ZernikeProbeBuilder __all__ = [ @@ -16,13 +16,11 @@ 'DiskProbeBuilder', 'FresnelZonePlateProbeBuilder', 'MultimodalProbeBuilder', - 'ProbeBuilder', + 'ProbeSequenceBuilder', 'ProbeBuilderFactory', 'ProbeModeDecayType', - 'ProbePresenter', 'ProbeRepositoryItem', 'ProbeRepositoryItemFactory', - 'ProbeRepositoryPresenter', 'ProbeSettings', 'RectangularProbeBuilder', 'SuperGaussianProbeBuilder', diff --git a/src/ptychodus/model/product/probe/averagePattern.py b/src/ptychodus/model/product/probe/averagePattern.py deleted file mode 100644 index 30bc42da..00000000 --- a/src/ptychodus/model/product/probe/averagePattern.py +++ /dev/null @@ -1,45 +0,0 @@ -from __future__ import annotations - -import numpy - -from ptychodus.api.probe import Probe, ProbeGeometryProvider -from ptychodus.api.propagator import FresnelTransformPropagator, PropagatorParameters - -from ...patterns import ActiveDiffractionDataset, Detector -from .builder import ProbeBuilder -from .settings import ProbeSettings - - -class AveragePatternProbeBuilder(ProbeBuilder): - def __init__( - self, settings: ProbeSettings, detector: Detector, patterns: ActiveDiffractionDataset - ) -> None: - super().__init__(settings, 'average_pattern') - self._settings = settings - self._detector = detector - self._patterns = patterns - - def copy(self) -> AveragePatternProbeBuilder: - return AveragePatternProbeBuilder(self._settings, self._detector, self._patterns) - - def build(self, geometryProvider: ProbeGeometryProvider) -> Probe: - geometry = geometryProvider.getProbeGeometry() - detectorIntensity = numpy.average(self._patterns.getAssembledData(), axis=0) - - pixelGeometry = self._detector.getPixelGeometry() - propagatorParameters = PropagatorParameters( - wavelength_m=geometryProvider.probeWavelengthInMeters, - width_px=detectorIntensity.shape[-1], - height_px=detectorIntensity.shape[-2], - pixel_width_m=pixelGeometry.widthInMeters, - pixel_height_m=pixelGeometry.heightInMeters, - propagation_distance_m=-geometryProvider.detectorDistanceInMeters, - ) - propagator = FresnelTransformPropagator(propagatorParameters) - array = propagator.propagate(numpy.sqrt(detectorIntensity).astype(complex)) - - return Probe( - array=self.normalize(array), - pixelWidthInMeters=geometry.pixelWidthInMeters, - pixelHeightInMeters=geometry.pixelHeightInMeters, - ) diff --git a/src/ptychodus/model/product/probe/average_pattern.py b/src/ptychodus/model/product/probe/average_pattern.py new file mode 100644 index 00000000..59e117f4 --- /dev/null +++ b/src/ptychodus/model/product/probe/average_pattern.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +import numpy + +from ptychodus.api.probe import ProbeSequence, ProbeGeometryProvider +from ptychodus.api.propagator import FresnelTransformPropagator, PropagatorParameters + +from ...patterns import AssembledDiffractionDataset +from .builder import ProbeSequenceBuilder +from .settings import ProbeSettings + + +class AveragePatternProbeBuilder(ProbeSequenceBuilder): + def __init__( + self, + settings: ProbeSettings, + dataset: AssembledDiffractionDataset, + ) -> None: + super().__init__(settings, 'average_pattern') + self._settings = settings + self._dataset = dataset + + def copy(self) -> AveragePatternProbeBuilder: + return AveragePatternProbeBuilder(self._settings, self._dataset) + + def build(self, geometry_provider: ProbeGeometryProvider) -> ProbeSequence: + geometry = geometry_provider.get_probe_geometry() + detector_intensity = numpy.mean(self._dataset.get_assembled_patterns(), axis=0) + + pixel_geometry = geometry_provider.get_detector_pixel_geometry() + propagator_parameters = PropagatorParameters( + wavelength_m=geometry_provider.probe_wavelength_m, + width_px=detector_intensity.shape[-1], + height_px=detector_intensity.shape[-2], + pixel_width_m=pixel_geometry.width_m, + pixel_height_m=pixel_geometry.height_m, + propagation_distance_m=-geometry_provider.detector_distance_m, + ) + propagator = FresnelTransformPropagator(propagator_parameters) + array = propagator.propagate(numpy.sqrt(detector_intensity).astype(complex)) + + return ProbeSequence( + array=self.normalize(array), + opr_weights=None, + pixel_geometry=geometry.get_pixel_geometry(), + ) diff --git a/src/ptychodus/model/product/probe/builder.py b/src/ptychodus/model/product/probe/builder.py index b8fec604..4957c27d 100644 --- a/src/ptychodus/model/product/probe/builder.py +++ b/src/ptychodus/model/product/probe/builder.py @@ -9,11 +9,11 @@ from ptychodus.api.parametric import ParameterGroup from ptychodus.api.probe import ( - Probe, + ProbeSequence, ProbeFileReader, ProbeGeometry, ProbeGeometryProvider, - WavefieldArrayType, + ComplexArrayType, ) from ptychodus.api.typing import RealArrayType @@ -24,55 +24,55 @@ @dataclass(frozen=True) class ProbeTransverseCoordinates: - positionXInMeters: RealArrayType - positionYInMeters: RealArrayType + position_x_m: RealArrayType + position_y_m: RealArrayType @property - def positionRInMeters(self) -> RealArrayType: - return numpy.hypot(self.positionXInMeters, self.positionYInMeters) + def position_r_m(self) -> RealArrayType: + return numpy.hypot(self.position_x_m, self.position_y_m) -class ProbeBuilder(ParameterGroup): +class ProbeSequenceBuilder(ParameterGroup): def __init__(self, settings: ProbeSettings, name: str) -> None: super().__init__() self._name = settings.builder.copy() - self._name.setValue(name) - self._addParameter('name', self._name) + self._name.set_value(name) + self._add_parameter('name', self._name) - def getTransverseCoordinates(self, geometry: ProbeGeometry) -> ProbeTransverseCoordinates: - Y, X = numpy.mgrid[: geometry.heightInPixels, : geometry.widthInPixels] - positionXInPixels = X - (geometry.widthInPixels - 1) / 2 - positionYInPixels = Y - (geometry.heightInPixels - 1) / 2 + def get_transverse_coordinates(self, geometry: ProbeGeometry) -> ProbeTransverseCoordinates: + Y, X = numpy.mgrid[: geometry.height_px, : geometry.width_px] # noqa: N806 + position_x_px = X - (geometry.width_px - 1) / 2 + position_y_px = Y - (geometry.height_px - 1) / 2 - positionXInMeters = positionXInPixels * geometry.pixelWidthInMeters - positionYInMeters = positionYInPixels * geometry.pixelHeightInMeters + position_x_m = position_x_px * geometry.pixel_width_m + position_y_m = position_y_px * geometry.pixel_height_m return ProbeTransverseCoordinates( - positionXInMeters=positionXInMeters, - positionYInMeters=positionYInMeters, + position_x_m=position_x_m, + position_y_m=position_y_m, ) - def normalize(self, array: WavefieldArrayType) -> WavefieldArrayType: - return array / numpy.sqrt(numpy.sum(numpy.abs(array) ** 2)) + def normalize(self, array: ComplexArrayType) -> ComplexArrayType: + return array / numpy.sqrt(numpy.sum(numpy.square(numpy.abs(array)))) - def getName(self) -> str: - return self._name.getValue() + def get_name(self) -> str: + return self._name.get_value() - def syncToSettings(self) -> None: + def sync_to_settings(self) -> None: for parameter in self.parameters().values(): - parameter.syncValueToParent() + parameter.sync_value_to_parent() @abstractmethod - def copy(self) -> ProbeBuilder: + def copy(self) -> ProbeSequenceBuilder: pass @abstractmethod - def build(self, geometryProvider: ProbeGeometryProvider) -> Probe: + def build(self, geometry_provider: ProbeGeometryProvider) -> ProbeSequence: pass -class FromMemoryProbeBuilder(ProbeBuilder): - def __init__(self, settings: ProbeSettings, probe: Probe) -> None: +class FromMemoryProbeBuilder(ProbeSequenceBuilder): + def __init__(self, settings: ProbeSettings, probe: ProbeSequence) -> None: super().__init__(settings, 'from_memory') self._settings = settings self._probe = probe.copy() @@ -80,37 +80,57 @@ def __init__(self, settings: ProbeSettings, probe: Probe) -> None: def copy(self) -> FromMemoryProbeBuilder: return FromMemoryProbeBuilder(self._settings, self._probe) - def build(self, geometryProvider: ProbeGeometryProvider) -> Probe: + def build(self, geometry_provider: ProbeGeometryProvider) -> ProbeSequence: return self._probe -class FromFileProbeBuilder(ProbeBuilder): +class FromFileProbeBuilder(ProbeSequenceBuilder): def __init__( - self, settings: ProbeSettings, filePath: Path, fileType: str, fileReader: ProbeFileReader + self, settings: ProbeSettings, file_path: Path, file_type: str, file_reader: ProbeFileReader ) -> None: super().__init__(settings, 'from_file') self._settings = settings - self.filePath = settings.filePath.copy() - self.filePath.setValue(filePath) - self._addParameter('file_path', self.filePath) - self.fileType = settings.fileType.copy() - self.fileType.setValue(fileType) - self._addParameter('file_type', self.fileType) - self._fileReader = fileReader + self.file_path = settings.file_path.copy() + self.file_path.set_value(file_path) + self._add_parameter('file_path', self.file_path) + self.file_type = settings.file_type.copy() + self.file_type.set_value(file_type) + self._add_parameter('file_type', self.file_type) + self._file_reader = file_reader def copy(self) -> FromFileProbeBuilder: return FromFileProbeBuilder( - self._settings, self.filePath.getValue(), self.fileType.getValue(), self._fileReader + self._settings, + self.file_path.get_value(), + self.file_type.get_value(), + self._file_reader, ) - def build(self, geometryProvider: ProbeGeometryProvider) -> Probe: - filePath = self.filePath.getValue() - fileType = self.fileType.getValue() - logger.debug(f'Reading "{filePath}" as "{fileType}"') + def build(self, geometry_provider: ProbeGeometryProvider) -> ProbeSequence: + file_path = self.file_path.get_value() + file_type = self.file_type.get_value() + logger.debug(f'Reading "{file_path}" as "{file_type}"') try: - probe = self._fileReader.read(filePath) + probe_from_file = self._file_reader.read(file_path) except Exception as exc: - raise RuntimeError(f'Failed to read "{filePath}"') from exc + raise RuntimeError(f'Failed to read "{file_path}"') from exc - return probe + probe_geometry = geometry_provider.get_probe_geometry() + + try: + pixel_geometry = probe_from_file.get_pixel_geometry() + except ValueError: + pixel_geometry = probe_geometry.get_pixel_geometry() + + try: + opr_weights = probe_from_file.get_opr_weights() + except ValueError: + opr_weights = None + + # TODO regrid probe as needed based on probe geometry from file/provider + return ProbeSequence( + probe_from_file.get_array(), + opr_weights, + pixel_geometry, + ) diff --git a/src/ptychodus/model/product/probe/builderFactory.py b/src/ptychodus/model/product/probe/builderFactory.py deleted file mode 100644 index a3843102..00000000 --- a/src/ptychodus/model/product/probe/builderFactory.py +++ /dev/null @@ -1,107 +0,0 @@ -from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence -from pathlib import Path -import logging - -from ptychodus.api.plugins import PluginChooser -from ptychodus.api.probe import ( - FresnelZonePlate, - Probe, - ProbeFileReader, - ProbeFileWriter, -) - -from ...patterns import ActiveDiffractionDataset, Detector -from .averagePattern import AveragePatternProbeBuilder -from .builder import FromFileProbeBuilder, ProbeBuilder -from .disk import DiskProbeBuilder -from .fzp import FresnelZonePlateProbeBuilder -from .rect import RectangularProbeBuilder -from .settings import ProbeSettings -from .superGaussian import SuperGaussianProbeBuilder -from .zernike import ZernikeProbeBuilder - -logger = logging.getLogger(__name__) - - -class ProbeBuilderFactory(Iterable[str]): - def __init__( - self, - settings: ProbeSettings, - detector: Detector, - patterns: ActiveDiffractionDataset, - fresnelZonePlateChooser: PluginChooser[FresnelZonePlate], - fileReaderChooser: PluginChooser[ProbeFileReader], - fileWriterChooser: PluginChooser[ProbeFileWriter], - ) -> None: - super().__init__() - self._settings = settings - self._detector = detector - self._patterns = patterns - self._fresnelZonePlateChooser = fresnelZonePlateChooser - self._fileReaderChooser = fileReaderChooser - self._fileWriterChooser = fileWriterChooser - self._builders: Mapping[str, Callable[[], ProbeBuilder]] = { - 'disk': lambda: DiskProbeBuilder(settings), - 'average_pattern': self._createAveragePatternBuilder, - 'fresnel_zone_plate': self._createFresnelZonePlateBuilder, - 'rectangular': lambda: RectangularProbeBuilder(settings), - 'super_gaussian': lambda: SuperGaussianProbeBuilder(settings), - 'zernike': lambda: ZernikeProbeBuilder(settings), - } - - def __iter__(self) -> Iterator[str]: - return iter(self._builders) - - def create(self, name: str) -> ProbeBuilder: - try: - factory = self._builders[name] - except KeyError as exc: - raise KeyError(f'Unknown probe builder "{name}"!') from exc - - return factory() - - def createDefault(self) -> ProbeBuilder: - return next(iter(self._builders.values()))() - - def createFromSettings(self) -> ProbeBuilder: - name = self._settings.builder.getValue() - nameRepaired = name.casefold() - - if nameRepaired == 'from_file': - return self.createProbeFromFile( - self._settings.filePath.getValue(), - self._settings.fileType.getValue(), - ) - - return self.create(nameRepaired) - - def _createAveragePatternBuilder(self) -> ProbeBuilder: - return AveragePatternProbeBuilder(self._settings, self._detector, self._patterns) - - def _createFresnelZonePlateBuilder(self) -> ProbeBuilder: - return FresnelZonePlateProbeBuilder(self._settings, self._fresnelZonePlateChooser) - - def getOpenFileFilterList(self) -> Sequence[str]: - return self._fileReaderChooser.getDisplayNameList() - - def getOpenFileFilter(self) -> str: - return self._fileReaderChooser.currentPlugin.displayName - - def createProbeFromFile(self, filePath: Path, fileFilter: str) -> ProbeBuilder: - self._fileReaderChooser.setCurrentPluginByName(fileFilter) - fileType = self._fileReaderChooser.currentPlugin.simpleName - fileReader = self._fileReaderChooser.currentPlugin.strategy - return FromFileProbeBuilder(self._settings, filePath, fileType, fileReader) - - def getSaveFileFilterList(self) -> Sequence[str]: - return self._fileWriterChooser.getDisplayNameList() - - def getSaveFileFilter(self) -> str: - return self._fileWriterChooser.currentPlugin.displayName - - def saveProbe(self, filePath: Path, fileFilter: str, probe: Probe) -> None: - self._fileWriterChooser.setCurrentPluginByName(fileFilter) - fileType = self._fileWriterChooser.currentPlugin.simpleName - logger.debug(f'Writing "{filePath}" as "{fileType}"') - fileWriter = self._fileWriterChooser.currentPlugin.strategy - fileWriter.write(filePath, probe) diff --git a/src/ptychodus/model/product/probe/builder_factory.py b/src/ptychodus/model/product/probe/builder_factory.py new file mode 100644 index 00000000..9f338125 --- /dev/null +++ b/src/ptychodus/model/product/probe/builder_factory.py @@ -0,0 +1,107 @@ +from collections.abc import Callable, Iterable, Iterator, Mapping +from pathlib import Path +import logging + +from ptychodus.api.plugins import PluginChooser +from ptychodus.api.probe import ( + FresnelZonePlate, + ProbeSequence, + ProbeFileReader, + ProbeFileWriter, +) + +from ...patterns import AssembledDiffractionDataset +from .average_pattern import AveragePatternProbeBuilder +from .builder import FromFileProbeBuilder, ProbeSequenceBuilder +from .disk import DiskProbeBuilder +from .fzp import FresnelZonePlateProbeBuilder +from .rect import RectangularProbeBuilder +from .settings import ProbeSettings +from .super_gaussian import SuperGaussianProbeBuilder +from .zernike import ZernikeProbeBuilder + +logger = logging.getLogger(__name__) + + +class ProbeBuilderFactory(Iterable[str]): + def __init__( + self, + settings: ProbeSettings, + dataset: AssembledDiffractionDataset, + fresnel_zone_plate_chooser: PluginChooser[FresnelZonePlate], + file_reader_chooser: PluginChooser[ProbeFileReader], + file_writer_chooser: PluginChooser[ProbeFileWriter], + ) -> None: + super().__init__() + self._settings = settings + self._dataset = dataset + self._fresnel_zone_plate_chooser = fresnel_zone_plate_chooser + self._file_reader_chooser = file_reader_chooser + self._file_writer_chooser = file_writer_chooser + self._builders: Mapping[str, Callable[[], ProbeSequenceBuilder]] = { + 'disk': lambda: DiskProbeBuilder(settings), + 'average_pattern': self._create_average_pattern_builder, + 'fresnel_zone_plate': self._create_fresnel_zone_plate_builder, + 'rectangular': lambda: RectangularProbeBuilder(settings), + 'super_gaussian': lambda: SuperGaussianProbeBuilder(settings), + 'zernike': lambda: ZernikeProbeBuilder(settings), + } + + def __iter__(self) -> Iterator[str]: + return iter(self._builders) + + def create(self, name: str) -> ProbeSequenceBuilder: + try: + factory = self._builders[name] + except KeyError as exc: + raise KeyError(f'Unknown probe builder "{name}"!') from exc + + return factory() + + def create_default(self) -> ProbeSequenceBuilder: + return next(iter(self._builders.values()))() + + def create_from_settings(self) -> ProbeSequenceBuilder: + name = self._settings.builder.get_value() + name_repaired = name.casefold() + + if name_repaired == 'from_file': + return self.create_probe_from_file( + self._settings.file_path.get_value(), + self._settings.file_type.get_value(), + ) + + return self.create(name_repaired) + + def _create_average_pattern_builder(self) -> ProbeSequenceBuilder: + return AveragePatternProbeBuilder(self._settings, self._dataset) + + def _create_fresnel_zone_plate_builder(self) -> ProbeSequenceBuilder: + return FresnelZonePlateProbeBuilder(self._settings, self._fresnel_zone_plate_chooser) + + def get_open_file_filters(self) -> Iterator[str]: + for plugin in self._file_reader_chooser: + yield plugin.display_name + + def get_open_file_filter(self) -> str: + return self._file_reader_chooser.get_current_plugin().display_name + + def create_probe_from_file(self, file_path: Path, file_filter: str) -> ProbeSequenceBuilder: + self._file_reader_chooser.set_current_plugin(file_filter) + file_type = self._file_reader_chooser.get_current_plugin().simple_name + file_reader = self._file_reader_chooser.get_current_plugin().strategy + return FromFileProbeBuilder(self._settings, file_path, file_type, file_reader) + + def get_save_file_filters(self) -> Iterator[str]: + for plugin in self._file_writer_chooser: + yield plugin.display_name + + def get_save_file_filter(self) -> str: + return self._file_writer_chooser.get_current_plugin().display_name + + def save_probe(self, file_path: Path, file_filter: str, probe: ProbeSequence) -> None: + self._file_writer_chooser.set_current_plugin(file_filter) + file_type = self._file_writer_chooser.get_current_plugin().simple_name + logger.debug(f'Writing "{file_path}" as "{file_type}"') + file_writer = self._file_writer_chooser.get_current_plugin().strategy + file_writer.write(file_path, probe) diff --git a/src/ptychodus/model/product/probe/disk.py b/src/ptychodus/model/product/probe/disk.py index d4630b8f..be574d72 100644 --- a/src/ptychodus/model/product/probe/disk.py +++ b/src/ptychodus/model/product/probe/disk.py @@ -2,54 +2,54 @@ import numpy -from ptychodus.api.probe import Probe, ProbeGeometryProvider +from ptychodus.api.probe import ProbeSequence, ProbeGeometryProvider from ptychodus.api.propagator import AngularSpectrumPropagator, PropagatorParameters -from .builder import ProbeBuilder +from .builder import ProbeSequenceBuilder from .settings import ProbeSettings -class DiskProbeBuilder(ProbeBuilder): +class DiskProbeBuilder(ProbeSequenceBuilder): def __init__(self, settings: ProbeSettings) -> None: super().__init__(settings, 'disk') self._settings = settings - self.diameterInMeters = settings.diskDiameterInMeters.copy() - self._addParameter('diameter_m', self.diameterInMeters) + self.diameter_m = settings.disk_diameter_m.copy() + self._add_parameter('diameter_m', self.diameter_m) # from sample to the focal plane - self.defocusDistanceInMeters = settings.defocusDistanceInMeters.copy() - self._addParameter('defocus_distance_m', self.defocusDistanceInMeters) + self.defocus_distance_m = settings.defocus_distance_m.copy() + self._add_parameter('defocus_distance_m', self.defocus_distance_m) def copy(self) -> DiskProbeBuilder: builder = DiskProbeBuilder(self._settings) for key, value in self.parameters().items(): - builder.parameters()[key].setValue(value.getValue()) + builder.parameters()[key].set_value(value.get_value()) return builder - def build(self, geometryProvider: ProbeGeometryProvider) -> Probe: - geometry = geometryProvider.getProbeGeometry() - coords = self.getTransverseCoordinates(geometry) + def build(self, geometry_provider: ProbeGeometryProvider) -> ProbeSequence: + geometry = geometry_provider.get_probe_geometry() + coords = self.get_transverse_coordinates(geometry) - R_m = coords.positionRInMeters - r_m = self.diameterInMeters.getValue() / 2.0 + R_m = coords.position_r_m # noqa: N806 + r_m = self.diameter_m.get_value() / 2.0 disk = numpy.where(R_m < r_m, 1 + 0j, 0j) - propagatorParameters = PropagatorParameters( - wavelength_m=geometryProvider.probeWavelengthInMeters, + propagator_parameters = PropagatorParameters( + wavelength_m=geometry_provider.probe_wavelength_m, width_px=disk.shape[-1], height_px=disk.shape[-2], - pixel_width_m=geometry.pixelWidthInMeters, - pixel_height_m=geometry.pixelHeightInMeters, - propagation_distance_m=self.defocusDistanceInMeters.getValue(), + pixel_width_m=geometry.pixel_width_m, + pixel_height_m=geometry.pixel_height_m, + propagation_distance_m=self.defocus_distance_m.get_value(), ) - propagator = AngularSpectrumPropagator(propagatorParameters) + propagator = AngularSpectrumPropagator(propagator_parameters) array = propagator.propagate(disk) - return Probe( + return ProbeSequence( array=self.normalize(array), - pixelWidthInMeters=geometry.pixelWidthInMeters, - pixelHeightInMeters=geometry.pixelHeightInMeters, + opr_weights=None, + pixel_geometry=geometry.get_pixel_geometry(), ) diff --git a/src/ptychodus/model/product/probe/fzp.py b/src/ptychodus/model/product/probe/fzp.py index adaf919f..bc5c3ef5 100644 --- a/src/ptychodus/model/product/probe/fzp.py +++ b/src/ptychodus/model/product/probe/fzp.py @@ -6,100 +6,101 @@ from ptychodus.api.geometry import PixelGeometry from ptychodus.api.plugins import PluginChooser -from ptychodus.api.probe import FresnelZonePlate, Probe, ProbeGeometryProvider +from ptychodus.api.probe import FresnelZonePlate, ProbeSequence, ProbeGeometryProvider from ptychodus.api.propagator import FresnelTransformPropagator, PropagatorParameters -from .builder import ProbeBuilder +from .builder import ProbeSequenceBuilder from .settings import ProbeSettings -class FresnelZonePlateProbeBuilder(ProbeBuilder): +class FresnelZonePlateProbeBuilder(ProbeSequenceBuilder): def __init__( self, settings: ProbeSettings, - fresnelZonePlateChooser: PluginChooser[FresnelZonePlate], + fresnel_zone_plate_chooser: PluginChooser[FresnelZonePlate], ) -> None: super().__init__(settings, 'fresnel_zone_plate') self._settings = settings - self._fresnelZonePlateChooser = fresnelZonePlateChooser + self._fresnel_zone_plate_chooser = fresnel_zone_plate_chooser - self.zonePlateDiameterInMeters = settings.zonePlateDiameterInMeters.copy() - self._addParameter('zone_plate_diameter_m', self.zonePlateDiameterInMeters) + self.zone_plate_diameter_m = settings.zone_plate_diameter_m.copy() + self._add_parameter('zone_plate_diameter_m', self.zone_plate_diameter_m) - self.outermostZoneWidthInMeters = settings.outermostZoneWidthInMeters.copy() - self._addParameter('outermost_zone_width_m', self.outermostZoneWidthInMeters) + self.outermost_zone_width_m = settings.outermost_zone_width_m.copy() + self._add_parameter('outermost_zone_width_m', self.outermost_zone_width_m) - self.centralBeamstopDiameterInMeters = settings.centralBeamstopDiameterInMeters.copy() - self._addParameter('central_beamstop_diameter_m', self.centralBeamstopDiameterInMeters) + self.central_beamstop_diameter_m = settings.central_beamstop_diameter_m.copy() + self._add_parameter('central_beamstop_diameter_m', self.central_beamstop_diameter_m) # from sample to the focal plane - self.defocusDistanceInMeters = settings.defocusDistanceInMeters.copy() - self._addParameter('defocus_distance_m', self.defocusDistanceInMeters) + self.defocus_distance_m = settings.defocus_distance_m.copy() + self._add_parameter('defocus_distance_m', self.defocus_distance_m) def copy(self) -> FresnelZonePlateProbeBuilder: - builder = FresnelZonePlateProbeBuilder(self._settings, self._fresnelZonePlateChooser) + builder = FresnelZonePlateProbeBuilder(self._settings, self._fresnel_zone_plate_chooser) for key, value in self.parameters().items(): - builder.parameters()[key].setValue(value.getValue()) + builder.parameters()[key].set_value(value.get_value()) return builder - def labelsForPresets(self) -> Iterator[str]: - for entry in self._fresnelZonePlateChooser: - yield entry.displayName - - def applyPresets(self, index: int) -> None: - fzp = self._fresnelZonePlateChooser[index].strategy - self.zonePlateDiameterInMeters.setValue(fzp.zonePlateDiameterInMeters) - self.outermostZoneWidthInMeters.setValue(fzp.outermostZoneWidthInMeters) - self.centralBeamstopDiameterInMeters.setValue(fzp.centralBeamstopDiameterInMeters) - - def build(self, geometryProvider: ProbeGeometryProvider) -> Probe: - wavelengthInMeters = geometryProvider.probeWavelengthInMeters - zonePlate = FresnelZonePlate( - zonePlateDiameterInMeters=self.zonePlateDiameterInMeters.getValue(), - outermostZoneWidthInMeters=self.outermostZoneWidthInMeters.getValue(), - centralBeamstopDiameterInMeters=self.centralBeamstopDiameterInMeters.getValue(), + def labels_for_presets(self) -> Iterator[str]: + for plugin in self._fresnel_zone_plate_chooser: + yield plugin.display_name + + def apply_presets(self, display_name: str) -> None: + self._fresnel_zone_plate_chooser.set_current_plugin(display_name) + fzp = self._fresnel_zone_plate_chooser.get_current_plugin().strategy + self.zone_plate_diameter_m.set_value(fzp.zone_plate_diameter_m) + self.outermost_zone_width_m.set_value(fzp.outermost_zone_width_m) + self.central_beamstop_diameter_m.set_value(fzp.central_beamstop_diameter_m) + + def build(self, geometry_provider: ProbeGeometryProvider) -> ProbeSequence: + wavelength_m = geometry_provider.probe_wavelength_m + zone_plate = FresnelZonePlate( + zone_plate_diameter_m=self.zone_plate_diameter_m.get_value(), + outermost_zone_width_m=self.outermost_zone_width_m.get_value(), + central_beamstop_diameter_m=self.central_beamstop_diameter_m.get_value(), ) - focalLengthInMeters = zonePlate.getFocalLengthInMeters(wavelengthInMeters) - distanceInMeters = focalLengthInMeters + self.defocusDistanceInMeters.getValue() - samplePlaneGeometry = geometryProvider.getProbeGeometry() - fzpHalfWidth = (samplePlaneGeometry.widthInPixels + 1) // 2 - fzpHalfHeight = (samplePlaneGeometry.heightInPixels + 1) // 2 - fzpPlanePixelSizeNumerator = wavelengthInMeters * distanceInMeters - fzpPixelGeometry = PixelGeometry( - widthInMeters=fzpPlanePixelSizeNumerator / samplePlaneGeometry.widthInMeters, - heightInMeters=fzpPlanePixelSizeNumerator / samplePlaneGeometry.heightInMeters, + focal_length_m = zone_plate.get_focal_length_m(wavelength_m) + distance_m = focal_length_m + self.defocus_distance_m.get_value() + sample_plane_geometry = geometry_provider.get_probe_geometry() + fzp_half_width = (sample_plane_geometry.width_px + 1) // 2 + fzp_half_height = (sample_plane_geometry.height_px + 1) // 2 + fzp_plane_pixel_size_numerator = wavelength_m * distance_m + fzp_pixel_geometry = PixelGeometry( + width_m=fzp_plane_pixel_size_numerator / sample_plane_geometry.width_m, + height_m=fzp_plane_pixel_size_numerator / sample_plane_geometry.height_m, ) # coordinate on FZP plane - lx_fzp = -fzpPixelGeometry.widthInMeters * numpy.arange(-fzpHalfWidth, fzpHalfWidth) - ly_fzp = -fzpPixelGeometry.heightInMeters * numpy.arange(-fzpHalfHeight, fzpHalfHeight) + lx_fzp = -fzp_pixel_geometry.width_m * numpy.arange(-fzp_half_width, fzp_half_width) + ly_fzp = -fzp_pixel_geometry.height_m * numpy.arange(-fzp_half_height, fzp_half_height) - YY_FZP, XX_FZP = numpy.meshgrid(ly_fzp, lx_fzp) - RR_FZP = numpy.hypot(XX_FZP, YY_FZP) + YY_FZP, XX_FZP = numpy.meshgrid(ly_fzp, lx_fzp) # noqa: N806 + RR_FZP = numpy.hypot(XX_FZP, YY_FZP) # noqa: N806 # transmission function of FZP - T = numpy.exp( - -2j * numpy.pi / wavelengthInMeters * (XX_FZP**2 + YY_FZP**2) / 2 / focalLengthInMeters + T = numpy.exp( # noqa: N806 + -2j * numpy.pi / wavelength_m * (XX_FZP**2 + YY_FZP**2) / 2 / focal_length_m ) - C = RR_FZP <= zonePlate.zonePlateDiameterInMeters / 2 - H = RR_FZP >= zonePlate.centralBeamstopDiameterInMeters / 2 - fzpTransmissionFunction = T * C * H - - propagatorParameters = PropagatorParameters( - wavelength_m=wavelengthInMeters, - width_px=fzpTransmissionFunction.shape[-1], - height_px=fzpTransmissionFunction.shape[-2], - pixel_width_m=fzpPixelGeometry.widthInMeters, - pixel_height_m=fzpPixelGeometry.heightInMeters, - propagation_distance_m=distanceInMeters, + C = RR_FZP <= zone_plate.zone_plate_diameter_m / 2 # noqa: N806 + H = RR_FZP >= zone_plate.central_beamstop_diameter_m / 2 # noqa: N806 + fzp_transmission_function = T * C * H + + propagator_parameters = PropagatorParameters( + wavelength_m=wavelength_m, + width_px=fzp_transmission_function.shape[-1], + height_px=fzp_transmission_function.shape[-2], + pixel_width_m=fzp_pixel_geometry.width_m, + pixel_height_m=fzp_pixel_geometry.height_m, + propagation_distance_m=distance_m, ) - propagator = FresnelTransformPropagator(propagatorParameters) - array = propagator.propagate(fzpTransmissionFunction) + propagator = FresnelTransformPropagator(propagator_parameters) + array = propagator.propagate(fzp_transmission_function) - return Probe( + return ProbeSequence( array=self.normalize(array), - pixelWidthInMeters=samplePlaneGeometry.pixelWidthInMeters, - pixelHeightInMeters=samplePlaneGeometry.pixelHeightInMeters, + opr_weights=None, + pixel_geometry=sample_plane_geometry.get_pixel_geometry(), ) diff --git a/src/ptychodus/model/product/probe/item.py b/src/ptychodus/model/product/probe/item.py index b3f58615..c0da516c 100644 --- a/src/ptychodus/model/product/probe/item.py +++ b/src/ptychodus/model/product/probe/item.py @@ -3,9 +3,9 @@ from ptychodus.api.observer import Observable from ptychodus.api.parametric import ParameterGroup -from ptychodus.api.probe import Probe, ProbeGeometryProvider +from ptychodus.api.probe import ProbeSequence, ProbeGeometryProvider -from .builder import FromMemoryProbeBuilder, ProbeBuilder +from .builder import FromMemoryProbeBuilder, ProbeSequenceBuilder from .multimodal import MultimodalProbeBuilder from .settings import ProbeSettings @@ -15,74 +15,81 @@ class ProbeRepositoryItem(ParameterGroup): def __init__( self, - geometryProvider: ProbeGeometryProvider, + geometry_provider: ProbeGeometryProvider, settings: ProbeSettings, - builder: ProbeBuilder, - additionalModesBuilder: MultimodalProbeBuilder, + builder: ProbeSequenceBuilder, + additional_modes_builder: MultimodalProbeBuilder, ) -> None: super().__init__() - self._geometryProvider = geometryProvider + self._geometry_provider = geometry_provider self._settings = settings self._builder = builder - self._additionalModesBuilder = additionalModesBuilder - self._probe = Probe() + self._additional_modes_builder = additional_modes_builder + self._probe_seq = ProbeSequence(array=None, opr_weights=None, pixel_geometry=None) - self._addGroup('builder', builder, observe=True) - self._addGroup('additional_modes', additionalModesBuilder, observe=True) + self._add_group('builder', builder, observe=True) + self._add_group('additional_modes', additional_modes_builder, observe=True) self._rebuild() - def assignItem(self, item: ProbeRepositoryItem) -> None: - self._removeGroup('additional_modes') - self._additionalModesBuilder.removeObserver(self) - self._additionalModesBuilder = item.getAdditionalModesBuilder().copy() - self._additionalModesBuilder.addObserver(self) - self._addGroup('additional_modes', self._additionalModesBuilder, observe=True) + def assign_item(self, item: ProbeRepositoryItem) -> None: + group = 'additional_modes' - self.setBuilder(item.getBuilder().copy()) + self._remove_group(group) + self._additional_modes_builder.remove_observer(self) + + additional_modes_builder = item.get_additional_modes_builder() + + self._additional_modes_builder = additional_modes_builder.copy() + self._additional_modes_builder.add_observer(self) + self._add_group(group, self._additional_modes_builder, observe=True) + + self.set_builder(item.get_builder().copy()) self._rebuild() - def assign(self, probe: Probe) -> None: + def assign(self, probe: ProbeSequence) -> None: builder = FromMemoryProbeBuilder(self._settings, probe) - self.setBuilder(builder) + self.set_builder(builder) - def syncToSettings(self) -> None: + def sync_to_settings(self) -> None: for parameter in self.parameters().values(): - parameter.syncValueToParent() + parameter.sync_value_to_parent() - self._builder.syncToSettings() - self._additionalModesBuilder.syncToSettings() + self._builder.sync_to_settings() + self._additional_modes_builder.sync_to_settings() - def getProbe(self) -> Probe: - return self._probe + def get_probes(self) -> ProbeSequence: + return self._probe_seq - def getBuilder(self) -> ProbeBuilder: + def get_builder(self) -> ProbeSequenceBuilder: return self._builder - def setBuilder(self, builder: ProbeBuilder) -> None: - self._removeGroup('builder') - self._builder.removeObserver(self) + def set_builder(self, builder: ProbeSequenceBuilder) -> None: + group = 'builder' + self._remove_group(group) + self._builder.remove_observer(self) self._builder = builder - self._builder.addObserver(self) - self._addGroup('builder', self._builder, observe=True) + self._builder.add_observer(self) + self._add_group(group, self._builder, observe=True) + self._rebuild() def _rebuild(self) -> None: try: - probe = self._builder.build(self._geometryProvider) + probe = self._builder.build(self._geometry_provider) except Exception as exc: - logger.error(''.join(exc.args)) + logger.exception('Failed to rebuild probe!') return - self._probe = self._additionalModesBuilder.build(probe) - self.notifyObservers() + self._probe_seq = self._additional_modes_builder.build(probe, self._geometry_provider) + self.notify_observers() - def getAdditionalModesBuilder(self) -> MultimodalProbeBuilder: - return self._additionalModesBuilder + def get_additional_modes_builder(self) -> MultimodalProbeBuilder: + return self._additional_modes_builder - def update(self, observable: Observable) -> None: + def _update(self, observable: Observable) -> None: if observable is self._builder: self._rebuild() - elif observable is self._additionalModesBuilder: + elif observable is self._additional_modes_builder: self._rebuild() else: - super().update(observable) + super()._update(observable) diff --git a/src/ptychodus/model/product/probe/itemFactory.py b/src/ptychodus/model/product/probe/itemFactory.py deleted file mode 100644 index bf49109a..00000000 --- a/src/ptychodus/model/product/probe/itemFactory.py +++ /dev/null @@ -1,46 +0,0 @@ -import logging - -import numpy - -from ptychodus.api.probe import Probe, ProbeGeometryProvider - -from .builder import FromMemoryProbeBuilder -from .builderFactory import ProbeBuilderFactory -from .item import ProbeRepositoryItem -from .multimodal import MultimodalProbeBuilder -from .settings import ProbeSettings - -logger = logging.getLogger(__name__) - - -class ProbeRepositoryItemFactory: - def __init__( - self, - rng: numpy.random.Generator, - settings: ProbeSettings, - builderFactory: ProbeBuilderFactory, - ) -> None: - self._rng = rng - self._settings = settings - self._builderFactory = builderFactory - - def create( - self, geometryProvider: ProbeGeometryProvider, probe: Probe | None = None - ) -> ProbeRepositoryItem: - builder = ( - self._builderFactory.createDefault() - if probe is None - else FromMemoryProbeBuilder(self._settings, probe) - ) - multimodalBuilder = MultimodalProbeBuilder(self._rng, self._settings) - return ProbeRepositoryItem(geometryProvider, self._settings, builder, multimodalBuilder) - - def createFromSettings(self, geometryProvider: ProbeGeometryProvider) -> ProbeRepositoryItem: - try: - builder = self._builderFactory.createFromSettings() - except Exception as exc: - logger.error(''.join(exc.args)) - builder = self._builderFactory.createDefault() - - multimodalBuilder = MultimodalProbeBuilder(self._rng, self._settings) - return ProbeRepositoryItem(geometryProvider, self._settings, builder, multimodalBuilder) diff --git a/src/ptychodus/model/product/probe/item_factory.py b/src/ptychodus/model/product/probe/item_factory.py new file mode 100644 index 00000000..900243eb --- /dev/null +++ b/src/ptychodus/model/product/probe/item_factory.py @@ -0,0 +1,48 @@ +import logging + +import numpy.random + +from ptychodus.api.probe import ProbeSequence, ProbeGeometryProvider + +from .builder import FromMemoryProbeBuilder +from .builder_factory import ProbeBuilderFactory +from .item import ProbeRepositoryItem +from .multimodal import MultimodalProbeBuilder +from .settings import ProbeSettings + +logger = logging.getLogger(__name__) + + +class ProbeRepositoryItemFactory: + def __init__( + self, + rng: numpy.random.Generator, + settings: ProbeSettings, + builder_factory: ProbeBuilderFactory, + ) -> None: + self._rng = rng + self._settings = settings + self._builder_factory = builder_factory + + def create( + self, geometry_provider: ProbeGeometryProvider, probe: ProbeSequence | None = None + ) -> ProbeRepositoryItem: + multimodal_builder = MultimodalProbeBuilder(self._rng, self._settings) + + if probe is None: + builder = self._builder_factory.create_default() + else: + builder = FromMemoryProbeBuilder(self._settings, probe) + multimodal_builder.set_identity() + + return ProbeRepositoryItem(geometry_provider, self._settings, builder, multimodal_builder) + + def create_from_settings(self, geometry_provider: ProbeGeometryProvider) -> ProbeRepositoryItem: + try: + builder = self._builder_factory.create_from_settings() + except Exception as exc: + logger.error(''.join(exc.args)) + builder = self._builder_factory.create_default() + + multimodal_builder = MultimodalProbeBuilder(self._rng, self._settings) + return ProbeRepositoryItem(geometry_provider, self._settings, builder, multimodal_builder) diff --git a/src/ptychodus/model/product/probe/multimodal.py b/src/ptychodus/model/product/probe/multimodal.py index 1563ad6e..c6c0ee70 100644 --- a/src/ptychodus/model/product/probe/multimodal.py +++ b/src/ptychodus/model/product/probe/multimodal.py @@ -3,13 +3,13 @@ from enum import auto, IntEnum import logging -import numpy +import numpy.random import scipy.linalg -from ptychodus.api.parametric import ( - ParameterGroup, -) -from ptychodus.api.probe import Probe, WavefieldArrayType +from ptychodus.api.parametric import ParameterGroup +from ptychodus.api.probe import ProbeSequence, ProbeGeometryProvider +from ptychodus.api.propagator import intensity +from ptychodus.api.typing import ComplexArrayType, RealArrayType from .settings import ProbeSettings @@ -17,9 +17,21 @@ class ProbeModeDecayType(IntEnum): + NONE = auto() POLYNOMIAL = auto() EXPONENTIAL = auto() + def get_weights(self, num_modes: int, decay_ratio: float) -> Sequence[float]: + match self.value: + case ProbeModeDecayType.EXPONENTIAL: + b = 1.0 + (1.0 - decay_ratio) / decay_ratio + return [b**-n for n in range(num_modes)] + case ProbeModeDecayType.POLYNOMIAL: + b = numpy.log(decay_ratio) / numpy.log(2.0) + return [(n + 1) ** b for n in range(num_modes)] + case _: + return [1.0] + [0.0] * (num_modes - 1) + class MultimodalProbeBuilder(ParameterGroup): def __init__(self, rng: numpy.random.Generator, settings: ProbeSettings) -> None: @@ -27,108 +39,171 @@ def __init__(self, rng: numpy.random.Generator, settings: ProbeSettings) -> None self._rng = rng self._settings = settings - self.numberOfModes = settings.numberOfModes.copy() - self._addParameter('number_of_modes', self.numberOfModes) + self.num_incoherent_modes = settings.num_incoherent_modes.copy() + self._add_parameter('num_incoherent_modes', self.num_incoherent_modes) + + self.orthogonalize_incoherent_modes = settings.orthogonalize_incoherent_modes.copy() + self._add_parameter('orthogonalize_incoherent_modes', self.orthogonalize_incoherent_modes) - self.modeDecayType = settings.modeDecayType.copy() - self._addParameter('mode_decay_type', self.modeDecayType) + self.incoherent_mode_decay_type = settings.incoherent_mode_decay_type.copy() + self._add_parameter('incoherent_mode_decay_type', self.incoherent_mode_decay_type) - self.modeDecayRatio = settings.modeDecayRatio.copy() - self._addParameter('mode_decay_ratio', self.modeDecayRatio) + self.incoherent_mode_decay_ratio = settings.incoherent_mode_decay_ratio.copy() + self._add_parameter('incoherent_mode_decay_ratio', self.incoherent_mode_decay_ratio) - self.isOrthogonalizeModesEnabled = settings.isOrthogonalizeModesEnabled.copy() - self._addParameter('orthogonalize_modes', self.isOrthogonalizeModesEnabled) + self.num_coherent_modes = settings.num_coherent_modes.copy() + self._add_parameter('num_coherent_modes', self.num_coherent_modes) - def syncToSettings(self) -> None: + def sync_to_settings(self) -> None: for parameter in self.parameters().values(): - parameter.syncValueToParent() + parameter.sync_value_to_parent() def copy(self) -> MultimodalProbeBuilder: builder = MultimodalProbeBuilder(self._rng, self._settings) for key, value in self.parameters().items(): - builder.parameters()[key].setValue(value.getValue()) + builder.parameters()[key].set_value(value.get_value()) return builder - def _initializeModes(self, probe: WavefieldArrayType) -> WavefieldArrayType: - modeList: list[WavefieldArrayType] = list() - - if probe.ndim == 2: - modeList.append(probe) - elif probe.ndim >= 3: - probe3D = probe + def _orthogonalize_incoherent_modes(self, array_in: ComplexArrayType) -> ComplexArrayType: + array_out = array_in.copy() - while probe3D.ndim > 3: - probe3D = probe3D[0] + if array_in.shape[-3] > 1: + imodes_as_rows = array_in[0].reshape(array_in.shape[-3], -1) + imodes_as_cols = imodes_as_rows.T - for mode in probe3D: - modeList.append(mode) - else: - raise ValueError('Probe array must contain at least two dimensions.') - - for mode in range(self.numberOfModes.getValue() - 1): - # randomly shift the first mode - pw = probe.shape[-1] # TODO clean up - variate1 = self._rng.uniform(size=(2, 1)) - 0.5 - variate2 = (numpy.arange(0, pw) + 0.5) / pw - 0.5 - ps = numpy.exp(-2j * numpy.pi * variate1 * variate2) - phaseShift = ps[0][numpy.newaxis] * ps[1][:, numpy.newaxis] - mode = modeList[0] * phaseShift - modeList.append(mode) - - return numpy.stack(modeList) - - def _orthogonalizeModes(self, probe: WavefieldArrayType) -> WavefieldArrayType: - probeModesAsRows = probe.reshape(probe.shape[-3], -1) - probeModesAsCols = probeModesAsRows.T - probeModesAsOrthoCols = scipy.linalg.orth(probeModesAsCols) - probeModesAsOrthoRows = probeModesAsOrthoCols.T - return probeModesAsOrthoRows.reshape(*probe.shape) - - def _getModeWeights(self, totalNumberOfModes: int) -> Sequence[float]: - modeDecayTypeText = self.modeDecayType.getValue() - modeDecayRatio = self.modeDecayRatio.getValue() - - if modeDecayRatio > 0.0: try: - modeDecayType = ProbeModeDecayType[modeDecayTypeText.upper()] - except KeyError: - modeDecayType = ProbeModeDecayType.POLYNOMIAL - - if modeDecayType == ProbeModeDecayType.EXPONENTIAL: - b = 1.0 + (1.0 - modeDecayRatio) / modeDecayRatio - return [b**-n for n in range(totalNumberOfModes)] - else: - b = numpy.log(modeDecayRatio) / numpy.log(2.0) - return [(n + 1) ** b for n in range(totalNumberOfModes)] + imodes_as_ortho_cols = scipy.linalg.orth(imodes_as_cols) + except ValueError as ex: + logger.exception(ex) + return array_in.copy() - return [1.0] + [0.0] * (totalNumberOfModes - 1) + imodes_as_ortho_rows = imodes_as_ortho_cols.T + imodes_ortho = imodes_as_ortho_rows.reshape(*array_in.shape) - def _adjustRelativePower(self, probe: WavefieldArrayType) -> WavefieldArrayType: - modeWeights = self._getModeWeights(probe.shape[-3]) - power0 = numpy.sum(numpy.square(numpy.abs(probe[0, ...]))) - adjustedProbe = probe.copy() + array_out[0, :, :, :] = imodes_ortho - for modeIndex, weight in enumerate(modeWeights): - powerN = numpy.sum(numpy.square(numpy.abs(adjustedProbe[modeIndex, ...]))) - adjustedProbe[modeIndex, ...] *= numpy.sqrt(weight * power0 / powerN) + return array_out - return adjustedProbe + def _get_imode_weights(self, num_imodes: int) -> Sequence[float]: + imode_decay_type_text = self.incoherent_mode_decay_type.get_value() + imode_decay_ratio = self.incoherent_mode_decay_ratio.get_value() + imode_decay_type = ProbeModeDecayType.NONE - def build(self, probe: Probe) -> Probe: - if self.numberOfModes.getValue() <= 1: - return probe + if imode_decay_ratio > 0.0: + try: + imode_decay_type = ProbeModeDecayType[imode_decay_type_text.upper()] + except KeyError: + logger.debug(f'Unknown probe mode decay type "{imode_decay_type_text}"') + + return imode_decay_type.get_weights(num_imodes, imode_decay_ratio) + + def _adjust_imode_power(self, array_in: ComplexArrayType, power: float) -> ComplexArrayType: + imode_weights = self._get_imode_weights(array_in.shape[-3]) + array_out = array_in.copy() + it = iter(array_out[0]) # iterate incoherent modes + + for weight in imode_weights: + imode = next(it) + ipower = numpy.sum(numpy.square(numpy.abs(imode))) + imode *= numpy.sqrt(weight * power / ipower) + + return array_out + + def _random_phase_shift_axis(self, size: int) -> ComplexArrayType: + a = self._rng.uniform() - 0.5 + b = (size - 1 - 2 * numpy.arange(size)) / size + return numpy.exp(1j * numpy.pi * a * b) + + def _init_modes( + self, + geometry_provider: ProbeGeometryProvider, + array_in: ComplexArrayType, + normalize_cmodes: bool = True, + ) -> ComplexArrayType: + assert array_in.ndim == 4 + num_cmodes = self.num_coherent_modes.get_value() + num_imodes = self.num_incoherent_modes.get_value() + height = array_in.shape[-2] + width = array_in.shape[-1] + + array_out = numpy.zeros((num_cmodes, num_imodes, height, width), array_in.dtype) + + for cmode in range(num_cmodes): + if cmode < array_in.shape[0]: + # copy existing cmode + values = array_in[cmode, 0, :, :] + else: + # randomize new cmode + real = self._rng.normal(0.0, 1.0, size=(height, width)) + imag = self._rng.normal(0.0, 1.0, size=(height, width)) + values = real + 1j * imag - array = self._initializeModes(probe.array) + if normalize_cmodes: + values /= numpy.sqrt(numpy.mean(intensity(values))) - if self.isOrthogonalizeModesEnabled.getValue(): - array = self._orthogonalizeModes(array) + array_out[cmode, 0, :, :] = values - array = self._adjustRelativePower(array) + for imode in range(num_imodes): + if imode < array_in.shape[1]: + # copy existing imode + values = array_in[0, imode, :, :] + else: + # apply random phase shift to first imode + first_imode = array_in[0, 0, :, :] + phase_shift_y = self._random_phase_shift_axis(height) + phase_shift_x = self._random_phase_shift_axis(width) + values = first_imode * numpy.outer(phase_shift_y, phase_shift_x) + + array_out[0, imode, :, :] = values + + if self.orthogonalize_incoherent_modes.get_value(): + array_out = self._orthogonalize_incoherent_modes(array_out) + + if geometry_provider.probe_photon_count > 0.0: + array_out = self._adjust_imode_power(array_out, geometry_provider.probe_photon_count) + + return array_out + + def _init_opr_weights( + self, geometry_provider: ProbeGeometryProvider, small_value: float = 1.0e-6 + ) -> RealArrayType | None: + num_scan_points = geometry_provider.num_scan_points + num_cmodes = self.num_coherent_modes.get_value() + opr_weights: RealArrayType | None = None + + if self.num_coherent_modes.get_value() > 1: + opr_weights = small_value * self._rng.normal( + 0.0, 1.0, size=(num_scan_points, num_cmodes) + ) + assert opr_weights is not None # unnecessary but makes pylance less annoying + opr_weights[:, 0] = 1.0 + + return opr_weights + + def set_identity(self) -> None: + self.num_coherent_modes.set_value(1) + self.num_incoherent_modes.set_value(1) + + def build( + self, probes: ProbeSequence, geometry_provider: ProbeGeometryProvider + ) -> ProbeSequence: + if self.num_coherent_modes.get_value() <= 1 and self.num_incoherent_modes.get_value() <= 1: + return probes + + array = self._init_modes(geometry_provider, probes.get_array()) + + try: + opr_weights: RealArrayType | None = probes.get_opr_weights() + except ValueError: + opr_weights = self._init_opr_weights(geometry_provider) + else: + # TODO if opr_weights.shape[0] != num_scan_points: raise ValueError() + pass - return Probe( - array, - pixelWidthInMeters=probe.pixelWidthInMeters, - pixelHeightInMeters=probe.pixelHeightInMeters, + return ProbeSequence( + array=array, + opr_weights=opr_weights, + pixel_geometry=probes.get_pixel_geometry(), ) diff --git a/src/ptychodus/model/product/probe/rect.py b/src/ptychodus/model/product/probe/rect.py index 5715e834..473ec5be 100644 --- a/src/ptychodus/model/product/probe/rect.py +++ b/src/ptychodus/model/product/probe/rect.py @@ -2,61 +2,61 @@ import numpy -from ptychodus.api.probe import Probe, ProbeGeometryProvider +from ptychodus.api.probe import ProbeSequence, ProbeGeometryProvider from ptychodus.api.propagator import AngularSpectrumPropagator, PropagatorParameters -from .builder import ProbeBuilder +from .builder import ProbeSequenceBuilder from .settings import ProbeSettings -class RectangularProbeBuilder(ProbeBuilder): +class RectangularProbeBuilder(ProbeSequenceBuilder): def __init__(self, settings: ProbeSettings) -> None: super().__init__(settings, 'rectangular') self._settings = settings - self.widthInMeters = settings.rectangleWidthInMeters.copy() - self._addParameter('width_m', self.widthInMeters) + self.width_m = settings.rectangle_width_m.copy() + self._add_parameter('width_m', self.width_m) - self.heightInMeters = settings.rectangleHeightInMeters.copy() - self._addParameter('height_m', self.heightInMeters) + self.height_m = settings.rectangle_height_m.copy() + self._add_parameter('height_m', self.height_m) # from sample to the focal plane - self.defocusDistanceInMeters = settings.defocusDistanceInMeters.copy() - self._addParameter('defocus_distance_m', self.defocusDistanceInMeters) + self.defocus_distance_m = settings.defocus_distance_m.copy() + self._add_parameter('defocus_distance_m', self.defocus_distance_m) def copy(self) -> RectangularProbeBuilder: builder = RectangularProbeBuilder(self._settings) for key, value in self.parameters().items(): - builder.parameters()[key].setValue(value.getValue()) + builder.parameters()[key].set_value(value.get_value()) return builder - def build(self, geometryProvider: ProbeGeometryProvider) -> Probe: - geometry = geometryProvider.getProbeGeometry() - coords = self.getTransverseCoordinates(geometry) + def build(self, geometry_provider: ProbeGeometryProvider) -> ProbeSequence: + geometry = geometry_provider.get_probe_geometry() + coords = self.get_transverse_coordinates(geometry) - aX_m = numpy.fabs(coords.positionXInMeters) - rx_m = self.widthInMeters.getValue() / 2.0 - aY_m = numpy.fabs(coords.positionYInMeters) - ry_m = self.heightInMeters.getValue() / 2.0 + aX_m = numpy.fabs(coords.position_x_m) # noqa: N806 + rx_m = self.width_m.get_value() / 2.0 + aY_m = numpy.fabs(coords.position_y_m) # noqa: N806 + ry_m = self.height_m.get_value() / 2.0 - isInside = numpy.logical_and(aX_m < rx_m, aY_m < ry_m) - rect = numpy.where(isInside, 1 + 0j, 0j) + is_inside = numpy.logical_and(aX_m < rx_m, aY_m < ry_m) + rect = numpy.where(is_inside, 1 + 0j, 0j) - propagatorParameters = PropagatorParameters( - wavelength_m=geometryProvider.probeWavelengthInMeters, + propagator_parameters = PropagatorParameters( + wavelength_m=geometry_provider.probe_wavelength_m, width_px=rect.shape[-1], height_px=rect.shape[-2], - pixel_width_m=geometry.pixelWidthInMeters, - pixel_height_m=geometry.pixelHeightInMeters, - propagation_distance_m=self.defocusDistanceInMeters.getValue(), + pixel_width_m=geometry.pixel_width_m, + pixel_height_m=geometry.pixel_height_m, + propagation_distance_m=self.defocus_distance_m.get_value(), ) - propagator = AngularSpectrumPropagator(propagatorParameters) + propagator = AngularSpectrumPropagator(propagator_parameters) array = propagator.propagate(rect) - return Probe( + return ProbeSequence( array=self.normalize(array), - pixelWidthInMeters=geometry.pixelWidthInMeters, - pixelHeightInMeters=geometry.pixelHeightInMeters, + opr_weights=None, + pixel_geometry=geometry.get_pixel_geometry(), ) diff --git a/src/ptychodus/model/product/probe/settings.py b/src/ptychodus/model/product/probe/settings.py index fed8ee7a..8a10aa67 100644 --- a/src/ptychodus/model/product/probe/settings.py +++ b/src/ptychodus/model/product/probe/settings.py @@ -7,61 +7,60 @@ class ProbeSettings(Observable, Observer): def __init__(self, registry: SettingsRegistry) -> None: super().__init__() - self._settingsGroup = registry.createGroup('Probe') - self._settingsGroup.addObserver(self) + self._group = registry.create_group('Probe') + self._group.add_observer(self) - self.builder = self._settingsGroup.createStringParameter('Builder', 'Disk') - self.filePath = self._settingsGroup.createPathParameter( - 'FilePath', Path('/path/to/probe.npy') - ) - self.fileType = self._settingsGroup.createStringParameter('FileType', 'NPY') + self.builder = self._group.create_string_parameter('Builder', 'Disk') + self.file_path = self._group.create_path_parameter('FilePath', Path('/path/to/probe.npy')) + self.file_type = self._group.create_string_parameter('FileType', 'NPY') - self.numberOfModes = self._settingsGroup.createIntegerParameter( - 'NumberOfModes', 1, minimum=1 + self.num_incoherent_modes = self._group.create_integer_parameter( + 'NumberOfIncoherentModes', 1, minimum=1 + ) + self.orthogonalize_incoherent_modes = self._group.create_boolean_parameter( + 'OrthogonalizeIncoherentModes', True ) - self.isOrthogonalizeModesEnabled = self._settingsGroup.createBooleanParameter( - 'OrthogonalizeModesEnabled', True + self.incoherent_mode_decay_type = self._group.create_string_parameter( + 'IncoherentModeDecayType', 'Polynomial' ) - self.modeDecayType = self._settingsGroup.createStringParameter( - 'ModeDecayType', 'Polynomial' + self.incoherent_mode_decay_ratio = self._group.create_real_parameter( + 'IncoherentModeDecayRatio', 1.0, minimum=0.0, maximum=1.0 ) - self.modeDecayRatio = self._settingsGroup.createRealParameter( - 'ModeDecayRatio', 1.0, minimum=0.0, maximum=1.0 + self.num_coherent_modes = self._group.create_integer_parameter( + 'NumberOfCoherentModes', 1, minimum=1 ) - self.diskDiameterInMeters = self._settingsGroup.createRealParameter( + self.disk_diameter_m = self._group.create_real_parameter( 'DiskDiameterInMeters', 1e-6, minimum=0.0 ) - self.rectangleWidthInMeters = self._settingsGroup.createRealParameter( + self.rectangle_width_m = self._group.create_real_parameter( 'RectangleWidthInMeters', 1e-6, minimum=0.0 ) - self.rectangleHeightInMeters = self._settingsGroup.createRealParameter( + self.rectangle_height_m = self._group.create_real_parameter( 'RectangleHeightInMeters', 1e-6, minimum=0.0 ) - self.superGaussianAnnularRadiusInMeters = self._settingsGroup.createRealParameter( + self.super_gaussian_annular_radius_m = self._group.create_real_parameter( 'SuperGaussianAnnularRadiusInMeters', 0, minimum=0.0 ) - self.superGaussianWidthInMeters = self._settingsGroup.createRealParameter( + self.super_gaussian_width_m = self._group.create_real_parameter( 'SuperGaussianWidthInMeters', 400e-6, minimum=0.0 ) - self.superGaussianOrderParameter = self._settingsGroup.createRealParameter( + self.super_gaussian_order_parameter = self._group.create_real_parameter( 'SuperGaussianOrderParameter', 1, minimum=1.0 ) - self.zonePlateDiameterInMeters = self._settingsGroup.createRealParameter( + self.zone_plate_diameter_m = self._group.create_real_parameter( 'ZonePlateDiameterInMeters', 180e-6, minimum=0.0 ) - self.outermostZoneWidthInMeters = self._settingsGroup.createRealParameter( + self.outermost_zone_width_m = self._group.create_real_parameter( 'OutermostZoneWidthInMeters', 50e-9, minimum=0.0 ) - self.centralBeamstopDiameterInMeters = self._settingsGroup.createRealParameter( + self.central_beamstop_diameter_m = self._group.create_real_parameter( 'CentralBeamstopDiameterInMeters', 60e-6, minimum=0.0 ) - self.defocusDistanceInMeters = self._settingsGroup.createRealParameter( - 'DefocusDistanceInMeters', 0.0 - ) + self.defocus_distance_m = self._group.create_real_parameter('DefocusDistanceInMeters', 0.0) - def update(self, observable: Observable) -> None: - if observable is self._settingsGroup: - self.notifyObservers() + def _update(self, observable: Observable) -> None: + if observable is self._group: + self.notify_observers() diff --git a/src/ptychodus/model/product/probe/superGaussian.py b/src/ptychodus/model/product/probe/superGaussian.py deleted file mode 100644 index 9df3b494..00000000 --- a/src/ptychodus/model/product/probe/superGaussian.py +++ /dev/null @@ -1,46 +0,0 @@ -from __future__ import annotations - -import numpy - -from ptychodus.api.probe import Probe, ProbeGeometryProvider - -from .builder import ProbeBuilder -from .settings import ProbeSettings - - -class SuperGaussianProbeBuilder(ProbeBuilder): - def __init__(self, settings: ProbeSettings) -> None: - super().__init__(settings, 'super_gaussian') - self._settings = settings - - self.annularRadiusInMeters = settings.superGaussianAnnularRadiusInMeters.copy() - self._addParameter('annular_radius_m', self.annularRadiusInMeters) - - self.fwhmInMeters = settings.superGaussianWidthInMeters.copy() - self._addParameter('full_width_at_half_maximum_m', self.fwhmInMeters) - - self.orderParameter = settings.superGaussianOrderParameter.copy() - self._addParameter('order_parameter', self.orderParameter) - - def copy(self) -> SuperGaussianProbeBuilder: - builder = SuperGaussianProbeBuilder(self._settings) - - for key, value in self.parameters().items(): - builder.parameters()[key].setValue(value.getValue()) - - return builder - - def build(self, geometryProvider: ProbeGeometryProvider) -> Probe: - geometry = geometryProvider.getProbeGeometry() - coords = self.getTransverseCoordinates(geometry) - - Z = ( - coords.positionRInMeters - self.annularRadiusInMeters.getValue() - ) / self.fwhmInMeters.getValue() - ZP = numpy.power(2 * Z, 2 * self.orderParameter.getValue()) - - return Probe( - array=self.normalize(numpy.exp(-numpy.log(2) * ZP) + 0j), - pixelWidthInMeters=geometry.pixelWidthInMeters, - pixelHeightInMeters=geometry.pixelHeightInMeters, - ) diff --git a/src/ptychodus/model/product/probe/super_gaussian.py b/src/ptychodus/model/product/probe/super_gaussian.py new file mode 100644 index 00000000..1e8d6439 --- /dev/null +++ b/src/ptychodus/model/product/probe/super_gaussian.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +import numpy + +from ptychodus.api.probe import ProbeSequence, ProbeGeometryProvider + +from .builder import ProbeSequenceBuilder +from .settings import ProbeSettings + + +class SuperGaussianProbeBuilder(ProbeSequenceBuilder): + def __init__(self, settings: ProbeSettings) -> None: + super().__init__(settings, 'super_gaussian') + self._settings = settings + + self.annular_radius_m = settings.super_gaussian_annular_radius_m.copy() + self._add_parameter('annular_radius_m', self.annular_radius_m) + + self.fwhm_m = settings.super_gaussian_width_m.copy() + self._add_parameter('full_width_at_half_maximum_m', self.fwhm_m) + + self.order_parameter = settings.super_gaussian_order_parameter.copy() + self._add_parameter('order_parameter', self.order_parameter) + + def copy(self) -> SuperGaussianProbeBuilder: + builder = SuperGaussianProbeBuilder(self._settings) + + for key, value in self.parameters().items(): + builder.parameters()[key].set_value(value.get_value()) + + return builder + + def build(self, geometry_provider: ProbeGeometryProvider) -> ProbeSequence: + geometry = geometry_provider.get_probe_geometry() + coords = self.get_transverse_coordinates(geometry) + + Z = ( # noqa: N806 + coords.position_r_m - self.annular_radius_m.get_value() + ) / self.fwhm_m.get_value() + ZP = numpy.power(2 * Z, 2 * self.order_parameter.get_value()) # noqa: N806 + + return ProbeSequence( + array=self.normalize(numpy.exp(-numpy.log(2) * ZP) + 0j), + opr_weights=None, + pixel_geometry=geometry.get_pixel_geometry(), + ) diff --git a/src/ptychodus/model/product/probe/zernike.py b/src/ptychodus/model/product/probe/zernike.py index f2c03c4c..4c5afa2f 100644 --- a/src/ptychodus/model/product/probe/zernike.py +++ b/src/ptychodus/model/product/probe/zernike.py @@ -6,10 +6,10 @@ import numpy.typing import scipy.special -from ptychodus.api.probe import Probe, ProbeGeometryProvider +from ptychodus.api.probe import ProbeSequence, ProbeGeometryProvider from ptychodus.api.typing import RealArrayType -from .builder import ProbeBuilder +from .builder import ProbeSequenceBuilder from .settings import ProbeSettings logger = logging.getLogger(__name__) @@ -73,29 +73,29 @@ def __str__(self) -> str: return f'$Z_{{{self.radial_degree}}}^{{{self.angular_frequency:+d}}}$' -class ZernikeProbeBuilder(ProbeBuilder): +class ZernikeProbeBuilder(ProbeSequenceBuilder): def __init__(self, settings: ProbeSettings) -> None: super().__init__(settings, 'zernike') self._settings = settings self._polynomials: list[ZernikePolynomial] = list() self._order = 0 - self.diameterInMeters = settings.diskDiameterInMeters.copy() - self._addParameter('diameter_m', self.diameterInMeters) + self.diameter_m = settings.disk_diameter_m.copy() + self._add_parameter('diameter_m', self.diameter_m) # TODO init zernike coefficients from settings - self.coefficients = self.createComplexArrayParameter('coefficients', [1 + 0j]) + self.coefficients = self.create_complex_sequence_parameter('coefficients', [1 + 0j]) - self.setOrder(1) + self.set_order(1) def copy(self) -> ZernikeProbeBuilder: builder = ZernikeProbeBuilder(self._settings) - builder.diameterInMeters.setValue(self.diameterInMeters.getValue()) - builder.coefficients.setValue(self.coefficients.getValue()) - builder.setOrder(self.getOrder()) + builder.diameter_m.set_value(self.diameter_m.get_value()) + builder.coefficients.set_value(self.coefficients.get_value()) + builder.set_order(self.get_order()) return builder - def setOrder(self, order: int) -> None: + def set_order(self, order: int) -> None: if order < 1: logger.warning('Order must be strictly positive!') return @@ -114,42 +114,42 @@ def setOrder(self, order: int) -> None: ncoef = len(self.coefficients) if ncoef < npoly: - coef = list(self.coefficients.getValue()) + coef = list(self.coefficients.get_value()) coef += [0j] * (npoly - ncoef) - self.coefficients.setValue(coef) + self.coefficients.set_value(coef) self._order = order - self.notifyObservers() + self.notify_observers() - def getOrder(self) -> int: + def get_order(self) -> int: return self._order - def setCoefficient(self, idx: int, value: complex) -> None: + def set_coefficient(self, idx: int, value: complex) -> None: self.coefficients[idx] = value - def getCoefficient(self, idx: int) -> complex: + def get_coefficient(self, idx: int) -> complex: return self.coefficients[idx] - def getPolynomial(self, idx: int) -> ZernikePolynomial: + def get_polynomial(self, idx: int) -> ZernikePolynomial: return self._polynomials[idx] def __len__(self) -> int: return min(len(self.coefficients), len(self._polynomials)) - def build(self, geometryProvider: ProbeGeometryProvider) -> Probe: - geometry = geometryProvider.getProbeGeometry() - coords = self.getTransverseCoordinates(geometry) + def build(self, geometry_provider: ProbeGeometryProvider) -> ProbeSequence: + geometry = geometry_provider.get_probe_geometry() + coords = self.get_transverse_coordinates(geometry) - radius = self.diameterInMeters.getValue() / 2.0 - distance = numpy.hypot(coords.positionYInMeters, coords.positionXInMeters) / radius - angle = numpy.arctan2(coords.positionYInMeters, coords.positionXInMeters) + radius = self.diameter_m.get_value() / 2.0 + distance = numpy.hypot(coords.position_y_m, coords.position_x_m) / radius + angle = numpy.arctan2(coords.position_y_m, coords.position_x_m) array = numpy.zeros_like(distance, dtype=complex) for coef, poly in zip(self.coefficients, self._polynomials): array += numpy.multiply(coef, poly(distance, angle)) - return Probe( + return ProbeSequence( array=self.normalize(array), - pixelWidthInMeters=geometry.pixelWidthInMeters, - pixelHeightInMeters=geometry.pixelHeightInMeters, + opr_weights=None, + pixel_geometry=geometry.get_pixel_geometry(), ) diff --git a/src/ptychodus/model/product/probeRepository.py b/src/ptychodus/model/product/probeRepository.py deleted file mode 100644 index 172f04cf..00000000 --- a/src/ptychodus/model/product/probeRepository.py +++ /dev/null @@ -1,65 +0,0 @@ -from collections.abc import Sequence -from typing import overload -import logging - -from ptychodus.api.observer import ObservableSequence - -from .item import ProductRepositoryItem, ProductRepositoryObserver -from .metadata import MetadataRepositoryItem -from .object import ObjectRepositoryItem -from .probe import ProbeRepositoryItem -from .productRepository import ProductRepository -from .scan import ScanRepositoryItem - -logger = logging.getLogger(__name__) - - -class ProbeRepository(ObservableSequence[ProbeRepositoryItem], ProductRepositoryObserver): - def __init__(self, repository: ProductRepository) -> None: - super().__init__() - self._repository = repository - self._repository.addObserver(self) - - def getName(self, index: int) -> str: - return self._repository[index].getName() - - def setName(self, index: int, name: str) -> None: - self._repository[index].setName(name) - - @overload - def __getitem__(self, index: int) -> ProbeRepositoryItem: ... - - @overload - def __getitem__(self, index: slice) -> Sequence[ProbeRepositoryItem]: ... - - def __getitem__( - self, index: int | slice - ) -> ProbeRepositoryItem | Sequence[ProbeRepositoryItem]: - if isinstance(index, slice): - return [item.getProbe() for item in self._repository[index]] - else: - return self._repository[index].getProbe() - - def __len__(self) -> int: - return len(self._repository) - - def handleItemInserted(self, index: int, item: ProductRepositoryItem) -> None: - self.notifyObserversItemInserted(index, item.getProbe()) - - def handleMetadataChanged(self, index: int, item: MetadataRepositoryItem) -> None: - pass - - def handleScanChanged(self, index: int, item: ScanRepositoryItem) -> None: - pass - - def handleProbeChanged(self, index: int, item: ProbeRepositoryItem) -> None: - self.notifyObserversItemChanged(index, item) - - def handleObjectChanged(self, index: int, item: ObjectRepositoryItem) -> None: - pass - - def handleCostsChanged(self, index: int, costs: Sequence[float]) -> None: - pass - - def handleItemRemoved(self, index: int, item: ProductRepositoryItem) -> None: - self.notifyObserversItemRemoved(index, item.getProbe()) diff --git a/src/ptychodus/model/product/probe_repository.py b/src/ptychodus/model/product/probe_repository.py new file mode 100644 index 00000000..62e66885 --- /dev/null +++ b/src/ptychodus/model/product/probe_repository.py @@ -0,0 +1,65 @@ +from collections.abc import Sequence +from typing import overload +import logging + +from ptychodus.api.observer import ObservableSequence + +from .item import ProductRepositoryItem, ProductRepositoryObserver +from .metadata import MetadataRepositoryItem +from .object import ObjectRepositoryItem +from .probe import ProbeRepositoryItem +from .repository import ProductRepository +from .scan import ScanRepositoryItem + +logger = logging.getLogger(__name__) + + +class ProbeRepository(ObservableSequence[ProbeRepositoryItem], ProductRepositoryObserver): + def __init__(self, repository: ProductRepository) -> None: + super().__init__() + self._repository = repository + self._repository.add_observer(self) + + def get_name(self, index: int) -> str: + return self._repository[index].get_name() + + def set_name(self, index: int, name: str) -> None: + self._repository[index].set_name(name) + + @overload + def __getitem__(self, index: int) -> ProbeRepositoryItem: ... + + @overload + def __getitem__(self, index: slice) -> Sequence[ProbeRepositoryItem]: ... + + def __getitem__( + self, index: int | slice + ) -> ProbeRepositoryItem | Sequence[ProbeRepositoryItem]: + if isinstance(index, slice): + return [item.get_probe_item() for item in self._repository[index]] + else: + return self._repository[index].get_probe_item() + + def __len__(self) -> int: + return len(self._repository) + + def handle_item_inserted(self, index: int, item: ProductRepositoryItem) -> None: + self.notify_observers_item_inserted(index, item.get_probe_item()) + + def handle_metadata_changed(self, index: int, item: MetadataRepositoryItem) -> None: + pass + + def handle_scan_changed(self, index: int, item: ScanRepositoryItem) -> None: + pass + + def handle_probe_changed(self, index: int, item: ProbeRepositoryItem) -> None: + self.notify_observers_item_changed(index, item) + + def handle_object_changed(self, index: int, item: ObjectRepositoryItem) -> None: + pass + + def handle_costs_changed(self, index: int, costs: Sequence[float]) -> None: + pass + + def handle_item_removed(self, index: int, item: ProductRepositoryItem) -> None: + self.notify_observers_item_removed(index, item.get_probe_item()) diff --git a/src/ptychodus/model/product/productGeometry.py b/src/ptychodus/model/product/productGeometry.py deleted file mode 100644 index 04a744cb..00000000 --- a/src/ptychodus/model/product/productGeometry.py +++ /dev/null @@ -1,138 +0,0 @@ -import numpy - -from ptychodus.api.constants import ( - ELECTRON_VOLT_J, - LIGHT_SPEED_M_PER_S, - PLANCK_CONSTANT_J_PER_HZ, -) -from ptychodus.api.geometry import PixelGeometry -from ptychodus.api.object import ObjectGeometry, ObjectGeometryProvider -from ptychodus.api.observer import Observable, Observer -from ptychodus.api.probe import ProbeGeometry, ProbeGeometryProvider - -from ..patterns import PatternSizer -from .metadata import MetadataRepositoryItem -from .scan import ScanRepositoryItem - - -class ProductGeometry(ProbeGeometryProvider, ObjectGeometryProvider, Observable, Observer): - def __init__( - self, - patternSizer: PatternSizer, - metadata: MetadataRepositoryItem, - scan: ScanRepositoryItem, - ) -> None: - super().__init__() - self._patternSizer = patternSizer - self._metadata = metadata - self._scan = scan - - self._patternSizer.addObserver(self) - self._metadata.addObserver(self) - self._scan.addObserver(self) - - @property - def probeEnergyInJoules(self) -> float: - return self._metadata.probeEnergyInElectronVolts.getValue() * ELECTRON_VOLT_J - - @property - def probeWavelengthInMeters(self) -> float: - hc_Jm = PLANCK_CONSTANT_J_PER_HZ * LIGHT_SPEED_M_PER_S - - try: - return hc_Jm / self.probeEnergyInJoules - except ZeroDivisionError: - return 0.0 - - @property - def detectorDistanceInMeters(self) -> float: - return self._metadata.detectorDistanceInMeters.getValue() - - @property - def probePowerInWatts(self) -> float: - return self.probeEnergyInJoules * self._metadata.probePhotonsPerSecond.getValue() - - @property - def _lambdaZInSquareMeters(self) -> float: - return self.probeWavelengthInMeters * self.detectorDistanceInMeters - - @property - def objectPlanePixelWidthInMeters(self) -> float: - return self._lambdaZInSquareMeters / self._patternSizer.getWidthInMeters() - - @property - def objectPlanePixelHeightInMeters(self) -> float: - return self._lambdaZInSquareMeters / self._patternSizer.getHeightInMeters() - - def getPixelGeometry(self) -> PixelGeometry: - return PixelGeometry( - widthInMeters=self.objectPlanePixelWidthInMeters, - heightInMeters=self.objectPlanePixelHeightInMeters, - ) - - @property - def fresnelNumber(self) -> float: - widthInMeters = self._patternSizer.getWidthInMeters() - heightInMeters = self._patternSizer.getHeightInMeters() - sizeInMeters = max(widthInMeters, heightInMeters) - return sizeInMeters**2 / self._lambdaZInSquareMeters - - def getProbeGeometry(self) -> ProbeGeometry: - extent = self._patternSizer.getImageExtent() - return ProbeGeometry( - widthInPixels=extent.widthInPixels, - heightInPixels=extent.heightInPixels, - pixelWidthInMeters=self.objectPlanePixelWidthInMeters, - pixelHeightInMeters=self.objectPlanePixelHeightInMeters, - ) - - def isProbeGeometryValid(self, geometry: ProbeGeometry) -> bool: - expected = self.getProbeGeometry() - widthIsValid = ( - geometry.pixelWidthInMeters > 0.0 and geometry.widthInMeters == expected.widthInMeters - ) - heightIsValid = ( - geometry.pixelHeightInMeters > 0.0 - and geometry.heightInMeters == expected.heightInMeters - ) - return widthIsValid and heightIsValid - - def getObjectGeometry(self) -> ObjectGeometry: - probeGeometry = self.getProbeGeometry() - widthInMeters = probeGeometry.widthInMeters - heightInMeters = probeGeometry.heightInMeters - centerXInMeters = 0.0 - centerYInMeters = 0.0 - - scanBoundingBox = self._scan.getBoundingBox() - - if scanBoundingBox is not None: - widthInMeters += scanBoundingBox.widthInMeters - heightInMeters += scanBoundingBox.heightInMeters - centerXInMeters = scanBoundingBox.centerXInMeters - centerYInMeters = scanBoundingBox.centerYInMeters - - widthInPixels = widthInMeters / self.objectPlanePixelWidthInMeters - heightInPixels = heightInMeters / self.objectPlanePixelHeightInMeters - - return ObjectGeometry( - widthInPixels=int(numpy.ceil(widthInPixels)), - heightInPixels=int(numpy.ceil(heightInPixels)), - pixelWidthInMeters=self.objectPlanePixelWidthInMeters, - pixelHeightInMeters=self.objectPlanePixelHeightInMeters, - centerXInMeters=centerXInMeters, - centerYInMeters=centerYInMeters, - ) - - def isObjectGeometryValid(self, geometry: ObjectGeometry) -> bool: - expectedGeometry = self.getObjectGeometry() - pixelSizeIsValid = geometry.pixelWidthInMeters > 0.0 and geometry.pixelHeightInMeters > 0.0 - return pixelSizeIsValid and geometry.contains(expectedGeometry) - - def update(self, observable: Observable) -> None: - if observable is self._metadata: - self.notifyObservers() - elif observable is self._scan: - self.notifyObservers() - elif observable is self._patternSizer: - self.notifyObservers() diff --git a/src/ptychodus/model/product/productRepository.py b/src/ptychodus/model/product/productRepository.py deleted file mode 100644 index 71e7751c..00000000 --- a/src/ptychodus/model/product/productRepository.py +++ /dev/null @@ -1,233 +0,0 @@ -from __future__ import annotations -from collections.abc import Sequence -from typing import overload -import logging -import sys - -from ptychodus.api.product import Product - -from ..patterns import ActiveDiffractionDataset, PatternSizer, ProductSettings -from .item import ( - ProductRepositoryItem, - ProductRepositoryItemObserver, - ProductRepositoryObserver, -) -from .metadataFactory import MetadataRepositoryItemFactory -from .object import ObjectRepositoryItemFactory -from .probe import ProbeRepositoryItemFactory -from .productGeometry import ProductGeometry -from .productValidator import ProductValidator -from .scan import ScanRepositoryItemFactory - -logger = logging.getLogger(__name__) - - -class ProductRepository(Sequence[ProductRepositoryItem], ProductRepositoryItemObserver): - def __init__( - self, - settings: ProductSettings, - patternSizer: PatternSizer, - patterns: ActiveDiffractionDataset, - scanRepositoryItemFactory: ScanRepositoryItemFactory, - probeRepositoryItemFactory: ProbeRepositoryItemFactory, - objectRepositoryItemFactory: ObjectRepositoryItemFactory, - ) -> None: - super().__init__() - self._settings = settings - self._patternSizer = patternSizer - self._patterns = patterns - self._scanRepositoryItemFactory = scanRepositoryItemFactory - self._probeRepositoryItemFactory = probeRepositoryItemFactory - self._objectRepositoryItemFactory = objectRepositoryItemFactory - self._itemList: list[ProductRepositoryItem] = list() - self._metadataRepositoryItemFactory = MetadataRepositoryItemFactory(self, settings) - self._observerList: list[ProductRepositoryObserver] = [ - self._metadataRepositoryItemFactory, # NOTE must be first! - ] - - @overload - def __getitem__(self, index: int) -> ProductRepositoryItem: ... - - @overload - def __getitem__(self, index: slice) -> Sequence[ProductRepositoryItem]: ... - - def __getitem__( - self, index: int | slice - ) -> ProductRepositoryItem | Sequence[ProductRepositoryItem]: - return self._itemList[index] - - def __len__(self) -> int: - return len(self._itemList) - - def _insertProduct(self, item: ProductRepositoryItem) -> int: - index = len(self._itemList) - self._itemList.append(item) - - for observer in self._observerList: - observer.handleItemInserted(index, item) - - return index - - def insertNewProduct( - self, - *, - name: str = '', - comments: str = '', - detectorDistanceInMeters: float | None = None, - probeEnergyInElectronVolts: float | None = None, - probePhotonsPerSecond: float | None = None, - exposureTimeInSeconds: float | None = None, - likeIndex: int, - ) -> int: - metadataItem = self._metadataRepositoryItemFactory.createDefault( - name=name, - comments=comments, - detectorDistanceInMeters=detectorDistanceInMeters, - probeEnergyInElectronVolts=probeEnergyInElectronVolts, - probePhotonsPerSecond=probePhotonsPerSecond, - exposureTimeInSeconds=exposureTimeInSeconds, - ) - scanItem = self._scanRepositoryItemFactory.create() - geometry = ProductGeometry(self._patternSizer, metadataItem, scanItem) - probeItem = self._probeRepositoryItemFactory.create(geometry) - objectItem = self._objectRepositoryItemFactory.create(geometry) - - item = ProductRepositoryItem( - parent=self, - metadata=metadataItem, - scan=scanItem, - geometry=geometry, - probe=probeItem, - object_=objectItem, - validator=ProductValidator(self._patterns, scanItem, geometry, probeItem, objectItem), - costs=list(), - ) - - if likeIndex >= 0: - item.assignItem(self._itemList[likeIndex], notify=False) - - return self._insertProduct(item) - - def insertProductFromSettings(self) -> int: - # TODO add mechanism to sync product state to settings - metadataItem = self._metadataRepositoryItemFactory.createDefault() - scanItem = self._scanRepositoryItemFactory.createFromSettings() - geometry = ProductGeometry(self._patternSizer, metadataItem, scanItem) - probeItem = self._probeRepositoryItemFactory.createFromSettings(geometry) - objectItem = self._objectRepositoryItemFactory.createFromSettings(geometry) - - item = ProductRepositoryItem( - parent=self, - metadata=metadataItem, - scan=scanItem, - geometry=geometry, - probe=probeItem, - object_=objectItem, - validator=ProductValidator(self._patterns, scanItem, geometry, probeItem, objectItem), - costs=list(), - ) - - return self._insertProduct(item) - - def insertProduct(self, product: Product) -> int: - metadataItem = self._metadataRepositoryItemFactory.create(product.metadata) - scanItem = self._scanRepositoryItemFactory.create(product.scan) - geometry = ProductGeometry(self._patternSizer, metadataItem, scanItem) - probeItem = self._probeRepositoryItemFactory.create(geometry, product.probe) - objectItem = self._objectRepositoryItemFactory.create(geometry, product.object_) - - item = ProductRepositoryItem( - parent=self, - metadata=metadataItem, - scan=scanItem, - geometry=geometry, - probe=probeItem, - object_=objectItem, - validator=ProductValidator(self._patterns, scanItem, geometry, probeItem, objectItem), - costs=product.costs, - ) - - return self._insertProduct(item) - - def removeProduct(self, index: int) -> None: - try: - item = self._itemList.pop(index) - except IndexError: - logger.debug(f'Failed to remove product item {index}!') - return - - for observer in self._observerList: - observer.handleItemRemoved(index, item) - - def getInfoText(self) -> str: - sizeInMB = sum(sys.getsizeof(prod) for prod in self._itemList) / (1024 * 1024) - return f'Total: {len(self)} [{sizeInMB:.2f}MB]' - - def addObserver(self, observer: ProductRepositoryObserver) -> None: - if observer not in self._observerList: - self._observerList.append(observer) - - def removeObserver(self, observer: ProductRepositoryObserver) -> None: - try: - self._observerList.remove(observer) - except ValueError: - pass - - def handleMetadataChanged(self, item: ProductRepositoryItem) -> None: - metadata = item.getMetadata() - index = metadata.getIndex() - - if index < 0: - logger.warning(f'Failed to look up index for "{item.getName()}"!') - return - - for observer in self._observerList: - observer.handleMetadataChanged(index, metadata) - - def handleScanChanged(self, item: ProductRepositoryItem) -> None: - metadata = item.getMetadata() - index = metadata.getIndex() - scan = item.getScan() - - if index < 0: - logger.warning(f'Failed to look up index for "{item.getName()}"!') - return - - for observer in self._observerList: - observer.handleScanChanged(index, scan) - - def handleProbeChanged(self, item: ProductRepositoryItem) -> None: - metadata = item.getMetadata() - index = metadata.getIndex() - probe = item.getProbe() - - if index < 0: - logger.warning(f'Failed to look up index for "{item.getName()}"!') - return - - for observer in self._observerList: - observer.handleProbeChanged(index, probe) - - def handleObjectChanged(self, item: ProductRepositoryItem) -> None: - metadata = item.getMetadata() - index = metadata.getIndex() - object_ = item.getObject() - - if index < 0: - logger.warning(f'Failed to look up index for "{item.getName()}"!') - return - - for observer in self._observerList: - observer.handleObjectChanged(index, object_) - - def handleCostsChanged(self, item: ProductRepositoryItem) -> None: - metadata = item.getMetadata() - index = metadata.getIndex() - costs = item.getCosts() - - if index < 0: - logger.warning(f'Failed to look up index for "{item.getName()}"!') - return - - for observer in self._observerList: - observer.handleCostsChanged(index, costs) diff --git a/src/ptychodus/model/product/productValidator.py b/src/ptychodus/model/product/productValidator.py deleted file mode 100644 index 21701fdf..00000000 --- a/src/ptychodus/model/product/productValidator.py +++ /dev/null @@ -1,78 +0,0 @@ -from ptychodus.api.observer import Observable, Observer - -from ..patterns import ActiveDiffractionDataset -from .object import ObjectRepositoryItem -from .probe import ProbeRepositoryItem -from .productGeometry import ProductGeometry -from .scan import ScanRepositoryItem - - -class ProductValidator(Observable, Observer): - def __init__( - self, - patterns: ActiveDiffractionDataset, - scan: ScanRepositoryItem, - geometry: ProductGeometry, - probe: ProbeRepositoryItem, - object_: ObjectRepositoryItem, - ) -> None: - super().__init__() - self._patterns = patterns - self._scan = scan - self._geometry = geometry - self._probe = probe - self._object = object_ - self._isScanValid = False - self._isProbeValid = False - self._isObjectValid = False - - def isScanValid(self) -> bool: - return self._isScanValid - - def _validateScan(self) -> None: - scan = self._scan.getScan() - scanIndexes = set(point.index for point in scan) - patternIndexes = set(self._patterns.getAssembledIndexes()) - isScanValidNow = not scanIndexes.isdisjoint(patternIndexes) - - if self._isScanValid != isScanValidNow: - self._isScanValid = isScanValidNow - self.notifyObservers() - - def isProbeValid(self) -> bool: - return self._isProbeValid - - def isObjectValid(self) -> bool: - return self._isObjectValid - - def _validateProbeAndObject(self) -> None: - hasValidityChanged = False - - probe = self._probe.getProbe() - isProbeValidNow = self._geometry.isProbeGeometryValid(probe.getGeometry()) - - if self._isProbeValid != isProbeValidNow: - self._isProbeValid = isProbeValidNow - hasValidityChanged = True - - object_ = self._object.getObject() - isObjectValidNow = self._geometry.isObjectGeometryValid(object_.getGeometry()) - - if self._isObjectValid != isObjectValidNow: - self._isObjectValid = isObjectValidNow - hasValidityChanged = True - - if hasValidityChanged: - self.notifyObservers() - - def update(self, observable: Observable) -> None: - if observable is self._patterns: - self._validateScan() - elif observable is self._scan: - self._validateScan() - elif observable is self._geometry: - self._validateProbeAndObject() - elif observable is self._probe: - self._validateProbeAndObject() - elif observable is self._object: - self._validateProbeAndObject() diff --git a/src/ptychodus/model/product/repository.py b/src/ptychodus/model/product/repository.py new file mode 100644 index 00000000..9137c0c4 --- /dev/null +++ b/src/ptychodus/model/product/repository.py @@ -0,0 +1,138 @@ +from collections.abc import Sequence +from typing import overload +import logging +import sys + +from ptychodus.api.units import BYTES_PER_MEGABYTE + +from .item import ProductRepositoryItem, ProductRepositoryItemObserver, ProductRepositoryObserver + +logger = logging.getLogger(__name__) + + +class ProductRepository(Sequence[ProductRepositoryItem], ProductRepositoryItemObserver): + def __init__(self) -> None: + super().__init__() + self._item_list: list[ProductRepositoryItem] = [] + self._observer_list: list[ProductRepositoryObserver] = [] + + @overload + def __getitem__(self, index: int) -> ProductRepositoryItem: ... + + @overload + def __getitem__(self, index: slice) -> Sequence[ProductRepositoryItem]: ... + + def __getitem__( + self, index: int | slice + ) -> ProductRepositoryItem | Sequence[ProductRepositoryItem]: + return self._item_list[index] + + def __len__(self) -> int: + return len(self._item_list) + + def create_unique_name(self, candidate_name: str) -> str: + reserved_names = set([item.get_name() for item in self._item_list]) + name = candidate_name or 'Unnamed' + match = 0 + + while name in reserved_names: + match += 1 + name = f'{candidate_name}-{match}' + + return name + + def _update_indexes(self) -> None: + for index, item in enumerate(self._item_list): + item._index = index + + def insert_product(self, item: ProductRepositoryItem) -> int: + index = len(self._item_list) + self._item_list.append(item) + + self._update_indexes() + + for observer in self._observer_list: + observer.handle_item_inserted(index, item) + + return index + + def remove_product(self, index: int) -> None: + try: + item = self._item_list.pop(index) + except IndexError: + logger.debug(f'Failed to remove product item {index}!') + return + + self._update_indexes() + + for observer in self._observer_list: + observer.handle_item_removed(index, item) + + def get_info_text(self) -> str: + size_MB = sum(sys.getsizeof(prod) for prod in self._item_list) / BYTES_PER_MEGABYTE # noqa: N806 + return f'Total: {len(self)} [{size_MB:.2f}MB]' + + def add_observer(self, observer: ProductRepositoryObserver) -> None: + if observer not in self._observer_list: + self._observer_list.append(observer) + + def remove_observer(self, observer: ProductRepositoryObserver) -> None: + try: + self._observer_list.remove(observer) + except ValueError: + pass + + def handle_metadata_changed(self, item: ProductRepositoryItem) -> None: + metadata = item.get_metadata_item() + index = item._index + + if index < 0: + logger.warning(f'Failed to look up index for "{item.get_name()}"!') + return + + for observer in self._observer_list: + observer.handle_metadata_changed(index, metadata) + + def handle_scan_changed(self, item: ProductRepositoryItem) -> None: + index = item._index + scan = item.get_scan_item() + + if index < 0: + logger.warning(f'Failed to look up index for "{item.get_name()}"!') + return + + for observer in self._observer_list: + observer.handle_scan_changed(index, scan) + + def handle_probe_changed(self, item: ProductRepositoryItem) -> None: + index = item._index + probe = item.get_probe_item() + + if index < 0: + logger.warning(f'Failed to look up index for "{item.get_name()}"!') + return + + for observer in self._observer_list: + observer.handle_probe_changed(index, probe) + + def handle_object_changed(self, item: ProductRepositoryItem) -> None: + index = item._index + object_ = item.get_object_item() + + if index < 0: + logger.warning(f'Failed to look up index for "{item.get_name()}"!') + return + + for observer in self._observer_list: + observer.handle_object_changed(index, object_) + + def handle_costs_changed(self, item: ProductRepositoryItem) -> None: + index = item._index + costs = item.get_costs() + + if index < 0: + logger.warning(f'Failed to look up index for "{item.get_name()}"!') + return + + for observer in self._observer_list: + observer.handle_costs_changed(index, costs) diff --git a/src/ptychodus/model/product/scan/__init__.py b/src/ptychodus/model/product/scan/__init__.py index 1f35bc99..4d4eeb1c 100644 --- a/src/ptychodus/model/product/scan/__init__.py +++ b/src/ptychodus/model/product/scan/__init__.py @@ -1,9 +1,9 @@ from .builder import FromFileScanBuilder, FromMemoryScanBuilder -from .builderFactory import ScanBuilderFactory +from .builder_factory import ScanBuilderFactory from .cartesian import CartesianScanBuilder from .concentric import ConcentricScanBuilder from .item import ScanRepositoryItem -from .itemFactory import ScanRepositoryItemFactory +from .item_factory import ScanRepositoryItemFactory from .lissajous import LissajousScanBuilder from .settings import ScanSettings from .spiral import SpiralScanBuilder diff --git a/src/ptychodus/model/product/scan/boundingBox.py b/src/ptychodus/model/product/scan/boundingBox.py deleted file mode 100644 index 11b092b1..00000000 --- a/src/ptychodus/model/product/scan/boundingBox.py +++ /dev/null @@ -1,38 +0,0 @@ -import numpy - -from ptychodus.api.scan import ScanBoundingBox, ScanPoint - - -class ScanBoundingBoxBuilder: - def __init__(self) -> None: - self._minimumXInMeters = +numpy.inf - self._maximumXInMeters = -numpy.inf - self._minimumYInMeters = +numpy.inf - self._maximumYInMeters = -numpy.inf - - def hull(self, point: ScanPoint) -> None: - if point.positionXInMeters < self._minimumXInMeters: - self._minimumXInMeters = point.positionXInMeters - - if self._maximumXInMeters < point.positionXInMeters: - self._maximumXInMeters = point.positionXInMeters - - if point.positionYInMeters < self._minimumYInMeters: - self._minimumYInMeters = point.positionYInMeters - - if self._maximumYInMeters < point.positionYInMeters: - self._maximumYInMeters = point.positionYInMeters - - def getBoundingBox(self) -> ScanBoundingBox | None: - isEmptyX = self._maximumXInMeters < self._minimumXInMeters - isEmptyY = self._maximumYInMeters < self._minimumYInMeters - - if isEmptyX or isEmptyY: - return None - - return ScanBoundingBox( - minimumXInMeters=self._minimumXInMeters, - maximumXInMeters=self._maximumXInMeters, - minimumYInMeters=self._minimumYInMeters, - maximumYInMeters=self._maximumYInMeters, - ) diff --git a/src/ptychodus/model/product/scan/bounding_box.py b/src/ptychodus/model/product/scan/bounding_box.py new file mode 100644 index 00000000..9ef037f8 --- /dev/null +++ b/src/ptychodus/model/product/scan/bounding_box.py @@ -0,0 +1,38 @@ +import numpy + +from ptychodus.api.scan import ScanBoundingBox, ScanPoint + + +class ScanBoundingBoxBuilder: + def __init__(self) -> None: + self._xmin_m = +numpy.inf + self._xmax_m = -numpy.inf + self._ymin_m = +numpy.inf + self._ymax_m = -numpy.inf + + def hull(self, point: ScanPoint) -> None: + if point.position_x_m < self._xmin_m: + self._xmin_m = point.position_x_m + + if self._xmax_m < point.position_x_m: + self._xmax_m = point.position_x_m + + if point.position_y_m < self._ymin_m: + self._ymin_m = point.position_y_m + + if self._ymax_m < point.position_y_m: + self._ymax_m = point.position_y_m + + def get_bounding_box(self) -> ScanBoundingBox | None: + is_empty_x = self._xmax_m < self._xmin_m + is_empty_y = self._ymax_m < self._ymin_m + + if is_empty_x or is_empty_y: + return None + + return ScanBoundingBox( + minimum_x_m=self._xmin_m, + maximum_x_m=self._xmax_m, + minimum_y_m=self._ymin_m, + maximum_y_m=self._ymax_m, + ) diff --git a/src/ptychodus/model/product/scan/builder.py b/src/ptychodus/model/product/scan/builder.py index 1be11dfe..e463e2c1 100644 --- a/src/ptychodus/model/product/scan/builder.py +++ b/src/ptychodus/model/product/scan/builder.py @@ -5,7 +5,7 @@ import logging from ptychodus.api.parametric import ParameterGroup -from ptychodus.api.scan import Scan, ScanFileReader, ScanPoint +from ptychodus.api.scan import PositionSequence, PositionFileReader, ScanPoint from .settings import ScanSettings @@ -16,22 +16,22 @@ class ScanBuilder(ParameterGroup): def __init__(self, settings: ScanSettings, name: str) -> None: super().__init__() self._name = settings.builder.copy() - self._name.setValue(name) - self._addParameter('name', self._name) + self._name.set_value(name) + self._add_parameter('name', self._name) - def getName(self) -> str: - return self._name.getValue() + def get_name(self) -> str: + return self._name.get_value() - def syncToSettings(self) -> None: + def sync_to_settings(self) -> None: for parameter in self.parameters().values(): - parameter.syncValueToParent() + parameter.sync_value_to_parent() @abstractmethod def copy(self) -> ScanBuilder: pass @abstractmethod - def build(self) -> Scan: + def build(self) -> PositionSequence: pass @@ -39,42 +39,49 @@ class FromMemoryScanBuilder(ScanBuilder): def __init__(self, settings: ScanSettings, points: Sequence[ScanPoint]) -> None: super().__init__(settings, 'from_memory') self._settings = settings - self._scan = Scan(points) + self._scan = PositionSequence(points) def copy(self) -> FromMemoryScanBuilder: return FromMemoryScanBuilder(self._settings, self._scan) - def build(self) -> Scan: + def build(self) -> PositionSequence: return self._scan class FromFileScanBuilder(ScanBuilder): def __init__( - self, settings: ScanSettings, filePath: Path, fileType: str, fileReader: ScanFileReader + self, + settings: ScanSettings, + file_path: Path, + file_type: str, + file_reader: PositionFileReader, ) -> None: super().__init__(settings, 'from_file') self._settings = settings - self.filePath = settings.filePath.copy() - self.filePath.setValue(filePath) - self._addParameter('file_path', self.filePath) - self.fileType = settings.fileType.copy() - self.fileType.setValue(fileType) - self._addParameter('file_type', self.fileType) - self._fileReader = fileReader + self.file_path = settings.file_path.copy() + self.file_path.set_value(file_path) + self._add_parameter('file_path', self.file_path) + self.file_type = settings.file_type.copy() + self.file_type.set_value(file_type) + self._add_parameter('file_type', self.file_type) + self._file_reader = file_reader def copy(self) -> FromFileScanBuilder: return FromFileScanBuilder( - self._settings, self.filePath.getValue(), self.fileType.getValue(), self._fileReader + self._settings, + self.file_path.get_value(), + self.file_type.get_value(), + self._file_reader, ) - def build(self) -> Scan: - filePath = self.filePath.getValue() - fileType = self.fileType.getValue() - logger.debug(f'Reading "{filePath}" as "{fileType}"') + def build(self) -> PositionSequence: + file_path = self.file_path.get_value() + file_type = self.file_type.get_value() + logger.debug(f'Reading "{file_path}" as "{file_type}"') try: - scan = self._fileReader.read(filePath) + scan = self._file_reader.read(file_path) except Exception as exc: - raise RuntimeError(f'Failed to read "{filePath}"') from exc + raise RuntimeError(f'Failed to read "{file_path}"') from exc return scan diff --git a/src/ptychodus/model/product/scan/builderFactory.py b/src/ptychodus/model/product/scan/builderFactory.py deleted file mode 100644 index 69cbad29..00000000 --- a/src/ptychodus/model/product/scan/builderFactory.py +++ /dev/null @@ -1,89 +0,0 @@ -from collections.abc import Callable, Iterable, Iterator, Sequence -from pathlib import Path -import logging - -from ptychodus.api.plugins import PluginChooser -from ptychodus.api.scan import Scan, ScanFileReader, ScanFileWriter - -from .builder import FromFileScanBuilder, ScanBuilder -from .cartesian import CartesianScanBuilder, CartesianScanVariant -from .concentric import ConcentricScanBuilder -from .lissajous import LissajousScanBuilder -from .settings import ScanSettings -from .spiral import SpiralScanBuilder - -logger = logging.getLogger(__name__) - - -class ScanBuilderFactory(Iterable[str]): - def __init__( - self, - settings: ScanSettings, - fileReaderChooser: PluginChooser[ScanFileReader], - fileWriterChooser: PluginChooser[ScanFileWriter], - ) -> None: - self._settings = settings - self._fileReaderChooser = fileReaderChooser - self._fileWriterChooser = fileWriterChooser - self._builders: dict[str, Callable[[], ScanBuilder]] = { - variant.name.lower(): lambda variant=variant: CartesianScanBuilder(variant, settings) # type: ignore - for variant in CartesianScanVariant - } - self._builders.update( - { - 'concentric': lambda: ConcentricScanBuilder(settings), - 'spiral': lambda: SpiralScanBuilder(settings), - 'lissajous': lambda: LissajousScanBuilder(settings), - } - ) - - def __iter__(self) -> Iterator[str]: - return iter(self._builders) - - def create(self, name: str) -> ScanBuilder: - try: - factory = self._builders[name] - except KeyError as exc: - raise KeyError(f'Unknown scan builder "{name}"!') from exc - - return factory() - - def createDefault(self) -> ScanBuilder: - return next(iter(self._builders.values()))() - - def createFromSettings(self) -> ScanBuilder: - name = self._settings.builder.getValue() - nameRepaired = name.casefold() - - if nameRepaired == 'from_file': - return self.createScanFromFile( - self._settings.filePath.getValue(), - self._settings.fileType.getValue(), - ) - - return self.create(nameRepaired) - - def getOpenFileFilterList(self) -> Sequence[str]: - return self._fileReaderChooser.getDisplayNameList() - - def getOpenFileFilter(self) -> str: - return self._fileReaderChooser.currentPlugin.displayName - - def createScanFromFile(self, filePath: Path, fileType: str) -> ScanBuilder: - self._fileReaderChooser.setCurrentPluginByName(fileType) - fileType = self._fileReaderChooser.currentPlugin.simpleName - fileReader = self._fileReaderChooser.currentPlugin.strategy - return FromFileScanBuilder(self._settings, filePath, fileType, fileReader) - - def getSaveFileFilterList(self) -> Sequence[str]: - return self._fileWriterChooser.getDisplayNameList() - - def getSaveFileFilter(self) -> str: - return self._fileWriterChooser.currentPlugin.displayName - - def saveScan(self, filePath: Path, fileType: str, scan: Scan) -> None: - self._fileWriterChooser.setCurrentPluginByName(fileType) - fileType = self._fileWriterChooser.currentPlugin.simpleName - logger.debug(f'Writing "{filePath}" as "{fileType}"') - fileWriter = self._fileWriterChooser.currentPlugin.strategy - fileWriter.write(filePath, scan) diff --git a/src/ptychodus/model/product/scan/builder_factory.py b/src/ptychodus/model/product/scan/builder_factory.py new file mode 100644 index 00000000..5c46324f --- /dev/null +++ b/src/ptychodus/model/product/scan/builder_factory.py @@ -0,0 +1,91 @@ +from collections.abc import Callable, Iterable, Iterator +from pathlib import Path +import logging + +from ptychodus.api.plugins import PluginChooser +from ptychodus.api.scan import PositionSequence, PositionFileReader, PositionFileWriter + +from .builder import FromFileScanBuilder, ScanBuilder +from .cartesian import CartesianScanBuilder, CartesianScanVariant +from .concentric import ConcentricScanBuilder +from .lissajous import LissajousScanBuilder +from .settings import ScanSettings +from .spiral import SpiralScanBuilder + +logger = logging.getLogger(__name__) + + +class ScanBuilderFactory(Iterable[str]): + def __init__( + self, + settings: ScanSettings, + file_reader_chooser: PluginChooser[PositionFileReader], + file_writer_chooser: PluginChooser[PositionFileWriter], + ) -> None: + self._settings = settings + self._file_reader_chooser = file_reader_chooser + self._file_writer_chooser = file_writer_chooser + self._builders: dict[str, Callable[[], ScanBuilder]] = { + variant.name.lower(): lambda variant=variant: CartesianScanBuilder(variant, settings) # type: ignore + for variant in CartesianScanVariant + } + self._builders.update( + { + 'concentric': lambda: ConcentricScanBuilder(settings), + 'spiral': lambda: SpiralScanBuilder(settings), + 'lissajous': lambda: LissajousScanBuilder(settings), + } + ) + + def __iter__(self) -> Iterator[str]: + return iter(self._builders) + + def create(self, name: str) -> ScanBuilder: + try: + factory = self._builders[name] + except KeyError as exc: + raise KeyError(f'Unknown scan builder "{name}"!') from exc + + return factory() + + def create_default(self) -> ScanBuilder: + return next(iter(self._builders.values()))() + + def create_from_settings(self) -> ScanBuilder: + name = self._settings.builder.get_value() + name_repaired = name.casefold() + + if name_repaired == 'from_file': + return self.create_scan_from_file( + self._settings.file_path.get_value(), + self._settings.file_type.get_value(), + ) + + return self.create(name_repaired) + + def get_open_file_filters(self) -> Iterator[str]: + for plugin in self._file_reader_chooser: + yield plugin.display_name + + def get_open_file_filter(self) -> str: + return self._file_reader_chooser.get_current_plugin().display_name + + def create_scan_from_file(self, file_path: Path, file_type: str) -> ScanBuilder: + self._file_reader_chooser.set_current_plugin(file_type) + file_type = self._file_reader_chooser.get_current_plugin().simple_name + file_reader = self._file_reader_chooser.get_current_plugin().strategy + return FromFileScanBuilder(self._settings, file_path, file_type, file_reader) + + def get_save_file_filters(self) -> Iterator[str]: + for plugin in self._file_writer_chooser: + yield plugin.display_name + + def get_save_file_filter(self) -> str: + return self._file_writer_chooser.get_current_plugin().display_name + + def save_scan(self, file_path: Path, file_type: str, scan: PositionSequence) -> None: + self._file_writer_chooser.set_current_plugin(file_type) + file_type = self._file_writer_chooser.get_current_plugin().simple_name + logger.debug(f'Writing "{file_path}" as "{file_type}"') + file_writer = self._file_writer_chooser.get_current_plugin().strategy + file_writer.write(file_path, scan) diff --git a/src/ptychodus/model/product/scan/cartesian.py b/src/ptychodus/model/product/scan/cartesian.py index 53b318e9..26429470 100644 --- a/src/ptychodus/model/product/scan/cartesian.py +++ b/src/ptychodus/model/product/scan/cartesian.py @@ -3,7 +3,7 @@ import numpy -from ptychodus.api.scan import Scan, ScanPoint +from ptychodus.api.scan import PositionSequence, ScanPoint from .builder import ScanBuilder from .settings import ScanSettings @@ -20,15 +20,15 @@ class CartesianScanVariant(IntEnum): HEXAGONAL_SNAKE = 0x7 @property - def isSnaked(self) -> bool: + def is_snaked(self) -> bool: return self.value & 1 != 0 @property - def isTriangular(self) -> bool: + def is_triangular(self) -> bool: return self.value & 2 != 0 @property - def isEquilateral(self) -> bool: + def is_equilateral(self) -> bool: return self.value & 4 != 0 @@ -38,49 +38,49 @@ def __init__(self, variant: CartesianScanVariant, settings: ScanSettings) -> Non self._variant = variant self._settings = settings - self.numberOfPointsX = settings.numberOfPointsX.copy() - self._addParameter('number_of_points_x', self.numberOfPointsX) + self.num_points_x = settings.num_points_x.copy() + self._add_parameter('num_points_x', self.num_points_x) - self.numberOfPointsY = settings.numberOfPointsY.copy() - self._addParameter('number_of_points_y', self.numberOfPointsY) + self.num_points_y = settings.num_points_y.copy() + self._add_parameter('num_points_y', self.num_points_y) - self.stepSizeXInMeters = settings.stepSizeXInMeters.copy() - self._addParameter('step_size_x_m', self.stepSizeXInMeters) + self.step_size_x_m = settings.step_size_x_m.copy() + self._add_parameter('step_size_x_m', self.step_size_x_m) - self.stepSizeYInMeters = settings.stepSizeYInMeters.copy() - self._addParameter('step_size_y_m', self.stepSizeYInMeters) + self.step_size_y_m = settings.step_size_y_m.copy() + self._add_parameter('step_size_y_m', self.step_size_y_m) def copy(self) -> CartesianScanBuilder: builder = CartesianScanBuilder(self._variant, self._settings) for key, value in self.parameters().items(): - builder.parameters()[key].setValue(value.getValue()) + builder.parameters()[key].set_value(value.get_value()) return builder @property - def isEquilateral(self) -> bool: - return self._variant.isEquilateral + def is_equilateral(self) -> bool: + return self._variant.is_equilateral - def build(self) -> Scan: - nx = self.numberOfPointsX.getValue() - ny = self.numberOfPointsY.getValue() - dx = self.stepSizeXInMeters.getValue() + def build(self) -> PositionSequence: + nx = self.num_points_x.get_value() + ny = self.num_points_y.get_value() + dx = self.step_size_x_m.get_value() - if self._variant.isEquilateral: + if self._variant.is_equilateral: dy = dx - if self._variant.isTriangular: + if self._variant.is_triangular: dy *= numpy.sqrt(0.75) else: - dy = self.stepSizeYInMeters.getValue() + dy = self.step_size_y_m.get_value() - pointList: list[ScanPoint] = list() + point_list: list[ScanPoint] = list() for index in range(nx * ny): y, x = divmod(index, nx) - if self._variant.isSnaked: + if self._variant.is_snaked: if y & 1: x = nx - 1 - x @@ -90,7 +90,7 @@ def build(self) -> Scan: xf = (x - cx) * dx yf = (y - cy) * dy - if self._variant.isTriangular: + if self._variant.is_triangular: if y & 1: xf += dx / 4 else: @@ -98,9 +98,9 @@ def build(self) -> Scan: point = ScanPoint( index=index, - positionXInMeters=xf, - positionYInMeters=yf, + position_x_m=xf, + position_y_m=yf, ) - pointList.append(point) + point_list.append(point) - return Scan(pointList) + return PositionSequence(point_list) diff --git a/src/ptychodus/model/product/scan/concentric.py b/src/ptychodus/model/product/scan/concentric.py index f615cd1a..73bcdea1 100644 --- a/src/ptychodus/model/product/scan/concentric.py +++ b/src/ptychodus/model/product/scan/concentric.py @@ -2,7 +2,7 @@ import numpy -from ptychodus.api.scan import Scan, ScanPoint +from ptychodus.api.scan import PositionSequence, ScanPoint from .builder import ScanBuilder from .settings import ScanSettings @@ -15,48 +15,48 @@ def __init__(self, settings: ScanSettings) -> None: super().__init__(settings, 'concentric') self._settings = settings - self.radialStepSizeInMeters = settings.radialStepSizeInMeters.copy() - self._addParameter('radial_step_size_m', self.radialStepSizeInMeters) + self.radial_step_size_m = settings.radial_step_size_m.copy() + self._add_parameter('radial_step_size_m', self.radial_step_size_m) - self.numberOfShells = settings.numberOfShells.copy() - self._addParameter('number_of_shells', self.numberOfShells) + self.num_shells = settings.num_shells.copy() + self._add_parameter('num_shells', self.num_shells) - self.numberOfPointsInFirstShell = settings.numberOfPointsInFirstShell.copy() - self._addParameter('number_of_points_1st_shell', self.numberOfPointsInFirstShell) + self.num_points_1st_shell = settings.num_points_in_first_shell.copy() + self._add_parameter('num_points_1st_shell', self.num_points_1st_shell) def copy(self) -> ConcentricScanBuilder: builder = ConcentricScanBuilder(self._settings) for key, value in self.parameters().items(): - builder.parameters()[key].setValue(value.getValue()) + builder.parameters()[key].set_value(value.get_value()) return builder @property - def _numberOfPoints(self) -> int: - numberOfShells = self.numberOfShells.getValue() - triangle = (numberOfShells * (numberOfShells + 1)) // 2 - return triangle * self.numberOfPointsInFirstShell.getValue() + def _num_points(self) -> int: + num_shells = self.num_shells.get_value() + triangle = (num_shells * (num_shells + 1)) // 2 + return triangle * self.num_points_1st_shell.get_value() - def build(self) -> Scan: - pointList: list[ScanPoint] = list() + def build(self) -> PositionSequence: + point_list: list[ScanPoint] = list() - for index in range(self._numberOfPoints): - triangle = index // self.numberOfPointsInFirstShell.getValue() - shellIndex = int((1 + numpy.sqrt(1 + 8 * triangle)) / 2) - 1 # see OEIS A002024 - shellTriangle = (shellIndex * (shellIndex + 1)) // 2 - firstIndexInShell = self.numberOfPointsInFirstShell.getValue() * shellTriangle - pointIndexInShell = index - firstIndexInShell + for index in range(self._num_points): + triangle = index // self.num_points_1st_shell.get_value() + shell_index = int((1 + numpy.sqrt(1 + 8 * triangle)) / 2) - 1 # see OEIS A002024 + shell_triangle = (shell_index * (shell_index + 1)) // 2 + first_index_in_shell = self.num_points_1st_shell.get_value() * shell_triangle + point_index_in_shell = index - first_index_in_shell - radiusInMeters = self.radialStepSizeInMeters.getValue() * (shellIndex + 1) - numberOfPointsInShell = self.numberOfPointsInFirstShell.getValue() * (shellIndex + 1) - thetaInRadians = 2 * numpy.pi * pointIndexInShell / numberOfPointsInShell + radius_m = self.radial_step_size_m.get_value() * (shell_index + 1) + num_points_in_shell = self.num_points_1st_shell.get_value() * (shell_index + 1) + theta_rad = 2 * numpy.pi * point_index_in_shell / num_points_in_shell point = ScanPoint( index=index, - positionXInMeters=radiusInMeters * numpy.cos(thetaInRadians), - positionYInMeters=radiusInMeters * numpy.sin(thetaInRadians), + position_x_m=radius_m * numpy.cos(theta_rad), + position_y_m=radius_m * numpy.sin(theta_rad), ) - pointList.append(point) + point_list.append(point) - return Scan(pointList) + return PositionSequence(point_list) diff --git a/src/ptychodus/model/product/scan/item.py b/src/ptychodus/model/product/scan/item.py index 3585df04..0748c398 100644 --- a/src/ptychodus/model/product/scan/item.py +++ b/src/ptychodus/model/product/scan/item.py @@ -6,9 +6,9 @@ from ptychodus.api.observer import Observable from ptychodus.api.parametric import ParameterGroup -from ptychodus.api.scan import Scan, ScanBoundingBox, ScanPoint +from ptychodus.api.scan import PositionSequence, ScanBoundingBox, ScanPoint -from .boundingBox import ScanBoundingBoxBuilder +from .bounding_box import ScanBoundingBoxBuilder from .builder import FromMemoryScanBuilder, ScanBuilder from .settings import ScanSettings from .transform import ScanPointTransform @@ -27,127 +27,126 @@ def __init__( self._settings = settings self._builder = builder self._transform = transform + self._untransformed_scan = PositionSequence() + self._transformed_scan = PositionSequence() + self._bbox_builder = ScanBoundingBoxBuilder() + self._length_m = 0.0 - self._untransformedScan = Scan() - self._transformedScan = Scan() - self._boundingBoxBuilder = ScanBoundingBoxBuilder() - self._lengthInMeters = 0.0 + self._add_group('builder', builder, observe=True) + self._add_group('transform', transform, observe=True) - self._addGroup('builder', builder, observe=True) - self._addGroup('transform', transform, observe=True) + self.expand_bbox = settings.expand_bbox.copy() + self._add_parameter('expand_bbox', self.expand_bbox) - self.expandBoundingBox = settings.expandBoundingBox.copy() - self._addParameter('expand_bbox', self.expandBoundingBox) + self.expand_bbox_xmin_m = settings.expand_bbox_xmin_m.copy() + self._add_parameter('expand_bbox_xmin_m', self.expand_bbox_xmin_m) - self.expandedBoundingBoxMinimumXInMeters = ( - settings.expandedBoundingBoxMinimumXInMeters.copy() - ) - self._addParameter('expanded_bbox_xmin_m', self.expandedBoundingBoxMinimumXInMeters) + self.expand_bbox_xmax_m = settings.expand_bbox_xmax_m.copy() + self._add_parameter('expand_bbox_xmax_m', self.expand_bbox_xmax_m) - self.expandedBoundingBoxMaximumXInMeters = ( - settings.expandedBoundingBoxMaximumXInMeters.copy() - ) - self._addParameter('expanded_bbox_xmax_m', self.expandedBoundingBoxMaximumXInMeters) + self.expand_bbox_ymin_m = settings.expand_bbox_ymin_m.copy() + self._add_parameter('expand_bbox_ymin_m', self.expand_bbox_ymin_m) - self.expandedBoundingBoxMinimumYInMeters = ( - settings.expandedBoundingBoxMinimumYInMeters.copy() - ) - self._addParameter('expanded_bbox_ymin_m', self.expandedBoundingBoxMinimumYInMeters) - - self.expandedBoundingBoxMaximumYInMeters = ( - settings.expandedBoundingBoxMaximumYInMeters.copy() - ) - self._addParameter('expanded_bbox_ymax_m', self.expandedBoundingBoxMaximumYInMeters) + self.expand_bbox_ymax_m = settings.expand_bbox_ymax_m.copy() + self._add_parameter('expand_bbox_ymax_m', self.expand_bbox_ymax_m) self._rebuild() - def assignItem(self, item: ScanRepositoryItem) -> None: - self._removeGroup('transform') - self._transform.removeObserver(self) - self._transform = item.getTransform().copy() - self._transform.addObserver(self) - self._addGroup('transform', self._transform) + def assign_item(self, item: ScanRepositoryItem) -> None: + group = 'transform' + + self._remove_group(group) + self._transform.remove_observer(self) - self.setBuilder(item.getBuilder().copy()) + transform = item.get_transform() - def assign(self, scan: Scan) -> None: + self._transform = transform.copy() + self._transform.add_observer(self) + self._add_group(group, self._transform, observe=True) + + self.set_builder(item.get_builder().copy()) + self._rebuild() + + def assign(self, scan: PositionSequence) -> None: builder = FromMemoryScanBuilder(self._settings, scan) - self.setBuilder(builder) + self.set_builder(builder) - def syncToSettings(self) -> None: + def sync_to_settings(self) -> None: for parameter in self.parameters().values(): - parameter.syncValueToParent() + parameter.sync_value_to_parent() - self._builder.syncToSettings() - self._transform.syncToSettings() + self._builder.sync_to_settings() + self._transform.sync_to_settings() - def getScan(self) -> Scan: - return self._transformedScan + def get_scan(self) -> PositionSequence: + return self._transformed_scan - def getBuilder(self) -> ScanBuilder: + def get_builder(self) -> ScanBuilder: return self._builder - def setBuilder(self, builder: ScanBuilder) -> None: - self._removeGroup('builder') - self._builder.removeObserver(self) + def set_builder(self, builder: ScanBuilder) -> None: + group = 'builder' + self._remove_group(group) + self._builder.remove_observer(self) self._builder = builder - self._builder.addObserver(self) - self._addGroup('builder', self._builder) + self._builder.add_observer(self) + self._add_group(group, self._builder, observe=True) self._rebuild() - def getBoundingBox(self) -> ScanBoundingBox | None: - bbox = self._boundingBoxBuilder.getBoundingBox() + def get_bounding_box(self) -> ScanBoundingBox | None: + bbox = self._bbox_builder.get_bounding_box() - if self.expandBoundingBox.getValue(): - expandedBoundingBox = ScanBoundingBox( - minimumXInMeters=self.expandedBoundingBoxMinimumXInMeters.getValue(), - maximumXInMeters=self.expandedBoundingBoxMaximumXInMeters.getValue(), - minimumYInMeters=self.expandedBoundingBoxMinimumYInMeters.getValue(), - maximumYInMeters=self.expandedBoundingBoxMaximumYInMeters.getValue(), + if self.expand_bbox.get_value(): + expanded_bbox = ScanBoundingBox( + minimum_x_m=self.expand_bbox_xmin_m.get_value(), + maximum_x_m=self.expand_bbox_xmax_m.get_value(), + minimum_y_m=self.expand_bbox_ymin_m.get_value(), + maximum_y_m=self.expand_bbox_ymax_m.get_value(), ) - bbox = expandedBoundingBox if bbox is None else bbox.hull(expandedBoundingBox) + bbox = expanded_bbox if bbox is None else bbox.hull(expanded_bbox) return bbox - def getLengthInMeters(self) -> float: - return self._lengthInMeters + def get_length_m(self) -> float: + return self._length_m - def _transformScan(self) -> None: - transformedPoints: list[ScanPoint] = list() - boundingBoxBuilder = ScanBoundingBoxBuilder() - lengthInMeters = 0.0 + def _transform_scan(self) -> None: + transformed_points: list[ScanPoint] = list() + bbox_builder = ScanBoundingBoxBuilder() + length_m = 0.0 - for untransformedPoint in self._untransformedScan: - point = self._transform(untransformedPoint) - transformedPoints.append(point) - boundingBoxBuilder.hull(point) + for untransformed_point in self._untransformed_scan: + transformed_point = self._transform(untransformed_point) + transformed_points.append(transformed_point) + bbox_builder.hull(transformed_point) - for pointL, pointR in pairwise(transformedPoints): - dx = pointR.positionXInMeters - pointL.positionXInMeters - dy = pointR.positionYInMeters - pointL.positionYInMeters - lengthInMeters += numpy.hypot(dx, dy) + for point_l, point_r in pairwise(transformed_points): + dx = point_r.position_x_m - point_l.position_x_m + dy = point_r.position_y_m - point_l.position_y_m + length_m += numpy.hypot(dx, dy) - self._transformedScan = Scan(transformedPoints) - self._boundingBoxBuilder = boundingBoxBuilder - self._lengthInMeters = lengthInMeters - self.notifyObservers() + self._transformed_scan = PositionSequence(transformed_points) + self._bbox_builder = bbox_builder + self._length_m = length_m + self.notify_observers() def _rebuild(self) -> None: try: scan = self._builder.build() except Exception as exc: - logger.error(''.join(exc.args)) - else: - self._untransformedScan = scan - self._transformScan() + logger.exception('Failed to rebuild scan!') + return + + self._untransformed_scan = scan + self._transform_scan() - def getTransform(self) -> ScanPointTransform: + def get_transform(self) -> ScanPointTransform: return self._transform - def update(self, observable: Observable) -> None: + def _update(self, observable: Observable) -> None: if observable is self._builder: self._rebuild() elif observable is self._transform: - self._transformScan() + self._transform_scan() else: - super().update(observable) + super()._update(observable) diff --git a/src/ptychodus/model/product/scan/itemFactory.py b/src/ptychodus/model/product/scan/item_factory.py similarity index 50% rename from src/ptychodus/model/product/scan/itemFactory.py rename to src/ptychodus/model/product/scan/item_factory.py index ce02bc44..91037bfa 100644 --- a/src/ptychodus/model/product/scan/itemFactory.py +++ b/src/ptychodus/model/product/scan/item_factory.py @@ -1,11 +1,11 @@ import logging -import numpy +import numpy.random -from ptychodus.api.scan import Scan +from ptychodus.api.scan import PositionSequence from .builder import FromMemoryScanBuilder -from .builderFactory import ScanBuilderFactory +from .builder_factory import ScanBuilderFactory from .item import ScanRepositoryItem from .settings import ScanSettings from .transform import ScanPointTransform @@ -18,27 +18,29 @@ def __init__( self, rng: numpy.random.Generator, settings: ScanSettings, - builderFactory: ScanBuilderFactory, + builder_factory: ScanBuilderFactory, ) -> None: self._rng = rng self._settings = settings - self._builderFactory = builderFactory - - def create(self, scan: Scan | None = None) -> ScanRepositoryItem: - builder = ( - self._builderFactory.createDefault() - if scan is None - else FromMemoryScanBuilder(self._settings, scan) - ) + self._builder_factory = builder_factory + + def create(self, scan: PositionSequence | None = None) -> ScanRepositoryItem: transform = ScanPointTransform(self._rng, self._settings) + + if scan is None: + builder = self._builder_factory.create_default() + else: + builder = FromMemoryScanBuilder(self._settings, scan) + transform.set_identity() + return ScanRepositoryItem(self._settings, builder, transform) - def createFromSettings(self) -> ScanRepositoryItem: + def create_from_settings(self) -> ScanRepositoryItem: try: - builder = self._builderFactory.createFromSettings() + builder = self._builder_factory.create_from_settings() except Exception as exc: - logger.error(''.join(exc.args)) - builder = self._builderFactory.createDefault() + logger.exception(''.join(exc.args)) + builder = self._builder_factory.create_default() transform = ScanPointTransform(self._rng, self._settings) return ScanRepositoryItem(self._settings, builder, transform) diff --git a/src/ptychodus/model/product/scan/lissajous.py b/src/ptychodus/model/product/scan/lissajous.py index ad6a0240..c8c8172b 100644 --- a/src/ptychodus/model/product/scan/lissajous.py +++ b/src/ptychodus/model/product/scan/lissajous.py @@ -2,7 +2,7 @@ import numpy -from ptychodus.api.scan import Scan, ScanPoint +from ptychodus.api.scan import PositionSequence, ScanPoint from .builder import ScanBuilder from .settings import ScanSettings @@ -13,55 +13,55 @@ def __init__(self, settings: ScanSettings) -> None: super().__init__(settings, 'lissajous') self._settings = settings - self.numberOfPoints = settings.numberOfPointsX.copy() - self.numberOfPoints.setValue( - settings.numberOfPointsX.getValue() * settings.numberOfPointsY.getValue() + self.num_points = settings.num_points_x.copy() + self.num_points.set_value( + settings.num_points_x.get_value() * settings.num_points_y.get_value() ) - self._addParameter('number_of_points', self.numberOfPoints) + self._add_parameter('num_points', self.num_points) - self._numberOfPoints = settings.numberOfPointsY.copy() - self._numberOfPoints.setValue(1) - self._addParameter('_number_of_points', self._numberOfPoints) + self._num_points = settings.num_points_y.copy() + self._num_points.set_value(1) + self._add_parameter('_num_points', self._num_points) - self.amplitudeXInMeters = settings.amplitudeXInMeters.copy() - self._addParameter('amplitude_x_m', self.amplitudeXInMeters) + self.amplitude_x_m = settings.amplitude_x_m.copy() + self._add_parameter('amplitude_x_m', self.amplitude_x_m) - self.amplitudeYInMeters = settings.amplitudeYInMeters.copy() - self._addParameter('amplitude_y_m', self.amplitudeYInMeters) + self.amplitude_y_m = settings.amplitude_y_m.copy() + self._add_parameter('amplitude_y_m', self.amplitude_y_m) - self.angularStepXInTurns = settings.angularStepXInTurns.copy() - self._addParameter('angular_step_x_tr', self.angularStepXInTurns) + self.angular_step_x_turns = settings.angular_step_x_turns.copy() + self._add_parameter('angular_step_x_tr', self.angular_step_x_turns) - self.angularStepYInTurns = settings.angularStepYInTurns.copy() - self._addParameter('angular_step_y_tr', self.angularStepYInTurns) + self.angular_step_y_turns = settings.angular_step_y_turns.copy() + self._add_parameter('angular_step_y_tr', self.angular_step_y_turns) - self.angularShiftInTurns = settings.angularShiftInTurns.copy() - self._addParameter('angular_shift_tr', self.angularShiftInTurns) + self.angular_shift_turns = settings.angular_shift_turns.copy() + self._add_parameter('angular_shift_tr', self.angular_shift_turns) def copy(self) -> LissajousScanBuilder: builder = LissajousScanBuilder(self._settings) for key, value in self.parameters().items(): - builder.parameters()[key].setValue(value.getValue()) + builder.parameters()[key].set_value(value.get_value()) return builder - def build(self) -> Scan: - pointList: list[ScanPoint] = list() + def build(self) -> PositionSequence: + point_list: list[ScanPoint] = list() - for index in range(self.numberOfPoints.getValue()): - twoPi = 2 * numpy.pi - thetaX = ( - twoPi * self.angularStepXInTurns.getValue() * index - + self.angularShiftInTurns.getValue() + for index in range(self.num_points.get_value()): + two_pi = 2 * numpy.pi + theta_x = ( + two_pi * self.angular_step_x_turns.get_value() * index + + self.angular_shift_turns.get_value() ) - thetaY = twoPi * self.angularStepYInTurns.getValue() * index + theta_y = two_pi * self.angular_step_y_turns.get_value() * index point = ScanPoint( index=index, - positionXInMeters=self.amplitudeXInMeters.getValue() * numpy.sin(thetaX), - positionYInMeters=self.amplitudeYInMeters.getValue() * numpy.sin(thetaY), + position_x_m=self.amplitude_x_m.get_value() * numpy.sin(theta_x), + position_y_m=self.amplitude_y_m.get_value() * numpy.sin(theta_y), ) - pointList.append(point) + point_list.append(point) - return Scan(pointList) + return PositionSequence(point_list) diff --git a/src/ptychodus/model/product/scan/settings.py b/src/ptychodus/model/product/scan/settings.py index 20bf882c..7de44849 100644 --- a/src/ptychodus/model/product/scan/settings.py +++ b/src/ptychodus/model/product/scan/settings.py @@ -7,93 +7,68 @@ class ScanSettings(Observable, Observer): def __init__(self, registry: SettingsRegistry) -> None: super().__init__() - self._settingsGroup = registry.createGroup('Scan') - self._settingsGroup.addObserver(self) + self._group = registry.create_group('Scan') + self._group.add_observer(self) - self.builder = self._settingsGroup.createStringParameter('Builder', 'rectangular_raster') - self.filePath = self._settingsGroup.createPathParameter( - 'FilePath', Path('/path/to/scan.csv') - ) - self.fileType = self._settingsGroup.createStringParameter('FileType', 'CSV') - - self.affineTransformAX = self._settingsGroup.createRealParameter('AffineTransformAX', 1.0) - self.affineTransformAY = self._settingsGroup.createRealParameter('AffineTransformAY', 0.0) - self.affineTransformATInMeters = self._settingsGroup.createRealParameter( - 'AffineTransformATInMeters', 0.0 - ) + self.builder = self._group.create_string_parameter('Builder', 'rectangular_raster') + self.file_path = self._group.create_path_parameter('FilePath', Path('/path/to/scan.csv')) + self.file_type = self._group.create_string_parameter('FileType', 'CSV') - self.affineTransformBX = self._settingsGroup.createRealParameter('AffineTransformBX', 0.0) - self.affineTransformBY = self._settingsGroup.createRealParameter('AffineTransformBY', 1.0) - self.affineTransformBTInMeters = self._settingsGroup.createRealParameter( - 'AffineTransformBTInMeters', 0.0 - ) - self.jitterRadiusInMeters = self._settingsGroup.createRealParameter( + self.affine00 = self._group.create_real_parameter('Affine00', 1.0) + self.affine01 = self._group.create_real_parameter('Affine01', 0.0) + self.affine02 = self._group.create_real_parameter('Affine02', 0.0) + self.affine10 = self._group.create_real_parameter('Affine10', 0.0) + self.affine11 = self._group.create_real_parameter('Affine11', 1.0) + self.affine12 = self._group.create_real_parameter('Affine12', 0.0) + self.jitter_radius_m = self._group.create_real_parameter( 'JitterRadiusInMeters', 0.0, minimum=0.0 ) - self.expandBoundingBox = self._settingsGroup.createBooleanParameter( - 'ExpandBoundingBox', False - ) - self.expandedBoundingBoxMinimumXInMeters = self._settingsGroup.createRealParameter( + self.expand_bbox = self._group.create_boolean_parameter('ExpandBoundingBox', False) + self.expand_bbox_xmin_m = self._group.create_real_parameter( 'ExpandedBoundingBoxMinimumXInMeters', -5e-7 ) - self.expandedBoundingBoxMaximumXInMeters = self._settingsGroup.createRealParameter( + self.expand_bbox_xmax_m = self._group.create_real_parameter( 'ExpandedBoundingBoxMaximumXInMeters', +5e-7 ) - self.expandedBoundingBoxMinimumYInMeters = self._settingsGroup.createRealParameter( + self.expand_bbox_ymin_m = self._group.create_real_parameter( 'ExpandedBoundingBoxMinimumYInMeters', -5e-7 ) - self.expandedBoundingBoxMaximumYInMeters = self._settingsGroup.createRealParameter( + self.expand_bbox_ymax_m = self._group.create_real_parameter( 'ExpandedBoundingBoxMaximumYInMeters', +5e-7 ) - self.numberOfPointsX = self._settingsGroup.createIntegerParameter( - 'NumberOfPointsX', 10, minimum=0 - ) - self.numberOfPointsY = self._settingsGroup.createIntegerParameter( - 'NumberOfPointsY', 10, minimum=0 - ) - self.stepSizeXInMeters = self._settingsGroup.createRealParameter( + self.num_points_x = self._group.create_integer_parameter('NumberOfPointsX', 10, minimum=0) + self.num_points_y = self._group.create_integer_parameter('NumberOfPointsY', 10, minimum=0) + self.step_size_x_m = self._group.create_real_parameter( 'StepSizeXInMeters', 1e-6, minimum=0.0 ) - self.stepSizeYInMeters = self._settingsGroup.createRealParameter( + self.step_size_y_m = self._group.create_real_parameter( 'StepSizeYInMeters', 1e-6, minimum=0.0 ) - self.radialStepSizeInMeters = self._settingsGroup.createRealParameter( + self.radial_step_size_m = self._group.create_real_parameter( 'RadialStepSizeInMeters', 1e-6, minimum=0.0 ) - self.numberOfShells = self._settingsGroup.createIntegerParameter( - 'NumberOfShells', 5, minimum=0 - ) - self.numberOfPointsInFirstShell = self._settingsGroup.createIntegerParameter( + self.num_shells = self._group.create_integer_parameter('NumberOfShells', 5, minimum=0) + self.num_points_in_first_shell = self._group.create_integer_parameter( 'NumberOfPointsInFirstShell', 10, minimum=0 ) - self.amplitudeXInMeters = self._settingsGroup.createRealParameter( + self.amplitude_x_m = self._group.create_real_parameter( 'AmplitudeXInMeters', 4.5e-6, minimum=0.0 ) - self.amplitudeYInMeters = self._settingsGroup.createRealParameter( + self.amplitude_y_m = self._group.create_real_parameter( 'AmplitudeYInMeters', 4.5e-6, minimum=0.0 ) - self.angularStepXInTurns = self._settingsGroup.createRealParameter( - 'AngularStepXInTurns', 0.03 - ) - self.angularStepYInTurns = self._settingsGroup.createRealParameter( - 'AngularStepYInTurns', 0.04 - ) - self.angularShiftInTurns = self._settingsGroup.createRealParameter( - 'AngularShiftInTurns', 0.25 - ) + self.angular_step_x_turns = self._group.create_real_parameter('AngularStepXInTurns', 0.03) + self.angular_step_y_turns = self._group.create_real_parameter('AngularStepYInTurns', 0.04) + self.angular_shift_turns = self._group.create_real_parameter('AngularShiftInTurns', 0.25) - self.radiusScalarInMeters = self._settingsGroup.createRealParameter( + self.radius_scalar_m = self._group.create_real_parameter( 'RadiusScalarInMeters', 0.5e-6, minimum=0.0 ) - @property - def numberOfPoints(self) -> int: - return self.numberOfPointsX.getValue() * self.numberOfPointsY.getValue() - - def update(self, observable: Observable) -> None: - if observable is self._settingsGroup: - self.notifyObservers() + def _update(self, observable: Observable) -> None: + if observable is self._group: + self.notify_observers() diff --git a/src/ptychodus/model/product/scan/spiral.py b/src/ptychodus/model/product/scan/spiral.py index ef1fd704..f8b77c31 100644 --- a/src/ptychodus/model/product/scan/spiral.py +++ b/src/ptychodus/model/product/scan/spiral.py @@ -2,7 +2,7 @@ import numpy -from ptychodus.api.scan import Scan, ScanPoint +from ptychodus.api.scan import PositionSequence, ScanPoint from .builder import ScanBuilder from .settings import ScanSettings @@ -15,40 +15,40 @@ def __init__(self, settings: ScanSettings) -> None: super().__init__(settings, 'spiral') self._settings = settings - self.numberOfPoints = settings.numberOfPointsX.copy() - self.numberOfPoints.setValue( - settings.numberOfPointsX.getValue() * settings.numberOfPointsY.getValue() + self.num_points = settings.num_points_x.copy() + self.num_points.set_value( + settings.num_points_x.get_value() * settings.num_points_y.get_value() ) - self._addParameter('number_of_points', self.numberOfPoints) + self._add_parameter('num_points', self.num_points) - self._numberOfPoints = settings.numberOfPointsY.copy() - self._numberOfPoints.setValue(1) - self._addParameter('_number_of_points', self._numberOfPoints) + self._num_points = settings.num_points_y.copy() + self._num_points.set_value(1) + self._add_parameter('_num_points', self._num_points) - self.radiusScalarInMeters = settings.radiusScalarInMeters.copy() - self._addParameter('radius_scalar_m', self.radiusScalarInMeters) + self.radius_scalar_m = settings.radius_scalar_m.copy() + self._add_parameter('radius_scalar_m', self.radius_scalar_m) def copy(self) -> SpiralScanBuilder: builder = SpiralScanBuilder(self._settings) for key, value in self.parameters().items(): - builder.parameters()[key].setValue(value.getValue()) + builder.parameters()[key].set_value(value.get_value()) return builder - def build(self) -> Scan: - pointList: list[ScanPoint] = list() + def build(self) -> PositionSequence: + point_list: list[ScanPoint] = list() - for index in range(self.numberOfPoints.getValue()): - radiusInMeters = self.radiusScalarInMeters.getValue() * numpy.sqrt(index) - divergenceAngleInRadians = (3.0 - numpy.sqrt(5)) * numpy.pi - thetaInRadians = divergenceAngleInRadians * index + for index in range(self.num_points.get_value()): + radius_m = self.radius_scalar_m.get_value() * numpy.sqrt(index) + divergence_angle_rad = (3.0 - numpy.sqrt(5)) * numpy.pi + theta_rad = divergence_angle_rad * index point = ScanPoint( index=index, - positionXInMeters=radiusInMeters * numpy.cos(thetaInRadians), - positionYInMeters=radiusInMeters * numpy.sin(thetaInRadians), + position_x_m=radius_m * numpy.cos(theta_rad), + position_y_m=radius_m * numpy.sin(theta_rad), ) - pointList.append(point) + point_list.append(point) - return Scan(pointList) + return PositionSequence(point_list) diff --git a/src/ptychodus/model/product/scan/streaming.py b/src/ptychodus/model/product/scan/streaming.py index dc1caf7b..cfe45c1d 100644 --- a/src/ptychodus/model/product/scan/streaming.py +++ b/src/ptychodus/model/product/scan/streaming.py @@ -2,25 +2,25 @@ # TODO from pvaccess import Channel, PvObjectQueue -from ptychodus.api.scan import Scan, ScanPoint +from ptychodus.api.scan import PositionSequence, ScanPoint from .builder import ScanBuilder from .settings import ScanSettings class StreamingScanBuilder(ScanBuilder): - def __init__(self, settings: ScanSettings, pointSeq: Sequence[ScanPoint]) -> None: + def __init__(self, settings: ScanSettings, point_seq: Sequence[ScanPoint]) -> None: super().__init__(settings, 'Streaming') - self._pointList = list(pointSeq) + self._point_list = list(point_seq) def append(self, point: ScanPoint) -> None: - self._pointList.append(point) + self._point_list.append(point) - def extend(self, pointSeq: Sequence[ScanPoint]) -> None: - self._pointList.extend(pointSeq) + def extend(self, point_seq: Sequence[ScanPoint]) -> None: + self._point_list.extend(point_seq) - def build(self) -> Scan: - return Scan(self._pointList) + def build(self) -> PositionSequence: + return PositionSequence(self._point_list) # TODO def echo(self, value: int = 125) -> None: diff --git a/src/ptychodus/model/product/scan/transform.py b/src/ptychodus/model/product/scan/transform.py index 8c1ce653..8a540ec0 100644 --- a/src/ptychodus/model/product/scan/transform.py +++ b/src/ptychodus/model/product/scan/transform.py @@ -3,6 +3,7 @@ import numpy +from ptychodus.api.geometry import AffineTransform from ptychodus.api.parametric import ParameterGroup from ptychodus.api.scan import ScanPoint @@ -15,92 +16,100 @@ def __init__(self, rng: numpy.random.Generator, settings: ScanSettings) -> None: self._rng = rng self._settings = settings - self.affineAX = settings.affineTransformAX.copy() - self._addParameter('affine_ax', self.affineAX) + self.affine00 = settings.affine00.copy() + self._add_parameter('affine00', self.affine00) - self.affineAY = settings.affineTransformAY.copy() - self._addParameter('affine_ay', self.affineAY) + self.affine01 = settings.affine01.copy() + self._add_parameter('affine01', self.affine01) - self.affineATInMeters = settings.affineTransformATInMeters.copy() - self._addParameter('affine_at_m', self.affineATInMeters) + self.affine02 = settings.affine02.copy() + self._add_parameter('affine02', self.affine02) - self.affineBX = settings.affineTransformBX.copy() - self._addParameter('affine_bx', self.affineBX) + self.affine10 = settings.affine10.copy() + self._add_parameter('affine10', self.affine10) - self.affineBY = settings.affineTransformBY.copy() - self._addParameter('affine_by', self.affineBY) + self.affine11 = settings.affine11.copy() + self._add_parameter('affine11', self.affine11) - self.affineBTInMeters = settings.affineTransformBTInMeters.copy() - self._addParameter('affine_bt_m', self.affineBTInMeters) + self.affine12 = settings.affine12.copy() + self._add_parameter('affine12', self.affine12) - self.jitterRadiusInMeters = settings.jitterRadiusInMeters.copy() - self._addParameter('jitter_radius_m', self.jitterRadiusInMeters) + self.jitter_radius_m = settings.jitter_radius_m.copy() + self._add_parameter('jitter_radius_m', self.jitter_radius_m) - def syncToSettings(self) -> None: + def sync_to_settings(self) -> None: for parameter in self.parameters().values(): - parameter.syncValueToParent() + parameter.sync_value_to_parent() def copy(self) -> ScanPointTransform: transform = ScanPointTransform(self._rng, self._settings) for key, value in self.parameters().items(): - transform.parameters()[key].setValue(value.getValue()) + transform.parameters()[key].set_value(value.get_value()) return transform @staticmethod - def negateX(preset: int) -> bool: + def negate_x(preset: int) -> bool: return preset & 0x1 != 0x0 @staticmethod - def negateY(preset: int) -> bool: + def negate_y(preset: int) -> bool: return preset & 0x2 != 0x0 @staticmethod - def swapXY(preset: int) -> bool: + def swap_xy(preset: int) -> bool: return preset & 0x4 != 0x0 - def labelsForPresets(self) -> Iterator[str]: + def labels_for_presets(self) -> Iterator[str]: for index in range(8): - xp = '\u2212x' if self.negateX(index) else '\u002bx' - yp = '\u2212y' if self.negateY(index) else '\u002by' - fxy = f'{yp}, {xp}' if self.swapXY(index) else f'{xp}, {yp}' - yield f'(x, y) \u2192 ({fxy})' - - def applyPresets(self, index: int) -> None: - if self.swapXY(index): - self.affineAY.setValue(-1 if self.negateY(index) else +1) - self.affineBX.setValue(-1 if self.negateX(index) else +1) - self.affineAX.setValue(0) - self.affineBY.setValue(0) + yp = '\u2212y' if self.negate_y(index) else '\u002by' + xp = '\u2212x' if self.negate_x(index) else '\u002bx' + fyx = f'{xp}, {yp}' if self.swap_xy(index) else f'{yp}, {xp}' + yield f'(y, x) \u2192 ({fyx})' + + def apply_presets(self, index: int) -> None: + self.block_notifications(True) + + if self.swap_xy(index): + self.affine00.set_value(0) + self.affine01.set_value(-1 if self.negate_x(index) else +1) + self.affine10.set_value(-1 if self.negate_y(index) else +1) + self.affine11.set_value(0) else: - self.affineAX.setValue(-1 if self.negateX(index) else +1) - self.affineBY.setValue(-1 if self.negateY(index) else +1) - self.affineAY.setValue(0) - self.affineBX.setValue(0) + self.affine00.set_value(-1 if self.negate_y(index) else +1) + self.affine01.set_value(0) + self.affine10.set_value(0) + self.affine11.set_value(-1 if self.negate_x(index) else +1) + + self.block_notifications(False) + + def get_transform(self) -> AffineTransform: + return AffineTransform( + a00=self.affine00.get_value(), + a01=self.affine01.get_value(), + a02=self.affine02.get_value(), + a10=self.affine10.get_value(), + a11=self.affine11.get_value(), + a12=self.affine12.get_value(), + ) + + def set_identity(self) -> None: + self.apply_presets(0) def __call__(self, point: ScanPoint) -> ScanPoint: - ax = self.affineAX.getValue() - ay = self.affineAY.getValue() - at_m = self.affineATInMeters.getValue() - - bx = self.affineBX.getValue() - by = self.affineBY.getValue() - bt_m = self.affineBTInMeters.getValue() - - posX = ax * point.positionXInMeters + ay * point.positionYInMeters + at_m - posY = bx * point.positionXInMeters + by * point.positionYInMeters + bt_m - - rad = self.jitterRadiusInMeters.getValue() + transform = self.get_transform() + pos_y, pos_x = transform(point.position_y_m, point.position_x_m) + rad = self.jitter_radius_m.get_value() if rad > 0.0: while True: - dX = self._rng.uniform() - dY = self._rng.uniform() + dx = self._rng.uniform() + dy = self._rng.uniform() - if dX * dX + dY * dY < 1.0: - posX += dX * rad - posY += dY * rad + if dx * dx + dy * dy < 1.0: + pos_x += dx * rad + pos_y += dy * rad break - return ScanPoint(point.index, posX, posY) + return ScanPoint(point.index, pos_x, pos_y) diff --git a/src/ptychodus/model/product/scanRepository.py b/src/ptychodus/model/product/scanRepository.py deleted file mode 100644 index 42bbe047..00000000 --- a/src/ptychodus/model/product/scanRepository.py +++ /dev/null @@ -1,63 +0,0 @@ -from collections.abc import Sequence -from typing import overload -import logging - -from ptychodus.api.observer import ObservableSequence - -from .item import ProductRepositoryItem, ProductRepositoryObserver -from .metadata import MetadataRepositoryItem -from .object import ObjectRepositoryItem -from .probe import ProbeRepositoryItem -from .productRepository import ProductRepository -from .scan import ScanRepositoryItem - -logger = logging.getLogger(__name__) - - -class ScanRepository(ObservableSequence[ScanRepositoryItem], ProductRepositoryObserver): - def __init__(self, repository: ProductRepository) -> None: - super().__init__() - self._repository = repository - self._repository.addObserver(self) - - def getName(self, index: int) -> str: - return self._repository[index].getName() - - def setName(self, index: int, name: str) -> None: - self._repository[index].setName(name) - - @overload - def __getitem__(self, index: int) -> ScanRepositoryItem: ... - - @overload - def __getitem__(self, index: slice) -> Sequence[ScanRepositoryItem]: ... - - def __getitem__(self, index: int | slice) -> ScanRepositoryItem | Sequence[ScanRepositoryItem]: - if isinstance(index, slice): - return [item.getScan() for item in self._repository[index]] - else: - return self._repository[index].getScan() - - def __len__(self) -> int: - return len(self._repository) - - def handleItemInserted(self, index: int, item: ProductRepositoryItem) -> None: - self.notifyObserversItemInserted(index, item.getScan()) - - def handleMetadataChanged(self, index: int, item: MetadataRepositoryItem) -> None: - pass - - def handleScanChanged(self, index: int, item: ScanRepositoryItem) -> None: - self.notifyObserversItemChanged(index, item) - - def handleProbeChanged(self, index: int, item: ProbeRepositoryItem) -> None: - pass - - def handleObjectChanged(self, index: int, item: ObjectRepositoryItem) -> None: - pass - - def handleCostsChanged(self, index: int, costs: Sequence[float]) -> None: - pass - - def handleItemRemoved(self, index: int, item: ProductRepositoryItem) -> None: - self.notifyObserversItemRemoved(index, item.getScan()) diff --git a/src/ptychodus/model/product/scan_repository.py b/src/ptychodus/model/product/scan_repository.py new file mode 100644 index 00000000..8b867cf0 --- /dev/null +++ b/src/ptychodus/model/product/scan_repository.py @@ -0,0 +1,63 @@ +from collections.abc import Sequence +from typing import overload +import logging + +from ptychodus.api.observer import ObservableSequence + +from .item import ProductRepositoryItem, ProductRepositoryObserver +from .metadata import MetadataRepositoryItem +from .object import ObjectRepositoryItem +from .probe import ProbeRepositoryItem +from .repository import ProductRepository +from .scan import ScanRepositoryItem + +logger = logging.getLogger(__name__) + + +class ScanRepository(ObservableSequence[ScanRepositoryItem], ProductRepositoryObserver): + def __init__(self, repository: ProductRepository) -> None: + super().__init__() + self._repository = repository + self._repository.add_observer(self) + + def get_name(self, index: int) -> str: + return self._repository[index].get_name() + + def set_name(self, index: int, name: str) -> None: + self._repository[index].set_name(name) + + @overload + def __getitem__(self, index: int) -> ScanRepositoryItem: ... + + @overload + def __getitem__(self, index: slice) -> Sequence[ScanRepositoryItem]: ... + + def __getitem__(self, index: int | slice) -> ScanRepositoryItem | Sequence[ScanRepositoryItem]: + if isinstance(index, slice): + return [item.get_scan_item() for item in self._repository[index]] + else: + return self._repository[index].get_scan_item() + + def __len__(self) -> int: + return len(self._repository) + + def handle_item_inserted(self, index: int, item: ProductRepositoryItem) -> None: + self.notify_observers_item_inserted(index, item.get_scan_item()) + + def handle_metadata_changed(self, index: int, item: MetadataRepositoryItem) -> None: + pass + + def handle_scan_changed(self, index: int, item: ScanRepositoryItem) -> None: + self.notify_observers_item_changed(index, item) + + def handle_probe_changed(self, index: int, item: ProbeRepositoryItem) -> None: + pass + + def handle_object_changed(self, index: int, item: ObjectRepositoryItem) -> None: + pass + + def handle_costs_changed(self, index: int, costs: Sequence[float]) -> None: + pass + + def handle_item_removed(self, index: int, item: ProductRepositoryItem) -> None: + self.notify_observers_item_removed(index, item.get_scan_item()) diff --git a/src/ptychodus/model/product/settings.py b/src/ptychodus/model/product/settings.py new file mode 100644 index 00000000..a1ddb94c --- /dev/null +++ b/src/ptychodus/model/product/settings.py @@ -0,0 +1,37 @@ +from pathlib import Path + +from ptychodus.api.observer import Observable, Observer +from ptychodus.api.settings import SettingsRegistry + + +class ProductSettings(Observable, Observer): + def __init__(self, registry: SettingsRegistry) -> None: + super().__init__() + self._group = registry.create_group('Products') + self._group.add_observer(self) + + self.name = self._group.create_string_parameter('Name', 'Unnamed') + self.file_path = self._group.create_path_parameter('FilePath', Path('/path/to/product.h5')) + self.file_type = self._group.create_string_parameter('FileType', 'HDF5') + self.detector_distance_m = self._group.create_real_parameter( + 'DetectorDistanceInMeters', 1.0, minimum=0.0 + ) + self.probe_energy_eV = self._group.create_real_parameter( + 'ProbeEnergyInElectronVolts', 10000.0, minimum=0.0 + ) + self.probe_photon_count = self._group.create_real_parameter( + 'ProbePhotonCount', 0.0, minimum=0.0 + ) + self.exposure_time_s = self._group.create_real_parameter( + 'ExposureTimeInSeconds', 0.0, minimum=0.0 + ) + self.mass_attenuation_m2_kg = self._group.create_real_parameter( + 'MassAttenuationSquareMetersPerKilogram', 0.0, minimum=0.0 + ) + self.tomography_angle_deg = self._group.create_real_parameter( + 'TomographyAngleInDegrees', 0.0 + ) + + def _update(self, observable: Observable) -> None: + if observable is self._group: + self.notify_observers() diff --git a/src/ptychodus/model/product/validator.py b/src/ptychodus/model/product/validator.py new file mode 100644 index 00000000..8ceb5d97 --- /dev/null +++ b/src/ptychodus/model/product/validator.py @@ -0,0 +1,79 @@ +from ptychodus.api.observer import Observable, Observer + +from ..patterns import AssembledDiffractionDataset +from .geometry import ProductGeometry +from .object import ObjectRepositoryItem +from .probe import ProbeRepositoryItem +from .scan import ScanRepositoryItem + + +class ProductValidator(Observable, Observer): # TODO display + def __init__( + self, + dataset: AssembledDiffractionDataset, + scan: ScanRepositoryItem, + geometry: ProductGeometry, + probe: ProbeRepositoryItem, + object_: ObjectRepositoryItem, + ) -> None: + super().__init__() + self._dataset = dataset + self._scan = scan + self._geometry = geometry + self._probe = probe + self._object = object_ + self._are_positions_valid = False + self._are_probes_valid = False + self._is_object_valid = False + + def are_positions_valid(self) -> bool: + return self._are_positions_valid + + def are_probes_valid(self) -> bool: + return self._are_probes_valid + + def is_object_valid(self) -> bool: + return self._is_object_valid + + def _validate_scan(self) -> None: + scan = self._scan.get_scan() + scan_indexes = set(point.index for point in scan) + pattern_indexes = set(self._dataset.get_assembled_indexes()) + are_positions_valid_now = not scan_indexes.isdisjoint(pattern_indexes) + + if self._are_positions_valid != are_positions_valid_now: + self._are_positions_valid = are_positions_valid_now + self.notify_observers() + + def _validate_probes_and_object(self) -> None: + has_validity_changed = False + + probes = self._probe.get_probes() + are_probes_valid_now = self._geometry.is_probe_geometry_valid(probes.get_geometry()) + + if self._are_probes_valid != are_probes_valid_now: + self._are_probes_valid = are_probes_valid_now + has_validity_changed = True + + object_ = self._object.get_object() + is_object_valid_now = self._geometry.is_object_geometry_valid(object_.get_geometry()) + + if self._is_object_valid != is_object_valid_now: + self._is_object_valid = is_object_valid_now + has_validity_changed = True + + if has_validity_changed: + self.notify_observers() + + def _update(self, observable: Observable) -> None: + if observable is self._dataset: + self._validate_scan() + self._validate_probes_and_object() + elif observable is self._scan: + self._validate_scan() + elif observable is self._geometry: + self._validate_probes_and_object() + elif observable is self._probe: + self._validate_probes_and_object() + elif observable is self._object: + self._validate_probes_and_object() diff --git a/src/ptychodus/model/ptychi/__init__.py b/src/ptychodus/model/ptychi/__init__.py new file mode 100644 index 00000000..22075cea --- /dev/null +++ b/src/ptychodus/model/ptychi/__init__.py @@ -0,0 +1,31 @@ +from .affine import PtyChiAffineDegreesOfFreedomBitField +from .core import PtyChiReconstructorLibrary +from .device import PtyChiDeviceRepository +from .enums import PtyChiEnumerators +from .settings import ( + PtyChiAutodiffSettings, + PtyChiDMSettings, + PtyChiLSQMLSettings, + PtyChiOPRSettings, + PtyChiObjectSettings, + PtyChiPIESettings, + PtyChiProbePositionSettings, + PtyChiProbeSettings, + PtyChiSettings, +) + +__all__ = [ + 'PtyChiAffineDegreesOfFreedomBitField', + 'PtyChiAutodiffSettings', + 'PtyChiDMSettings', + 'PtyChiDeviceRepository', + 'PtyChiEnumerators', + 'PtyChiLSQMLSettings', + 'PtyChiOPRSettings', + 'PtyChiObjectSettings', + 'PtyChiPIESettings', + 'PtyChiProbePositionSettings', + 'PtyChiProbeSettings', + 'PtyChiReconstructorLibrary', + 'PtyChiSettings', +] diff --git a/src/ptychodus/model/ptychi/affine.py b/src/ptychodus/model/ptychi/affine.py new file mode 100644 index 00000000..c001b542 --- /dev/null +++ b/src/ptychodus/model/ptychi/affine.py @@ -0,0 +1,56 @@ +from collections.abc import Sequence +from enum import IntEnum +from typing import overload + +from ptychodus.api.observer import Observable, Observer +from ptychodus.api.parametric import IntegerParameter + + +class PtyChiAffineDegreesOfFreedom(IntEnum): + TRANSLATION = 0 + ROTATION = 1 + SCALING = 2 + SHEARING = 3 + ASYMMETRY = 4 + + +class PtyChiAffineDegreesOfFreedomBitField(Sequence[str], Observable, Observer): + def __init__(self, parameter: IntegerParameter) -> None: + super().__init__() + self._parameter = parameter + parameter.add_observer(self) + + def is_bit_set(self, bit: int) -> bool: + value = self._parameter.get_value() + mask = 1 << bit + return value & mask != 0 + + def set_bit(self, bit: int, is_set: bool) -> None: + value = self._parameter.get_value() + mask = 1 << bit + + if is_set: + value |= mask + else: + value &= ~mask + + self._parameter.set_value(value) + + @overload + def __getitem__(self, index: int) -> str: ... + + @overload + def __getitem__(self, index: slice) -> Sequence[str]: ... + + def __getitem__(self, index: int | slice) -> str | Sequence[str]: + if isinstance(index, slice): + return [self[idx] for idx in range(index.start, index.stop, index.step)] + + return PtyChiAffineDegreesOfFreedom(index).name.title() + + def __len__(self) -> int: + return len(PtyChiAffineDegreesOfFreedom) + + def _update(self, observable: Observable) -> None: + if observable is self._parameter: + self.notify_observers() diff --git a/src/ptychodus/model/ptychi/autodiff.py b/src/ptychodus/model/ptychi/autodiff.py new file mode 100644 index 00000000..b2a4f6e7 --- /dev/null +++ b/src/ptychodus/model/ptychi/autodiff.py @@ -0,0 +1,201 @@ +from collections.abc import Sequence +import logging + + +from ptychi.api import ( + AutodiffPtychographyOPRModeWeightsOptions, + AutodiffPtychographyObjectOptions, + AutodiffPtychographyOptions, + AutodiffPtychographyProbeOptions, + AutodiffPtychographyProbePositionOptions, + AutodiffPtychographyReconstructorOptions, + ForwardModels, + LossFunctions, +) +from ptychi.api.task import PtychographyTask + +from ptychodus.api.object import Object, ObjectGeometry +from ptychodus.api.probe import ProbeSequence +from ptychodus.api.product import ProductMetadata +from ptychodus.api.reconstructor import ReconstructInput, ReconstructOutput, Reconstructor +from ptychodus.api.scan import PositionSequence + +from .helper import PtyChiOptionsHelper +from .settings import PtyChiAutodiffSettings + +logger = logging.getLogger(__name__) + + +class AutodiffReconstructor(Reconstructor): + def __init__( + self, options_helper: PtyChiOptionsHelper, settings: PtyChiAutodiffSettings + ) -> None: + super().__init__() + self._options_helper = options_helper + self._settings = settings + + @property + def name(self) -> str: + return 'Autodiff' + + def _create_reconstructor_options(self) -> AutodiffPtychographyReconstructorOptions: + helper = self._options_helper.reconstructor_helper + + #### + + loss_function_str = self._settings.loss_function.get_value() + + try: + loss_function = LossFunctions[loss_function_str.upper()] + except KeyError: + logger.warning('Failed to parse loss function "{loss_function_str}"!') + loss_function = LossFunctions.MSE_SQRT + + #### + + forward_model_class_str = self._settings.forward_model_class.get_value() + + try: + forward_model_class = ForwardModels[forward_model_class_str.upper()] + except KeyError: + logger.warning('Failed to parse forward model class "{forward_model_class_str}"!') + forward_model_class = ForwardModels.PLANAR_PTYCHOGRAPHY + + #### + + return AutodiffPtychographyReconstructorOptions( + num_epochs=helper.num_epochs, + batch_size=helper.batch_size, + batching_mode=helper.batching_mode, + compact_mode_update_clustering=helper.compact_mode_update_clustering, + compact_mode_update_clustering_stride=helper.compact_mode_update_clustering_stride, + default_device=helper.default_device, + default_dtype=helper.default_dtype, + use_double_precision_for_fft=helper.use_double_precision_for_fft, + allow_nondeterministic_algorithms=helper.allow_nondeterministic_algorithms, + random_seed=helper.random_seed, + displayed_loss_function=helper.displayed_loss_function, + forward_model_options=helper.forward_model_options, + loss_function=loss_function, + forward_model_class=forward_model_class, + forward_model_params=None, + ) + + def _create_object_options(self, object_: Object) -> AutodiffPtychographyObjectOptions: + helper = self._options_helper.object_helper + return AutodiffPtychographyObjectOptions( + optimizable=helper.optimizable, + optimization_plan=helper.optimization_plan, + optimizer=helper.optimizer, + step_size=helper.step_size, + optimizer_params=helper.optimizer_params, + initial_guess=helper.get_initial_guess(object_), + slice_spacings_m=helper.get_slice_spacings_m(object_), + slice_spacing_options=helper.slice_spacing_options, + pixel_size_m=helper.get_pixel_size_m(object_), + pixel_size_aspect_ratio=helper.get_pixel_aspect_ratio(object_), + l1_norm_constraint=helper.l1_norm_constraint, + l2_norm_constraint=helper.l2_norm_constraint, + smoothness_constraint=helper.smoothness_constraint, + total_variation=helper.total_variation, + remove_grid_artifacts=helper.remove_grid_artifacts, + multislice_regularization=helper.multislice_regularization, + patch_interpolation_method=helper.patch_interpolation_method, + remove_object_probe_ambiguity=helper.remove_object_probe_ambiguity, + build_preconditioner_with_all_modes=helper.build_preconditioner_with_all_modes, + determine_position_origin_coords_by=helper.determine_position_origin_coords_by, + position_origin_coords=helper.get_position_origin_coords(object_), + ) + + def _create_probe_options( + self, probes: ProbeSequence, metadata: ProductMetadata + ) -> AutodiffPtychographyProbeOptions: + helper = self._options_helper.probe_helper + return AutodiffPtychographyProbeOptions( + optimizable=helper.optimizable, + optimization_plan=helper.optimization_plan, + optimizer=helper.optimizer, + step_size=helper.step_size, + optimizer_params=helper.optimizer_params, + initial_guess=helper.get_initial_guess(probes), + power_constraint=helper.get_power_constraint(metadata), + orthogonalize_incoherent_modes=helper.orthogonalize_incoherent_modes, + orthogonalize_opr_modes=helper.orthogonalize_opr_modes, + support_constraint=helper.support_constraint, + center_constraint=helper.center_constraint, + eigenmode_update_relaxation=helper.eigenmode_update_relaxation, + ) + + def _create_probe_position_options( + self, scan: PositionSequence, object_geometry: ObjectGeometry + ) -> AutodiffPtychographyProbePositionOptions: + helper = self._options_helper.probe_position_helper + position_x_px, position_y_px = helper.get_positions_px(scan, object_geometry) + return AutodiffPtychographyProbePositionOptions( + optimizable=helper.optimizable, + optimization_plan=helper.optimization_plan, + optimizer=helper.optimizer, + step_size=helper.step_size, + optimizer_params=helper.optimizer_params, + position_x_px=position_x_px, + position_y_px=position_y_px, + constrain_position_mean=helper.constrain_position_mean, + correction_options=helper.correction_options, + affine_transform_constraint=helper.affine_transform_constraint, + ) + + def _create_opr_mode_weight_options( + self, probes: ProbeSequence + ) -> AutodiffPtychographyOPRModeWeightsOptions: + helper = self._options_helper.opr_helper + return AutodiffPtychographyOPRModeWeightsOptions( + optimizable=helper.optimizable, + optimization_plan=helper.optimization_plan, + optimizer=helper.optimizer, + step_size=helper.step_size, + optimizer_params=helper.optimizer_params, + initial_weights=helper.get_initial_weights(probes), + optimize_eigenmode_weights=helper.optimize_eigenmode_weights, + optimize_intensity_variation=helper.optimize_intensity_variation, + smoothing=helper.smoothing, + update_relaxation=helper.update_relaxation, + ) + + def _create_task_options(self, parameters: ReconstructInput) -> AutodiffPtychographyOptions: + product = parameters.product + return AutodiffPtychographyOptions( + data_options=self._options_helper.create_data_options(parameters), + reconstructor_options=self._create_reconstructor_options(), + object_options=self._create_object_options(product.object_), + probe_options=self._create_probe_options(product.probes, product.metadata), + probe_position_options=self._create_probe_position_options( + product.positions, product.object_.get_geometry() + ), + opr_mode_weight_options=self._create_opr_mode_weight_options(product.probes), + ) + + def reconstruct(self, parameters: ReconstructInput) -> ReconstructOutput: + task_options = self._create_task_options(parameters) + task = PtychographyTask(task_options) + task.run() # TODO (n_epochs: int | None = None) + + costs: Sequence[float] = list() + task_reconstructor = task.reconstructor + + if task_reconstructor is not None: + loss_tracker = task_reconstructor.loss_tracker + # TODO update api to include epoch and loss + # epoch = loss_tracker.table['epoch'].to_numpy() + loss = loss_tracker.table['loss'].to_numpy() + costs = [float(x) for x in loss.flatten()] + + product = self._options_helper.create_product( + product=parameters.product, + position_x_px=task.get_probe_positions_x(as_numpy=True), + position_y_px=task.get_probe_positions_y(as_numpy=True), + probe_array=task.get_data_to_cpu('probe', as_numpy=True), + object_array=task.get_data_to_cpu('object', as_numpy=True), + opr_weights=task.get_data_to_cpu('opr_mode_weights', as_numpy=True), + costs=costs, + ) + return ReconstructOutput(product, 0) diff --git a/src/ptychodus/model/ptychi/core.py b/src/ptychodus/model/ptychi/core.py new file mode 100644 index 00000000..0f4f3ed0 --- /dev/null +++ b/src/ptychodus/model/ptychi/core.py @@ -0,0 +1,98 @@ +from collections.abc import Iterator +from importlib.metadata import version +import logging + +from ptychodus.api.reconstructor import ( + NullReconstructor, + Reconstructor, + ReconstructorLibrary, +) +from ptychodus.api.settings import SettingsRegistry + +from ..patterns import PatternSizer +from .device import PtyChiDeviceRepository +from .enums import PtyChiEnumerators +from .settings import ( + PtyChiAutodiffSettings, + PtyChiDMSettings, + PtyChiLSQMLSettings, + PtyChiOPRSettings, + PtyChiObjectSettings, + PtyChiPIESettings, + PtyChiProbePositionSettings, + PtyChiProbeSettings, + PtyChiSettings, +) + +logger = logging.getLogger(__name__) + + +class PtyChiReconstructorLibrary(ReconstructorLibrary): + def __init__( + self, + settings_registry: SettingsRegistry, + pattern_sizer: PatternSizer, + is_developer_mode_enabled: bool, + ) -> None: + super().__init__() + self.autodiff_settings = PtyChiAutodiffSettings(settings_registry) + self.dm_settings = PtyChiDMSettings(settings_registry) + self.lsqml_settings = PtyChiLSQMLSettings(settings_registry) + self.object_settings = PtyChiObjectSettings(settings_registry) + self.opr_settings = PtyChiOPRSettings(settings_registry) + self.pie_settings = PtyChiPIESettings(settings_registry) + self.probe_position_settings = PtyChiProbePositionSettings(settings_registry) + self.probe_settings = PtyChiProbeSettings(settings_registry) + self.settings = PtyChiSettings(settings_registry) + + self.enumerators = PtyChiEnumerators() + self.device_repository = PtyChiDeviceRepository( + is_developer_mode_enabled=is_developer_mode_enabled + ) + self.reconstructor_list: list[Reconstructor] = list() + + try: + from .autodiff import AutodiffReconstructor + from .dm import DMReconstructor + from .epie import EPIEReconstructor + from .helper import PtyChiOptionsHelper + from .lsqml import LSQMLReconstructor + from .pie import PIEReconstructor + from .rpie import RPIEReconstructor + except ModuleNotFoundError: + logger.info('pty-chi not found.') + + if is_developer_mode_enabled: + for reconstructor in ('DM', 'PIE', 'ePIE', 'rPIE', 'LSQML', 'Autodiff'): + self.reconstructor_list.append(NullReconstructor(reconstructor)) + else: + ptychi_version = version('ptychi') + logger.info(f'Pty-Chi {ptychi_version}') + + options_helper = PtyChiOptionsHelper( + self.settings, + self.object_settings, + self.probe_settings, + self.probe_position_settings, + self.opr_settings, + pattern_sizer, + ) + self.reconstructor_list.append(DMReconstructor(options_helper, self.dm_settings)) + self.reconstructor_list.append(PIEReconstructor(options_helper, self.pie_settings)) + self.reconstructor_list.append(EPIEReconstructor(options_helper, self.pie_settings)) + self.reconstructor_list.append(RPIEReconstructor(options_helper, self.pie_settings)) + self.reconstructor_list.append(LSQMLReconstructor(options_helper, self.lsqml_settings)) + self.reconstructor_list.append( + AutodiffReconstructor(options_helper, self.autodiff_settings) + ) + + @property + def name(self) -> str: + return 'pty-chi' + + @property + def logger_name(self) -> str: + return 'ptychi' + + def __iter__(self) -> Iterator[Reconstructor]: + return iter(self.reconstructor_list) diff --git a/src/ptychodus/model/ptychi/device.py b/src/ptychodus/model/ptychi/device.py new file mode 100644 index 00000000..1518c80c --- /dev/null +++ b/src/ptychodus/model/ptychi/device.py @@ -0,0 +1,35 @@ +from collections.abc import Sequence +from typing import overload +import logging + +logger = logging.getLogger(__name__) + + +class PtyChiDeviceRepository(Sequence[str]): + def __init__(self, *, is_developer_mode_enabled: bool) -> None: + self._devices: list[str] = list() + + try: + import ptychi + except ModuleNotFoundError: + if is_developer_mode_enabled: + self._devices.extend(f'gpu:{n}' for n in range(4)) + else: + for device in ptychi.list_available_devices(): + logger.info(device) + self._devices.append(f'{device.name} ({device.torch_device})') + + if not self._devices: + logger.info('No devices found!') + + @overload + def __getitem__(self, index: int) -> str: ... + + @overload + def __getitem__(self, index: slice) -> Sequence[str]: ... + + def __getitem__(self, index: int | slice) -> str | Sequence[str]: + return self._devices[index] + + def __len__(self) -> int: + return len(self._devices) diff --git a/src/ptychodus/model/ptychi/dm.py b/src/ptychodus/model/ptychi/dm.py new file mode 100644 index 00000000..1366d9fa --- /dev/null +++ b/src/ptychodus/model/ptychi/dm.py @@ -0,0 +1,174 @@ +from collections.abc import Sequence +import logging + + +from ptychi.api import ( + DMOPRModeWeightsOptions, + DMObjectOptions, + DMOptions, + DMProbeOptions, + DMProbePositionOptions, + DMReconstructorOptions, +) +from ptychi.api.task import PtychographyTask + +from ptychodus.api.object import Object, ObjectGeometry +from ptychodus.api.probe import ProbeSequence +from ptychodus.api.product import ProductMetadata +from ptychodus.api.reconstructor import ReconstructInput, ReconstructOutput, Reconstructor +from ptychodus.api.scan import PositionSequence + +from .helper import PtyChiOptionsHelper +from .settings import PtyChiDMSettings + +logger = logging.getLogger(__name__) + + +class DMReconstructor(Reconstructor): + def __init__(self, options_helper: PtyChiOptionsHelper, settings: PtyChiDMSettings) -> None: + super().__init__() + self._options_helper = options_helper + self._settings = settings + + @property + def name(self) -> str: + return 'DM' + + def _create_reconstructor_options(self) -> DMReconstructorOptions: + helper = self._options_helper.reconstructor_helper + return DMReconstructorOptions( + num_epochs=helper.num_epochs, + batch_size=helper.batch_size, + batching_mode=helper.batching_mode, + compact_mode_update_clustering=helper.compact_mode_update_clustering, + compact_mode_update_clustering_stride=helper.compact_mode_update_clustering_stride, + default_device=helper.default_device, + default_dtype=helper.default_dtype, + use_double_precision_for_fft=helper.use_double_precision_for_fft, + allow_nondeterministic_algorithms=helper.allow_nondeterministic_algorithms, + random_seed=helper.random_seed, + displayed_loss_function=helper.displayed_loss_function, + forward_model_options=helper.forward_model_options, + exit_wave_update_relaxation=self._settings.exit_wave_update_relaxation.get_value(), + chunk_length=self._settings.chunk_length.get_value(), + ) + + def _create_object_options(self, object_: Object) -> DMObjectOptions: + helper = self._options_helper.object_helper + return DMObjectOptions( + optimizable=helper.optimizable, + optimization_plan=helper.optimization_plan, + optimizer=helper.optimizer, + step_size=helper.step_size, + optimizer_params=helper.optimizer_params, + initial_guess=helper.get_initial_guess(object_), + slice_spacings_m=helper.get_slice_spacings_m(object_), + slice_spacing_options=helper.slice_spacing_options, + pixel_size_m=helper.get_pixel_size_m(object_), + pixel_size_aspect_ratio=helper.get_pixel_aspect_ratio(object_), + l1_norm_constraint=helper.l1_norm_constraint, + l2_norm_constraint=helper.l2_norm_constraint, + smoothness_constraint=helper.smoothness_constraint, + total_variation=helper.total_variation, + remove_grid_artifacts=helper.remove_grid_artifacts, + multislice_regularization=helper.multislice_regularization, + patch_interpolation_method=helper.patch_interpolation_method, + remove_object_probe_ambiguity=helper.remove_object_probe_ambiguity, + build_preconditioner_with_all_modes=helper.build_preconditioner_with_all_modes, + determine_position_origin_coords_by=helper.determine_position_origin_coords_by, + position_origin_coords=helper.get_position_origin_coords(object_), + amplitude_clamp_limit=self._settings.object_amplitude_clamp_limit.get_value(), + inertia=self._settings.object_inertia.get_value(), + ) + + def _create_probe_options( + self, probes: ProbeSequence, metadata: ProductMetadata + ) -> DMProbeOptions: + helper = self._options_helper.probe_helper + return DMProbeOptions( + optimizable=helper.optimizable, + optimization_plan=helper.optimization_plan, + optimizer=helper.optimizer, + step_size=helper.step_size, + optimizer_params=helper.optimizer_params, + initial_guess=helper.get_initial_guess(probes), + power_constraint=helper.get_power_constraint(metadata), + orthogonalize_incoherent_modes=helper.orthogonalize_incoherent_modes, + orthogonalize_opr_modes=helper.orthogonalize_opr_modes, + support_constraint=helper.support_constraint, + center_constraint=helper.center_constraint, + eigenmode_update_relaxation=helper.eigenmode_update_relaxation, + inertia=self._settings.probe_inertia.get_value(), + ) + + def _create_probe_position_options( + self, scan: PositionSequence, object_geometry: ObjectGeometry + ) -> DMProbePositionOptions: + helper = self._options_helper.probe_position_helper + position_x_px, position_y_px = helper.get_positions_px(scan, object_geometry) + return DMProbePositionOptions( + optimizable=helper.optimizable, + optimization_plan=helper.optimization_plan, + optimizer=helper.optimizer, + step_size=helper.step_size, + optimizer_params=helper.optimizer_params, + position_x_px=position_x_px, + position_y_px=position_y_px, + constrain_position_mean=helper.constrain_position_mean, + correction_options=helper.correction_options, + affine_transform_constraint=helper.affine_transform_constraint, + ) + + def _create_opr_mode_weight_options(self, probes: ProbeSequence) -> DMOPRModeWeightsOptions: + helper = self._options_helper.opr_helper + return DMOPRModeWeightsOptions( + optimizable=helper.optimizable, + optimization_plan=helper.optimization_plan, + optimizer=helper.optimizer, + step_size=helper.step_size, + optimizer_params=helper.optimizer_params, + initial_weights=helper.get_initial_weights(probes), + optimize_eigenmode_weights=helper.optimize_eigenmode_weights, + optimize_intensity_variation=helper.optimize_intensity_variation, + smoothing=helper.smoothing, + update_relaxation=helper.update_relaxation, + ) + + def _create_task_options(self, parameters: ReconstructInput) -> DMOptions: + product = parameters.product + return DMOptions( + data_options=self._options_helper.create_data_options(parameters), + reconstructor_options=self._create_reconstructor_options(), + object_options=self._create_object_options(product.object_), + probe_options=self._create_probe_options(product.probes, product.metadata), + probe_position_options=self._create_probe_position_options( + product.positions, product.object_.get_geometry() + ), + opr_mode_weight_options=self._create_opr_mode_weight_options(product.probes), + ) + + def reconstruct(self, parameters: ReconstructInput) -> ReconstructOutput: + task_options = self._create_task_options(parameters) + task = PtychographyTask(task_options) + task.run() # TODO (n_epochs: int | None = None) + + costs: Sequence[float] = list() + task_reconstructor = task.reconstructor + + if task_reconstructor is not None: + loss_tracker = task_reconstructor.loss_tracker + # TODO update api to include epoch and loss + # epoch = loss_tracker.table['epoch'].to_numpy() + loss = loss_tracker.table['loss'].to_numpy() + costs = [float(x) for x in loss.flatten()] + + product = self._options_helper.create_product( + product=parameters.product, + position_x_px=task.get_probe_positions_x(as_numpy=True), + position_y_px=task.get_probe_positions_y(as_numpy=True), + probe_array=task.get_data_to_cpu('probe', as_numpy=True), + object_array=task.get_data_to_cpu('object', as_numpy=True), + opr_weights=task.get_data_to_cpu('opr_mode_weights', as_numpy=True), + costs=costs, + ) + return ReconstructOutput(product, 0) diff --git a/src/ptychodus/model/ptychi/enums.py b/src/ptychodus/model/ptychi/enums.py new file mode 100644 index 00000000..f2c761b4 --- /dev/null +++ b/src/ptychodus/model/ptychi/enums.py @@ -0,0 +1,86 @@ +from collections.abc import Iterator, Sequence + + +class PtyChiEnumerators: + def __init__(self) -> None: + try: + from ptychi.api import ( + BatchingModes, + Directions, + ForwardModels, + ImageGradientMethods, + ImageIntegrationMethods, + LossFunctions, + NoiseModels, + OPRWeightSmoothingMethods, + Optimizers, + OrthogonalizationMethods, + PatchInterpolationMethods, + PositionCorrectionTypes, + ) + except ModuleNotFoundError: + self._batching_modes: Sequence[str] = list() + self._directions: Sequence[str] = list() + self._forward_models: Sequence[str] = list() + self._image_gradient_methods: Sequence[str] = list() + self._image_integration_methods: Sequence[str] = list() + self._loss_functions: Sequence[str] = list() + self._noise_models: Sequence[str] = list() + self._opr_weight_smoothing_methods: Sequence[str] = list() + self._optimizers: Sequence[str] = list() + self._orthogonalization_methods: Sequence[str] = list() + self._patch_interpolation_methods: Sequence[str] = list() + self._position_correction_types: Sequence[str] = list() + else: + self._batching_modes = [member.name for member in BatchingModes] + self._directions = [member.name for member in Directions] + self._forward_models = [member.name for member in ForwardModels] + self._image_gradient_methods = [member.name for member in ImageGradientMethods] + self._image_integration_methods = [member.name for member in ImageIntegrationMethods] + self._loss_functions = [member.name for member in LossFunctions] + self._noise_models = [member.name for member in NoiseModels] + self._opr_weight_smoothing_methods = [ + member.name for member in OPRWeightSmoothingMethods + ] + self._optimizers = [member.name for member in Optimizers] + self._orthogonalization_methods = [member.name for member in OrthogonalizationMethods] + self._patch_interpolation_methods = [ + member.name for member in PatchInterpolationMethods + ] + self._position_correction_types = [member.name for member in PositionCorrectionTypes] + + def batching_modes(self) -> Iterator[str]: + return iter(self._batching_modes) + + def directions(self) -> Iterator[str]: + return iter(self._directions) + + def forward_models(self) -> Iterator[str]: + return iter(self._forward_models) + + def image_gradient_methods(self) -> Iterator[str]: + return iter(self._image_gradient_methods) + + def image_integration_methods(self) -> Iterator[str]: + return iter(self._image_integration_methods) + + def loss_functions(self) -> Iterator[str]: + return iter(self._loss_functions) + + def noise_models(self) -> Iterator[str]: + return iter(self._noise_models) + + def opr_weight_smoothing_methods(self) -> Iterator[str]: + return iter(self._opr_weight_smoothing_methods) + + def optimizers(self) -> Iterator[str]: + return iter(self._optimizers) + + def orthogonalization_methods(self) -> Iterator[str]: + return iter(self._orthogonalization_methods) + + def patch_interpolation_methods(self) -> Iterator[str]: + return iter(self._patch_interpolation_methods) + + def position_correction_types(self) -> Iterator[str]: + return iter(self._position_correction_types) diff --git a/src/ptychodus/model/ptychi/epie.py b/src/ptychodus/model/ptychi/epie.py new file mode 100644 index 00000000..6d8b2965 --- /dev/null +++ b/src/ptychodus/model/ptychi/epie.py @@ -0,0 +1,171 @@ +from collections.abc import Sequence +import logging + + +from ptychi.api import ( + EPIEOptions, + EPIEReconstructorOptions, + PIEOPRModeWeightsOptions, + PIEObjectOptions, + PIEProbeOptions, + PIEProbePositionOptions, +) +from ptychi.api.task import PtychographyTask + +from ptychodus.api.object import Object, ObjectGeometry +from ptychodus.api.probe import ProbeSequence +from ptychodus.api.product import ProductMetadata +from ptychodus.api.reconstructor import ReconstructInput, ReconstructOutput, Reconstructor +from ptychodus.api.scan import PositionSequence + +from .helper import PtyChiOptionsHelper +from .settings import PtyChiPIESettings + +logger = logging.getLogger(__name__) + + +class EPIEReconstructor(Reconstructor): + def __init__(self, options_helper: PtyChiOptionsHelper, settings: PtyChiPIESettings) -> None: + super().__init__() + self._options_helper = options_helper + self._settings = settings + + @property + def name(self) -> str: + return 'ePIE' + + def _create_reconstructor_options(self) -> EPIEReconstructorOptions: + helper = self._options_helper.reconstructor_helper + return EPIEReconstructorOptions( + num_epochs=helper.num_epochs, + batch_size=helper.batch_size, + batching_mode=helper.batching_mode, + compact_mode_update_clustering=helper.compact_mode_update_clustering, + compact_mode_update_clustering_stride=helper.compact_mode_update_clustering_stride, + default_device=helper.default_device, + default_dtype=helper.default_dtype, + use_double_precision_for_fft=helper.use_double_precision_for_fft, + allow_nondeterministic_algorithms=helper.allow_nondeterministic_algorithms, + random_seed=helper.random_seed, + displayed_loss_function=helper.displayed_loss_function, + forward_model_options=helper.forward_model_options, + ) + + def _create_object_options(self, object_: Object) -> PIEObjectOptions: + helper = self._options_helper.object_helper + return PIEObjectOptions( + optimizable=helper.optimizable, + optimization_plan=helper.optimization_plan, + optimizer=helper.optimizer, + step_size=helper.step_size, + optimizer_params=helper.optimizer_params, + initial_guess=helper.get_initial_guess(object_), + slice_spacings_m=helper.get_slice_spacings_m(object_), + slice_spacing_options=helper.slice_spacing_options, + pixel_size_m=helper.get_pixel_size_m(object_), + pixel_size_aspect_ratio=helper.get_pixel_aspect_ratio(object_), + l1_norm_constraint=helper.l1_norm_constraint, + l2_norm_constraint=helper.l2_norm_constraint, + smoothness_constraint=helper.smoothness_constraint, + total_variation=helper.total_variation, + remove_grid_artifacts=helper.remove_grid_artifacts, + multislice_regularization=helper.multislice_regularization, + patch_interpolation_method=helper.patch_interpolation_method, + remove_object_probe_ambiguity=helper.remove_object_probe_ambiguity, + build_preconditioner_with_all_modes=helper.build_preconditioner_with_all_modes, + determine_position_origin_coords_by=helper.determine_position_origin_coords_by, + position_origin_coords=helper.get_position_origin_coords(object_), + alpha=self._settings.object_alpha.get_value(), + ) + + def _create_probe_options( + self, probes: ProbeSequence, metadata: ProductMetadata + ) -> PIEProbeOptions: + helper = self._options_helper.probe_helper + return PIEProbeOptions( + optimizable=helper.optimizable, + optimization_plan=helper.optimization_plan, + optimizer=helper.optimizer, + step_size=helper.step_size, + optimizer_params=helper.optimizer_params, + initial_guess=helper.get_initial_guess(probes), + power_constraint=helper.get_power_constraint(metadata), + orthogonalize_incoherent_modes=helper.orthogonalize_incoherent_modes, + orthogonalize_opr_modes=helper.orthogonalize_opr_modes, + support_constraint=helper.support_constraint, + center_constraint=helper.center_constraint, + eigenmode_update_relaxation=helper.eigenmode_update_relaxation, + alpha=self._settings.probe_alpha.get_value(), + ) + + def _create_probe_position_options( + self, scan: PositionSequence, object_geometry: ObjectGeometry + ) -> PIEProbePositionOptions: + helper = self._options_helper.probe_position_helper + position_x_px, position_y_px = helper.get_positions_px(scan, object_geometry) + return PIEProbePositionOptions( + optimizable=helper.optimizable, + optimization_plan=helper.optimization_plan, + optimizer=helper.optimizer, + step_size=helper.step_size, + optimizer_params=helper.optimizer_params, + position_x_px=position_x_px, + position_y_px=position_y_px, + constrain_position_mean=helper.constrain_position_mean, + correction_options=helper.correction_options, + affine_transform_constraint=helper.affine_transform_constraint, + ) + + def _create_opr_mode_weight_options(self, probes: ProbeSequence) -> PIEOPRModeWeightsOptions: + helper = self._options_helper.opr_helper + return PIEOPRModeWeightsOptions( + optimizable=helper.optimizable, + optimization_plan=helper.optimization_plan, + optimizer=helper.optimizer, + step_size=helper.step_size, + optimizer_params=helper.optimizer_params, + initial_weights=helper.get_initial_weights(probes), + optimize_eigenmode_weights=helper.optimize_eigenmode_weights, + optimize_intensity_variation=helper.optimize_intensity_variation, + smoothing=helper.smoothing, + update_relaxation=helper.update_relaxation, + ) + + def _create_task_options(self, parameters: ReconstructInput) -> EPIEOptions: + product = parameters.product + return EPIEOptions( + data_options=self._options_helper.create_data_options(parameters), + reconstructor_options=self._create_reconstructor_options(), + object_options=self._create_object_options(product.object_), + probe_options=self._create_probe_options(product.probes, product.metadata), + probe_position_options=self._create_probe_position_options( + product.positions, product.object_.get_geometry() + ), + opr_mode_weight_options=self._create_opr_mode_weight_options(product.probes), + ) + + def reconstruct(self, parameters: ReconstructInput) -> ReconstructOutput: + task_options = self._create_task_options(parameters) + task = PtychographyTask(task_options) + task.run() # TODO (n_epochs: int | None = None) + + costs: Sequence[float] = list() + task_reconstructor = task.reconstructor + + if task_reconstructor is not None: + loss_tracker = task_reconstructor.loss_tracker + # TODO update api to include epoch and loss + # epoch = loss_tracker.table['epoch'].to_numpy() + loss = loss_tracker.table['loss'].to_numpy() + costs = [float(x) for x in loss.flatten()] + + product = self._options_helper.create_product( + product=parameters.product, + position_x_px=task.get_probe_positions_x(as_numpy=True), + position_y_px=task.get_probe_positions_y(as_numpy=True), + probe_array=task.get_data_to_cpu('probe', as_numpy=True), + object_array=task.get_data_to_cpu('object', as_numpy=True), + opr_weights=task.get_data_to_cpu('opr_mode_weights', as_numpy=True), + costs=costs, + ) + return ReconstructOutput(product, 0) diff --git a/src/ptychodus/model/ptychi/helper.py b/src/ptychodus/model/ptychi/helper.py new file mode 100644 index 00000000..dd4776f8 --- /dev/null +++ b/src/ptychodus/model/ptychi/helper.py @@ -0,0 +1,722 @@ +from collections.abc import Sequence +import logging + +import torch +import numpy + +from ptychi.api import ( + AffineDegreesOfFreedom, + BatchingModes, + Devices, + Directions, + Dtypes, + ImageGradientMethods, + ImageIntegrationMethods, + LossFunctions, + OPRWeightSmoothingMethods, + ObjectPosOriginCoordsMethods, + OptimizationPlan, + Optimizers, + OrthogonalizationMethods, + PatchInterpolationMethods, + PositionCorrectionTypes, + PtychographyDataOptions, +) +from ptychi.api.options.base import ( + ForwardModelOptions, + OPRModeWeightsSmoothingOptions, + ObjectL1NormConstraintOptions, + ObjectL2NormConstraintOptions, + ObjectMultisliceRegularizationOptions, + ObjectSmoothnessConstraintOptions, + ObjectTotalVariationOptions, + PositionAffineTransformConstraintOptions, + PositionCorrectionOptions, + ProbeCenterConstraintOptions, + ProbeOrthogonalizeIncoherentModesOptions, + ProbeOrthogonalizeOPRModesOptions, + ProbePowerConstraintOptions, + ProbeSupportConstraintOptions, + RemoveGridArtifactsOptions, + RemoveObjectProbeAmbiguityOptions, + SliceSpacingOptions, +) + +from ptychodus.api.object import Object, ObjectGeometry, ObjectPoint +from ptychodus.api.probe import ProbeSequence +from ptychodus.api.product import Product, ProductMetadata +from ptychodus.api.reconstructor import ReconstructInput +from ptychodus.api.scan import PositionSequence, ScanPoint +from ptychodus.api.typing import ComplexArrayType, RealArrayType + +from ..patterns import PatternSizer +from .affine import PtyChiAffineDegreesOfFreedom, PtyChiAffineDegreesOfFreedomBitField +from .settings import ( + PtyChiOPRSettings, + PtyChiObjectSettings, + PtyChiProbePositionSettings, + PtyChiProbeSettings, + PtyChiSettings, +) + + +__all__ = ['PtyChiOptionsHelper'] + +logger = logging.getLogger(__name__) + + +def create_optimization_plan(start: int, stop: int, stride: int) -> OptimizationPlan: + return OptimizationPlan(start, None if stop < 0 else stop, stride) + + +def parse_optimizer(text: str) -> Optimizers: + try: + optimizer = Optimizers[text.upper()] + except KeyError: + logger.warning(f'Failed to parse optimizer "{text}"!') + optimizer = Optimizers.SGD + + return optimizer + + +class PtyChiReconstructorOptionsHelper: + def __init__(self, settings: PtyChiSettings) -> None: + self._settings = settings + + @property + def num_epochs(self) -> int: + return self._settings.num_epochs.get_value() + + @property + def batch_size(self) -> int: + return self._settings.batch_size.get_value() + + @property + def batching_mode(self) -> BatchingModes: + batching_mode_str = self._settings.batching_mode.get_value() + + try: + return BatchingModes[batching_mode_str.upper()] + except KeyError: + logger.warning(f'Failed to parse batching mode "{batching_mode_str}"!') + return BatchingModes.RANDOM + + @property + def compact_mode_update_clustering(self) -> bool: + return self._settings.compact_mode_update_clustering.get_value() > 0 + + @property + def compact_mode_update_clustering_stride(self) -> int: + return self._settings.compact_mode_update_clustering.get_value() + + @property + def default_device(self) -> Devices: + return Devices.GPU if self._settings.use_devices.get_value() else Devices.CPU + + @property + def default_dtype(self) -> Dtypes: + return Dtypes.FLOAT64 if self._settings.use_double_precision.get_value() else Dtypes.FLOAT32 + + @property + def use_double_precision_for_fft(self) -> bool: + return self._settings.use_double_precision_for_fft.get_value() + + @property + def allow_nondeterministic_algorithms(self) -> bool: + return self._settings.allow_nondeterministic_algorithms.get_value() + + @property + def random_seed(self) -> int | None: + return None # TODO + + @property + def displayed_loss_function(self) -> LossFunctions | None: + return LossFunctions.MSE_SQRT # TODO + + @property + def forward_model_options(self) -> ForwardModelOptions: + return ForwardModelOptions( + low_memory_mode=self._settings.use_low_memory_mode.get_value(), + pad_for_shift=self._settings.pad_for_shift.get_value(), + ) + + +class PtyChiObjectOptionsHelper: + def __init__(self, settings: PtyChiObjectSettings) -> None: + self._settings = settings + + @property + def optimizable(self) -> bool: + return self._settings.is_optimizable.get_value() + + @property + def optimization_plan(self) -> OptimizationPlan: + return create_optimization_plan( + self._settings.optimization_plan_start.get_value(), + self._settings.optimization_plan_stop.get_value(), + self._settings.optimization_plan_stride.get_value(), + ) + + @property + def optimizer(self) -> Optimizers: + return parse_optimizer(self._settings.optimizer.get_value()) + + @property + def step_size(self) -> float: + return self._settings.step_size.get_value() + + @property + def optimizer_params(self) -> dict: # TODO + return dict() + + @property + def slice_spacing_options(self) -> SliceSpacingOptions: + optimizer = parse_optimizer(self._settings.optimize_slice_spacing_optimizer.get_value()) + return SliceSpacingOptions( + optimizable=self._settings.optimize_slice_spacing.get_value(), + optimization_plan=create_optimization_plan( + self._settings.optimize_slice_spacing_start.get_value(), + self._settings.optimize_slice_spacing_stop.get_value(), + self._settings.optimize_slice_spacing_stride.get_value(), + ), + optimizer=optimizer, + step_size=self._settings.optimize_slice_spacing_step_size.get_value(), + ) + + @property + def l1_norm_constraint(self) -> ObjectL1NormConstraintOptions: + return ObjectL1NormConstraintOptions( + enabled=self._settings.constrain_l1_norm.get_value(), + optimization_plan=create_optimization_plan( + self._settings.constrain_l1_norm_start.get_value(), + self._settings.constrain_l1_norm_stop.get_value(), + self._settings.constrain_l1_norm_stride.get_value(), + ), + weight=self._settings.constrain_l1_norm_weight.get_value(), + ) + + @property + def l2_norm_constraint(self) -> ObjectL2NormConstraintOptions: + return ObjectL2NormConstraintOptions( + enabled=self._settings.constrain_l2_norm.get_value(), + optimization_plan=create_optimization_plan( + self._settings.constrain_l2_norm_start.get_value(), + self._settings.constrain_l2_norm_stop.get_value(), + self._settings.constrain_l2_norm_stride.get_value(), + ), + weight=self._settings.constrain_l2_norm_weight.get_value(), + ) + + @property + def smoothness_constraint(self) -> ObjectSmoothnessConstraintOptions: + return ObjectSmoothnessConstraintOptions( + enabled=self._settings.constrain_smoothness.get_value(), + optimization_plan=create_optimization_plan( + self._settings.constrain_smoothness_start.get_value(), + self._settings.constrain_smoothness_stop.get_value(), + self._settings.constrain_smoothness_stride.get_value(), + ), + alpha=self._settings.constrain_smoothness_alpha.get_value(), + ) + + @property + def total_variation(self) -> ObjectTotalVariationOptions: + return ObjectTotalVariationOptions( + enabled=self._settings.constrain_total_variation.get_value(), + optimization_plan=create_optimization_plan( + self._settings.constrain_total_variation_start.get_value(), + self._settings.constrain_total_variation_stop.get_value(), + self._settings.constrain_total_variation_stride.get_value(), + ), + weight=self._settings.constrain_total_variation_weight.get_value(), + ) + + @property + def remove_grid_artifacts(self) -> RemoveGridArtifactsOptions: + direction_str = self._settings.remove_grid_artifacts_direction.get_value() + + try: + direction = Directions[direction_str.upper()] + except KeyError: + logger.warning(f'Failed to parse direction "{direction_str}"!') + direction = Directions.XY + + return RemoveGridArtifactsOptions( + enabled=self._settings.remove_grid_artifacts.get_value(), + optimization_plan=create_optimization_plan( + self._settings.remove_grid_artifacts_start.get_value(), + self._settings.remove_grid_artifacts_stop.get_value(), + self._settings.remove_grid_artifacts_stride.get_value(), + ), + period_x_m=self._settings.remove_grid_artifacts_period_x_m.get_value(), + period_y_m=self._settings.remove_grid_artifacts_period_y_m.get_value(), + window_size=self._settings.remove_grid_artifacts_window_size_px.get_value(), + direction=direction, + ) + + @property + def multislice_regularization(self) -> ObjectMultisliceRegularizationOptions: + unwrap_image_grad_method_str = ( + self._settings.regularize_multislice_unwrap_phase_image_gradient_method.get_value() + ) + + try: + unwrap_image_grad_method = ImageGradientMethods[unwrap_image_grad_method_str.upper()] + except KeyError: + logger.warning( + f'Failed to parse image gradient method "{unwrap_image_grad_method_str}"!' + ) + unwrap_image_grad_method = ImageGradientMethods.FOURIER_SHIFT + + unwrap_image_integration_method_str = ( + self._settings.regularize_multislice_unwrap_phase_image_integration_method.get_value() + ) + + try: + unwrap_image_integration_method = ImageIntegrationMethods[ + unwrap_image_integration_method_str.upper() + ] + except KeyError: + logger.warning( + f'Failed to parse image integrationient method "{unwrap_image_integration_method_str}"!' + ) + unwrap_image_integration_method = ImageIntegrationMethods.DECONVOLUTION + + return ObjectMultisliceRegularizationOptions( + enabled=self._settings.regularize_multislice.get_value(), + optimization_plan=create_optimization_plan( + self._settings.regularize_multislice_start.get_value(), + self._settings.regularize_multislice_stop.get_value(), + self._settings.regularize_multislice_stride.get_value(), + ), + weight=self._settings.regularize_multislice_weight.get_value(), + unwrap_phase=self._settings.regularize_multislice_unwrap_phase.get_value(), + unwrap_image_grad_method=unwrap_image_grad_method, + unwrap_image_integration_method=unwrap_image_integration_method, + ) + + @property + def patch_interpolation_method(self) -> PatchInterpolationMethods: + method_str = self._settings.patch_interpolator.get_value() + + try: + return PatchInterpolationMethods[method_str.upper()] + except KeyError: + logger.warning(f'Failed to parse patch interpolation method "{method_str}"!') + return PatchInterpolationMethods.FOURIER + + @property + def remove_object_probe_ambiguity(self) -> RemoveObjectProbeAmbiguityOptions: + return RemoveObjectProbeAmbiguityOptions( + enabled=self._settings.remove_object_probe_ambiguity.get_value(), + optimization_plan=create_optimization_plan( + self._settings.remove_object_probe_ambiguity_start.get_value(), + self._settings.remove_object_probe_ambiguity_stop.get_value(), + self._settings.remove_object_probe_ambiguity_stride.get_value(), + ), + ) + + @property + def build_preconditioner_with_all_modes(self) -> bool: + return self._settings.build_preconditioner_with_all_modes.get_value() + + @property + def determine_position_origin_coords_by(self) -> ObjectPosOriginCoordsMethods: + return ObjectPosOriginCoordsMethods.SPECIFIED + + def get_initial_guess(self, object_: Object) -> ComplexArrayType: + return object_.get_array() + + def get_slice_spacings_m(self, object_: Object) -> RealArrayType | None: + slice_spacings_m = object_.layer_spacing_m + return numpy.array(slice_spacings_m) if slice_spacings_m else None + + def get_pixel_size_m(self, object_: Object) -> float: + pixel_geometry = object_.get_pixel_geometry() + return pixel_geometry.width_m + + def get_pixel_aspect_ratio(self, object_: Object) -> float: + pixel_geometry = object_.get_pixel_geometry() + return pixel_geometry.aspect_ratio + + def get_position_origin_coords(self, object_: Object) -> RealArrayType: + # TODO return numpy.zeros(2) + return torch.zeros(2) # type: ignore + + +class PtyChiProbeOptionsHelper: + def __init__(self, settings: PtyChiProbeSettings) -> None: + self._settings = settings + + @property + def optimizable(self) -> bool: + return self._settings.is_optimizable.get_value() + + @property + def optimization_plan(self) -> OptimizationPlan: + return create_optimization_plan( + self._settings.optimization_plan_start.get_value(), + self._settings.optimization_plan_stop.get_value(), + self._settings.optimization_plan_stride.get_value(), + ) + + @property + def optimizer(self) -> Optimizers: + return parse_optimizer(self._settings.optimizer.get_value()) + + @property + def step_size(self) -> float: + return self._settings.step_size.get_value() + + @property + def optimizer_params(self) -> dict: # TODO + return dict() + + @property + def orthogonalize_incoherent_modes(self) -> ProbeOrthogonalizeIncoherentModesOptions: + method_str = self._settings.orthogonalize_incoherent_modes_method.get_value() + + try: + method = OrthogonalizationMethods[method_str.upper()] + except KeyError: + logger.warning(f'Failed to parse batching mode "{method_str}"!') + method = OrthogonalizationMethods.GS + + return ProbeOrthogonalizeIncoherentModesOptions( + enabled=self._settings.orthogonalize_incoherent_modes.get_value(), + optimization_plan=create_optimization_plan( + self._settings.orthogonalize_incoherent_modes_start.get_value(), + self._settings.orthogonalize_incoherent_modes_stop.get_value(), + self._settings.orthogonalize_incoherent_modes_stride.get_value(), + ), + method=method, + ) + + @property + def orthogonalize_opr_modes(self) -> ProbeOrthogonalizeOPRModesOptions: + return ProbeOrthogonalizeOPRModesOptions( + enabled=self._settings.orthogonalize_opr_modes.get_value(), + optimization_plan=create_optimization_plan( + self._settings.orthogonalize_opr_modes_start.get_value(), + self._settings.orthogonalize_opr_modes_stop.get_value(), + self._settings.orthogonalize_opr_modes_stride.get_value(), + ), + ) + + @property + def support_constraint(self) -> ProbeSupportConstraintOptions: + return ProbeSupportConstraintOptions( + enabled=self._settings.constrain_support.get_value(), + optimization_plan=create_optimization_plan( + self._settings.constrain_support_start.get_value(), + self._settings.constrain_support_stop.get_value(), + self._settings.constrain_support_stride.get_value(), + ), + threshold=self._settings.constrain_support_threshold.get_value(), + ) + + @property + def center_constraint(self) -> ProbeCenterConstraintOptions: + return ProbeCenterConstraintOptions( + enabled=self._settings.constrain_center.get_value(), + optimization_plan=create_optimization_plan( + self._settings.constrain_center_start.get_value(), + self._settings.constrain_center_stop.get_value(), + self._settings.constrain_center_stride.get_value(), + ), + ) + + @property + def eigenmode_update_relaxation(self) -> float: + return self._settings.relax_eigenmode_update.get_value() + + def get_initial_guess(self, probe: ProbeSequence) -> ComplexArrayType: + return probe.get_array() + + def get_power_constraint(self, metadata: ProductMetadata) -> ProbePowerConstraintOptions: + return ProbePowerConstraintOptions( + enabled=self._settings.constrain_probe_power.get_value(), + optimization_plan=create_optimization_plan( + self._settings.constrain_probe_power_start.get_value(), + self._settings.constrain_probe_power_stop.get_value(), + self._settings.constrain_probe_power_stride.get_value(), + ), + probe_power=metadata.probe_photon_count, + ) + + +class PtyChiProbePositionOptionsHelper: + def __init__(self, settings: PtyChiProbePositionSettings) -> None: + self._settings = settings + self._affine_dof = PtyChiAffineDegreesOfFreedomBitField( + settings.constrain_affine_transform_degrees_of_freedom + ) + + @property + def optimizable(self) -> bool: + return self._settings.is_optimizable.get_value() + + @property + def optimization_plan(self) -> OptimizationPlan: + return create_optimization_plan( + self._settings.optimization_plan_start.get_value(), + self._settings.optimization_plan_stop.get_value(), + self._settings.optimization_plan_stride.get_value(), + ) + + @property + def optimizer(self) -> Optimizers: + return parse_optimizer(self._settings.optimizer.get_value()) + + @property + def step_size(self) -> float: + return self._settings.step_size.get_value() + + @property + def optimizer_params(self) -> dict: # TODO + return dict() + + @property + def constrain_position_mean(self) -> bool: + return self._settings.constrain_centroid.get_value() + + @property + def correction_options(self) -> PositionCorrectionOptions: + correction_type_str = self._settings.correction_type.get_value() + + try: + correction_type = PositionCorrectionTypes[correction_type_str.upper()] + except KeyError: + logger.warning(f'Failed to parse correction type "{correction_type_str}"!') + correction_type = PositionCorrectionTypes.GRADIENT + + differentiation_method_str = self._settings.differentiation_method.get_value() + + try: + differentiation_method = ImageGradientMethods[differentiation_method_str.upper()] + except KeyError: + logger.warning( + f'Failed to parse differentiation method "{differentiation_method_str}"!' + ) + differentiation_method = ImageGradientMethods.GAUSSIAN + + update_magnitude_limit = ( + self._settings.update_magnitude_limit.get_value() + if self._settings.limit_update_magnitude.get_value() + else float('inf') + ) + + return PositionCorrectionOptions( + correction_type=correction_type, + differentiation_method=differentiation_method, + cross_correlation_scale=self._settings.cross_correlation_scale.get_value(), + cross_correlation_real_space_width=self._settings.cross_correlation_real_space_width.get_value(), + cross_correlation_probe_threshold=self._settings.cross_correlation_probe_threshold.get_value(), + slice_for_correction=0, # TODO + update_magnitude_limit=update_magnitude_limit, + ) + + @property + def affine_transform_constraint(self) -> PositionAffineTransformConstraintOptions: + degrees_of_freedom: list[AffineDegreesOfFreedom] = list() + + if self._affine_dof.is_bit_set(PtyChiAffineDegreesOfFreedom.TRANSLATION): + degrees_of_freedom.append(AffineDegreesOfFreedom.TRANSLATION) + + if self._affine_dof.is_bit_set(PtyChiAffineDegreesOfFreedom.ROTATION): + degrees_of_freedom.append(AffineDegreesOfFreedom.ROTATION) + + if self._affine_dof.is_bit_set(PtyChiAffineDegreesOfFreedom.SCALING): + degrees_of_freedom.append(AffineDegreesOfFreedom.SCALE) + + if self._affine_dof.is_bit_set(PtyChiAffineDegreesOfFreedom.SHEARING): + degrees_of_freedom.append(AffineDegreesOfFreedom.SHEAR) + + if self._affine_dof.is_bit_set(PtyChiAffineDegreesOfFreedom.ASYMMETRY): + degrees_of_freedom.append(AffineDegreesOfFreedom.ASSYMETRY) + + return PositionAffineTransformConstraintOptions( + enabled=self._settings.constrain_affine_transform.get_value(), + optimization_plan=create_optimization_plan( + self._settings.constrain_affine_transform_start.get_value(), + self._settings.constrain_affine_transform_stop.get_value(), + self._settings.constrain_affine_transform_stride.get_value(), + ), + degrees_of_freedom=degrees_of_freedom, + position_weight_update_interval=self._settings.constrain_affine_transform_position_weight_update_interval.get_value(), + apply_constraint=self._settings.constrain_affine_transform_apply_constraint.get_value(), + max_expected_error=self._settings.constrain_affine_transform_max_expected_error_px.get_value(), + ) + + def get_positions_px( + self, scan: PositionSequence, object_geometry: ObjectGeometry + ) -> tuple[RealArrayType, RealArrayType]: + position_x_px: list[float] = list() + position_y_px: list[float] = list() + + for scan_point in scan: + object_point = object_geometry.map_scan_point_to_object_point(scan_point) + position_x_px.append(object_point.position_x_px) + position_y_px.append(object_point.position_y_px) + + return numpy.array(position_x_px), numpy.array(position_y_px) + + +class PtyChiOPROptionsHelper: + def __init__(self, settings: PtyChiOPRSettings) -> None: + self._settings = settings + + @property + def optimizable(self) -> bool: + return self._settings.is_optimizable.get_value() + + @property + def optimization_plan(self) -> OptimizationPlan: + return create_optimization_plan( + self._settings.optimization_plan_start.get_value(), + self._settings.optimization_plan_stop.get_value(), + self._settings.optimization_plan_stride.get_value(), + ) + + @property + def optimizer(self) -> Optimizers: + return parse_optimizer(self._settings.optimizer.get_value()) + + @property + def step_size(self) -> float: + return self._settings.step_size.get_value() + + @property + def optimizer_params(self) -> dict: # TODO + return dict() + + @property + def optimize_eigenmode_weights(self) -> bool: + return self._settings.optimize_eigenmode_weights.get_value() + + @property + def optimize_intensity_variation(self) -> bool: + return self._settings.optimize_intensities.get_value() + + @property + def smoothing(self) -> OPRModeWeightsSmoothingOptions: + method_str = self._settings.smoothing_method.get_value() + + try: + method: OPRWeightSmoothingMethods | None = OPRWeightSmoothingMethods[method_str.upper()] + except KeyError: + logger.debug('OPR weight smoothing method is None.') + method = None + + return OPRModeWeightsSmoothingOptions( + enabled=self._settings.smooth_mode_weights.get_value(), + optimization_plan=create_optimization_plan( + self._settings.smooth_mode_weights_start.get_value(), + self._settings.smooth_mode_weights_stop.get_value(), + self._settings.smooth_mode_weights_stride.get_value(), + ), + method=method, + polynomial_degree=self._settings.polynomial_smoothing_degree.get_value(), + ) + + @property + def update_relaxation(self) -> float: + return self._settings.relax_update.get_value() + + def get_initial_weights(self, probe: ProbeSequence) -> RealArrayType: + try: + return probe.get_opr_weights() + except ValueError: + pass + + initial_weights = numpy.zeros((probe.num_coherent_modes)) + initial_weights[0] = 1.0 + return initial_weights + + +class PtyChiOptionsHelper: + def __init__( + self, + reconstructor_settings: PtyChiSettings, + object_settings: PtyChiObjectSettings, + probe_settings: PtyChiProbeSettings, + probe_position_settings: PtyChiProbePositionSettings, + opr_settings: PtyChiOPRSettings, + pattern_sizer: PatternSizer, + ) -> None: + self._reconstructor_settings = reconstructor_settings + self._pattern_sizer = pattern_sizer + + self.reconstructor_helper = PtyChiReconstructorOptionsHelper(reconstructor_settings) + self.object_helper = PtyChiObjectOptionsHelper(object_settings) + self.probe_helper = PtyChiProbeOptionsHelper(probe_settings) + self.probe_position_helper = PtyChiProbePositionOptionsHelper(probe_position_settings) + self.opr_helper = PtyChiOPROptionsHelper(opr_settings) + + def create_data_options(self, parameters: ReconstructInput) -> PtychographyDataOptions: + metadata = parameters.product.metadata + pixel_geometry = self._pattern_sizer.get_processed_pixel_geometry() + free_space_propagation_distance_m = ( + numpy.inf + if self._reconstructor_settings.use_far_field_propagation + else metadata.detector_distance_m + ) + return PtychographyDataOptions( + data=parameters.patterns, + free_space_propagation_distance_m=free_space_propagation_distance_m, + wavelength_m=metadata.probe_wavelength_m, + fft_shift=self._reconstructor_settings.fft_shift_diffraction_patterns.get_value(), + detector_pixel_size_m=pixel_geometry.width_m, + valid_pixel_mask=numpy.logical_not(parameters.bad_pixels), + save_data_on_device=self._reconstructor_settings.save_data_on_device.get_value(), + ) + + def create_product( + self, + product: Product, + position_x_px: torch.Tensor | numpy.ndarray, + position_y_px: torch.Tensor | numpy.ndarray, + probe_array: torch.Tensor | numpy.ndarray, + object_array: torch.Tensor | numpy.ndarray, + opr_weights: torch.Tensor | numpy.ndarray, + costs: Sequence[float], + ) -> Product: + object_in = product.object_ + object_out = Object( + array=numpy.array(object_array), + layer_spacing_m=object_in.layer_spacing_m, + pixel_geometry=object_in.get_pixel_geometry(), + center=object_in.get_center(), + ) + + probe_out = ProbeSequence( + array=numpy.array(probe_array[0]), + opr_weights=numpy.array(opr_weights), + pixel_geometry=product.probes.get_pixel_geometry(), + ) + + corrected_scan_points: list[ScanPoint] = list() + object_geometry = object_in.get_geometry() + rx_px = object_geometry.width_px / 2 + ry_px = object_geometry.height_px / 2 + + for uncorrected_point, pos_x_px, pos_y_px in zip( + product.positions, position_x_px, position_y_px + ): + object_point = ObjectPoint( + index=uncorrected_point.index, + position_x_px=float(pos_x_px + rx_px), + position_y_px=float(pos_y_px + ry_px), + ) + scan_point = object_geometry.map_object_point_to_scan_point(object_point) + corrected_scan_points.append(scan_point) + + scan_out = PositionSequence(corrected_scan_points) + + return Product( + metadata=product.metadata, + positions=scan_out, + probes=probe_out, + object_=object_out, + costs=costs, + ) diff --git a/src/ptychodus/model/ptychi/lsqml.py b/src/ptychodus/model/ptychi/lsqml.py new file mode 100644 index 00000000..67304f16 --- /dev/null +++ b/src/ptychodus/model/ptychi/lsqml.py @@ -0,0 +1,202 @@ +from collections.abc import Sequence +import logging + + +from ptychi.api import ( + LSQMLOPRModeWeightsOptions, + LSQMLObjectOptions, + LSQMLOptions, + LSQMLProbeOptions, + LSQMLProbePositionOptions, + LSQMLReconstructorOptions, + NoiseModels, +) +from ptychi.api.task import PtychographyTask + +from ptychodus.api.object import Object, ObjectGeometry +from ptychodus.api.probe import ProbeSequence +from ptychodus.api.product import ProductMetadata +from ptychodus.api.reconstructor import ReconstructInput, ReconstructOutput, Reconstructor +from ptychodus.api.scan import PositionSequence + +from .helper import PtyChiOptionsHelper +from .settings import PtyChiLSQMLSettings + +logger = logging.getLogger(__name__) + + +class LSQMLReconstructor(Reconstructor): + def __init__(self, options_helper: PtyChiOptionsHelper, settings: PtyChiLSQMLSettings) -> None: + super().__init__() + self._options_helper = options_helper + self._settings = settings + + @property + def name(self) -> str: + return 'LSQML' + + def _create_reconstructor_options(self) -> LSQMLReconstructorOptions: + helper = self._options_helper.reconstructor_helper + + #### + + noise_model_str = self._settings.noise_model.get_value() + + try: + noise_model = NoiseModels[noise_model_str.upper()] + except KeyError: + logger.warning('Failed to parse batching mode "{noise_model_str}"!') + noise_model = NoiseModels.GAUSSIAN + + #### + + momentum_acceleration_gradient_mixing_factor: float | None = None + + if self._settings.use_momentum_acceleration_gradient_mixing_factor.get_value(): + momentum_acceleration_gradient_mixing_factor = ( + self._settings.momentum_acceleration_gradient_mixing_factor.get_value() + ) + + #### + + return LSQMLReconstructorOptions( + num_epochs=helper.num_epochs, + batch_size=helper.batch_size, + batching_mode=helper.batching_mode, + compact_mode_update_clustering=helper.compact_mode_update_clustering, + compact_mode_update_clustering_stride=helper.compact_mode_update_clustering_stride, + default_device=helper.default_device, + default_dtype=helper.default_dtype, + use_double_precision_for_fft=helper.use_double_precision_for_fft, + allow_nondeterministic_algorithms=helper.allow_nondeterministic_algorithms, + random_seed=helper.random_seed, + displayed_loss_function=helper.displayed_loss_function, + forward_model_options=helper.forward_model_options, + noise_model=noise_model, + gaussian_noise_std=self._settings.gaussian_noise_deviation.get_value(), + solve_obj_prb_step_size_jointly_for_first_slice_in_multislice=self._settings.solve_object_probe_step_size_jointly_for_first_slice_in_multislice.get_value(), + solve_step_sizes_only_using_first_probe_mode=self._settings.solve_step_sizes_only_using_first_probe_mode.get_value(), + momentum_acceleration_gain=self._settings.momentum_acceleration_gain.get_value(), + momentum_acceleration_gradient_mixing_factor=momentum_acceleration_gradient_mixing_factor, + rescale_probe_intensity_in_first_epoch=self._settings.rescale_probe_intensity_in_first_epoch.get_value(), + ) + + def _create_object_options(self, object_: Object) -> LSQMLObjectOptions: + helper = self._options_helper.object_helper + return LSQMLObjectOptions( + optimizable=helper.optimizable, + optimization_plan=helper.optimization_plan, + optimizer=helper.optimizer, + step_size=helper.step_size, + optimizer_params=helper.optimizer_params, + initial_guess=helper.get_initial_guess(object_), + slice_spacings_m=helper.get_slice_spacings_m(object_), + slice_spacing_options=helper.slice_spacing_options, + pixel_size_m=helper.get_pixel_size_m(object_), + pixel_size_aspect_ratio=helper.get_pixel_aspect_ratio(object_), + l1_norm_constraint=helper.l1_norm_constraint, + l2_norm_constraint=helper.l2_norm_constraint, + smoothness_constraint=helper.smoothness_constraint, + total_variation=helper.total_variation, + remove_grid_artifacts=helper.remove_grid_artifacts, + multislice_regularization=helper.multislice_regularization, + patch_interpolation_method=helper.patch_interpolation_method, + remove_object_probe_ambiguity=helper.remove_object_probe_ambiguity, + build_preconditioner_with_all_modes=helper.build_preconditioner_with_all_modes, + determine_position_origin_coords_by=helper.determine_position_origin_coords_by, + position_origin_coords=helper.get_position_origin_coords(object_), + optimal_step_size_scaler=self._settings.object_optimal_step_size_scaler.get_value(), + multimodal_update=self._settings.object_multimodal_update.get_value(), + ) + + def _create_probe_options( + self, probes: ProbeSequence, metadata: ProductMetadata + ) -> LSQMLProbeOptions: + helper = self._options_helper.probe_helper + return LSQMLProbeOptions( + optimizable=helper.optimizable, + optimization_plan=helper.optimization_plan, + optimizer=helper.optimizer, + step_size=helper.step_size, + optimizer_params=helper.optimizer_params, + initial_guess=helper.get_initial_guess(probes), + power_constraint=helper.get_power_constraint(metadata), + orthogonalize_incoherent_modes=helper.orthogonalize_incoherent_modes, + orthogonalize_opr_modes=helper.orthogonalize_opr_modes, + support_constraint=helper.support_constraint, + center_constraint=helper.center_constraint, + eigenmode_update_relaxation=helper.eigenmode_update_relaxation, + optimal_step_size_scaler=self._settings.probe_optimal_step_size_scaler.get_value(), + ) + + def _create_probe_position_options( + self, scan: PositionSequence, object_geometry: ObjectGeometry + ) -> LSQMLProbePositionOptions: + helper = self._options_helper.probe_position_helper + position_x_px, position_y_px = helper.get_positions_px(scan, object_geometry) + return LSQMLProbePositionOptions( + optimizable=helper.optimizable, + optimization_plan=helper.optimization_plan, + optimizer=helper.optimizer, + step_size=helper.step_size, + optimizer_params=helper.optimizer_params, + position_x_px=position_x_px, + position_y_px=position_y_px, + constrain_position_mean=helper.constrain_position_mean, + correction_options=helper.correction_options, + affine_transform_constraint=helper.affine_transform_constraint, + ) + + def _create_opr_mode_weight_options(self, probes: ProbeSequence) -> LSQMLOPRModeWeightsOptions: + helper = self._options_helper.opr_helper + return LSQMLOPRModeWeightsOptions( + optimizable=helper.optimizable, + optimization_plan=helper.optimization_plan, + optimizer=helper.optimizer, + step_size=helper.step_size, + optimizer_params=helper.optimizer_params, + initial_weights=helper.get_initial_weights(probes), + optimize_eigenmode_weights=helper.optimize_eigenmode_weights, + optimize_intensity_variation=helper.optimize_intensity_variation, + smoothing=helper.smoothing, + update_relaxation=helper.update_relaxation, + ) + + def _create_task_options(self, parameters: ReconstructInput) -> LSQMLOptions: + product = parameters.product + return LSQMLOptions( + data_options=self._options_helper.create_data_options(parameters), + reconstructor_options=self._create_reconstructor_options(), + object_options=self._create_object_options(product.object_), + probe_options=self._create_probe_options(product.probes, product.metadata), + probe_position_options=self._create_probe_position_options( + product.positions, product.object_.get_geometry() + ), + opr_mode_weight_options=self._create_opr_mode_weight_options(product.probes), + ) + + def reconstruct(self, parameters: ReconstructInput) -> ReconstructOutput: + task_options = self._create_task_options(parameters) + task = PtychographyTask(task_options) + task.run() # TODO (n_epochs: int | None = None) + + costs: Sequence[float] = list() + task_reconstructor = task.reconstructor + + if task_reconstructor is not None: + loss_tracker = task_reconstructor.loss_tracker + # TODO update api to include epoch and loss + # epoch = loss_tracker.table['epoch'].to_numpy() + loss = loss_tracker.table['loss'].to_numpy() + costs = [float(x) for x in loss.flatten()] + + product = self._options_helper.create_product( + product=parameters.product, + position_x_px=task.get_probe_positions_x(as_numpy=True), + position_y_px=task.get_probe_positions_y(as_numpy=True), + probe_array=task.get_data_to_cpu('probe', as_numpy=True), + object_array=task.get_data_to_cpu('object', as_numpy=True), + opr_weights=task.get_data_to_cpu('opr_mode_weights', as_numpy=True), + costs=costs, + ) + return ReconstructOutput(product, 0) diff --git a/src/ptychodus/model/ptychi/pie.py b/src/ptychodus/model/ptychi/pie.py new file mode 100644 index 00000000..17ee578e --- /dev/null +++ b/src/ptychodus/model/ptychi/pie.py @@ -0,0 +1,171 @@ +from collections.abc import Sequence +import logging + + +from ptychi.api import ( + PIEOPRModeWeightsOptions, + PIEObjectOptions, + PIEOptions, + PIEProbeOptions, + PIEProbePositionOptions, + PIEReconstructorOptions, +) +from ptychi.api.task import PtychographyTask + +from ptychodus.api.object import Object, ObjectGeometry +from ptychodus.api.probe import ProbeSequence +from ptychodus.api.product import ProductMetadata +from ptychodus.api.reconstructor import ReconstructInput, ReconstructOutput, Reconstructor +from ptychodus.api.scan import PositionSequence + +from .helper import PtyChiOptionsHelper +from .settings import PtyChiPIESettings + +logger = logging.getLogger(__name__) + + +class PIEReconstructor(Reconstructor): + def __init__(self, options_helper: PtyChiOptionsHelper, settings: PtyChiPIESettings) -> None: + super().__init__() + self._options_helper = options_helper + self._settings = settings + + @property + def name(self) -> str: + return 'PIE' + + def _create_reconstructor_options(self) -> PIEReconstructorOptions: + helper = self._options_helper.reconstructor_helper + return PIEReconstructorOptions( + num_epochs=helper.num_epochs, + batch_size=helper.batch_size, + batching_mode=helper.batching_mode, + compact_mode_update_clustering=helper.compact_mode_update_clustering, + compact_mode_update_clustering_stride=helper.compact_mode_update_clustering_stride, + default_device=helper.default_device, + default_dtype=helper.default_dtype, + use_double_precision_for_fft=helper.use_double_precision_for_fft, + allow_nondeterministic_algorithms=helper.allow_nondeterministic_algorithms, + random_seed=helper.random_seed, + displayed_loss_function=helper.displayed_loss_function, + forward_model_options=helper.forward_model_options, + ) + + def _create_object_options(self, object_: Object) -> PIEObjectOptions: + helper = self._options_helper.object_helper + return PIEObjectOptions( + optimizable=helper.optimizable, + optimization_plan=helper.optimization_plan, + optimizer=helper.optimizer, + step_size=helper.step_size, + optimizer_params=helper.optimizer_params, + initial_guess=helper.get_initial_guess(object_), + slice_spacings_m=helper.get_slice_spacings_m(object_), + slice_spacing_options=helper.slice_spacing_options, + pixel_size_m=helper.get_pixel_size_m(object_), + pixel_size_aspect_ratio=helper.get_pixel_aspect_ratio(object_), + l1_norm_constraint=helper.l1_norm_constraint, + l2_norm_constraint=helper.l2_norm_constraint, + smoothness_constraint=helper.smoothness_constraint, + total_variation=helper.total_variation, + remove_grid_artifacts=helper.remove_grid_artifacts, + multislice_regularization=helper.multislice_regularization, + patch_interpolation_method=helper.patch_interpolation_method, + remove_object_probe_ambiguity=helper.remove_object_probe_ambiguity, + build_preconditioner_with_all_modes=helper.build_preconditioner_with_all_modes, + determine_position_origin_coords_by=helper.determine_position_origin_coords_by, + position_origin_coords=helper.get_position_origin_coords(object_), + alpha=self._settings.object_alpha.get_value(), + ) + + def _create_probe_options( + self, probes: ProbeSequence, metadata: ProductMetadata + ) -> PIEProbeOptions: + helper = self._options_helper.probe_helper + return PIEProbeOptions( + optimizable=helper.optimizable, + optimization_plan=helper.optimization_plan, + optimizer=helper.optimizer, + step_size=helper.step_size, + optimizer_params=helper.optimizer_params, + initial_guess=helper.get_initial_guess(probes), + power_constraint=helper.get_power_constraint(metadata), + orthogonalize_incoherent_modes=helper.orthogonalize_incoherent_modes, + orthogonalize_opr_modes=helper.orthogonalize_opr_modes, + support_constraint=helper.support_constraint, + center_constraint=helper.center_constraint, + eigenmode_update_relaxation=helper.eigenmode_update_relaxation, + alpha=self._settings.probe_alpha.get_value(), + ) + + def _create_probe_position_options( + self, scan: PositionSequence, object_geometry: ObjectGeometry + ) -> PIEProbePositionOptions: + helper = self._options_helper.probe_position_helper + position_x_px, position_y_px = helper.get_positions_px(scan, object_geometry) + return PIEProbePositionOptions( + optimizable=helper.optimizable, + optimization_plan=helper.optimization_plan, + optimizer=helper.optimizer, + step_size=helper.step_size, + optimizer_params=helper.optimizer_params, + position_x_px=position_x_px, + position_y_px=position_y_px, + constrain_position_mean=helper.constrain_position_mean, + correction_options=helper.correction_options, + affine_transform_constraint=helper.affine_transform_constraint, + ) + + def _create_opr_mode_weight_options(self, probes: ProbeSequence) -> PIEOPRModeWeightsOptions: + helper = self._options_helper.opr_helper + return PIEOPRModeWeightsOptions( + optimizable=helper.optimizable, + optimization_plan=helper.optimization_plan, + optimizer=helper.optimizer, + step_size=helper.step_size, + optimizer_params=helper.optimizer_params, + initial_weights=helper.get_initial_weights(probes), + optimize_eigenmode_weights=helper.optimize_eigenmode_weights, + optimize_intensity_variation=helper.optimize_intensity_variation, + smoothing=helper.smoothing, + update_relaxation=helper.update_relaxation, + ) + + def _create_task_options(self, parameters: ReconstructInput) -> PIEOptions: + product = parameters.product + return PIEOptions( + data_options=self._options_helper.create_data_options(parameters), + reconstructor_options=self._create_reconstructor_options(), + object_options=self._create_object_options(product.object_), + probe_options=self._create_probe_options(product.probes, product.metadata), + probe_position_options=self._create_probe_position_options( + product.positions, product.object_.get_geometry() + ), + opr_mode_weight_options=self._create_opr_mode_weight_options(product.probes), + ) + + def reconstruct(self, parameters: ReconstructInput) -> ReconstructOutput: + task_options = self._create_task_options(parameters) + task = PtychographyTask(task_options) + task.run() # TODO (n_epochs: int | None = None) + + costs: Sequence[float] = list() + task_reconstructor = task.reconstructor + + if task_reconstructor is not None: + loss_tracker = task_reconstructor.loss_tracker + # TODO update api to include epoch and loss + # epoch = loss_tracker.table['epoch'].to_numpy() + loss = loss_tracker.table['loss'].to_numpy() + costs = [float(x) for x in loss.flatten()] + + product = self._options_helper.create_product( + product=parameters.product, + position_x_px=task.get_probe_positions_x(as_numpy=True), + position_y_px=task.get_probe_positions_y(as_numpy=True), + probe_array=task.get_data_to_cpu('probe', as_numpy=True), + object_array=task.get_data_to_cpu('object', as_numpy=True), + opr_weights=task.get_data_to_cpu('opr_mode_weights', as_numpy=True), + costs=costs, + ) + return ReconstructOutput(product, 0) diff --git a/src/ptychodus/model/ptychi/rpie.py b/src/ptychodus/model/ptychi/rpie.py new file mode 100644 index 00000000..c4862d46 --- /dev/null +++ b/src/ptychodus/model/ptychi/rpie.py @@ -0,0 +1,171 @@ +from collections.abc import Sequence +import logging + + +from ptychi.api import ( + PIEOPRModeWeightsOptions, + PIEObjectOptions, + PIEProbeOptions, + PIEProbePositionOptions, + RPIEOptions, + RPIEReconstructorOptions, +) +from ptychi.api.task import PtychographyTask + +from ptychodus.api.object import Object, ObjectGeometry +from ptychodus.api.probe import ProbeSequence +from ptychodus.api.product import ProductMetadata +from ptychodus.api.reconstructor import ReconstructInput, ReconstructOutput, Reconstructor +from ptychodus.api.scan import PositionSequence + +from .helper import PtyChiOptionsHelper +from .settings import PtyChiPIESettings + +logger = logging.getLogger(__name__) + + +class RPIEReconstructor(Reconstructor): + def __init__(self, options_helper: PtyChiOptionsHelper, settings: PtyChiPIESettings) -> None: + super().__init__() + self._options_helper = options_helper + self._settings = settings + + @property + def name(self) -> str: + return 'rPIE' + + def _create_reconstructor_options(self) -> RPIEReconstructorOptions: + helper = self._options_helper.reconstructor_helper + return RPIEReconstructorOptions( + num_epochs=helper.num_epochs, + batch_size=helper.batch_size, + batching_mode=helper.batching_mode, + compact_mode_update_clustering=helper.compact_mode_update_clustering, + compact_mode_update_clustering_stride=helper.compact_mode_update_clustering_stride, + default_device=helper.default_device, + default_dtype=helper.default_dtype, + use_double_precision_for_fft=helper.use_double_precision_for_fft, + allow_nondeterministic_algorithms=helper.allow_nondeterministic_algorithms, + random_seed=helper.random_seed, + displayed_loss_function=helper.displayed_loss_function, + forward_model_options=helper.forward_model_options, + ) + + def _create_object_options(self, object_: Object) -> PIEObjectOptions: + helper = self._options_helper.object_helper + return PIEObjectOptions( + optimizable=helper.optimizable, + optimization_plan=helper.optimization_plan, + optimizer=helper.optimizer, + step_size=helper.step_size, + optimizer_params=helper.optimizer_params, + initial_guess=helper.get_initial_guess(object_), + slice_spacings_m=helper.get_slice_spacings_m(object_), + slice_spacing_options=helper.slice_spacing_options, + pixel_size_m=helper.get_pixel_size_m(object_), + pixel_size_aspect_ratio=helper.get_pixel_aspect_ratio(object_), + l1_norm_constraint=helper.l1_norm_constraint, + l2_norm_constraint=helper.l2_norm_constraint, + smoothness_constraint=helper.smoothness_constraint, + total_variation=helper.total_variation, + remove_grid_artifacts=helper.remove_grid_artifacts, + multislice_regularization=helper.multislice_regularization, + patch_interpolation_method=helper.patch_interpolation_method, + remove_object_probe_ambiguity=helper.remove_object_probe_ambiguity, + build_preconditioner_with_all_modes=helper.build_preconditioner_with_all_modes, + determine_position_origin_coords_by=helper.determine_position_origin_coords_by, + position_origin_coords=helper.get_position_origin_coords(object_), + alpha=self._settings.object_alpha.get_value(), + ) + + def _create_probe_options( + self, probes: ProbeSequence, metadata: ProductMetadata + ) -> PIEProbeOptions: + helper = self._options_helper.probe_helper + return PIEProbeOptions( + optimizable=helper.optimizable, + optimization_plan=helper.optimization_plan, + optimizer=helper.optimizer, + step_size=helper.step_size, + optimizer_params=helper.optimizer_params, + initial_guess=helper.get_initial_guess(probes), + power_constraint=helper.get_power_constraint(metadata), + orthogonalize_incoherent_modes=helper.orthogonalize_incoherent_modes, + orthogonalize_opr_modes=helper.orthogonalize_opr_modes, + support_constraint=helper.support_constraint, + center_constraint=helper.center_constraint, + eigenmode_update_relaxation=helper.eigenmode_update_relaxation, + alpha=self._settings.probe_alpha.get_value(), + ) + + def _create_probe_position_options( + self, scan: PositionSequence, object_geometry: ObjectGeometry + ) -> PIEProbePositionOptions: + helper = self._options_helper.probe_position_helper + position_x_px, position_y_px = helper.get_positions_px(scan, object_geometry) + return PIEProbePositionOptions( + optimizable=helper.optimizable, + optimization_plan=helper.optimization_plan, + optimizer=helper.optimizer, + step_size=helper.step_size, + optimizer_params=helper.optimizer_params, + position_x_px=position_x_px, + position_y_px=position_y_px, + constrain_position_mean=helper.constrain_position_mean, + correction_options=helper.correction_options, + affine_transform_constraint=helper.affine_transform_constraint, + ) + + def _create_opr_mode_weight_options(self, probes: ProbeSequence) -> PIEOPRModeWeightsOptions: + helper = self._options_helper.opr_helper + return PIEOPRModeWeightsOptions( + optimizable=helper.optimizable, + optimization_plan=helper.optimization_plan, + optimizer=helper.optimizer, + step_size=helper.step_size, + optimizer_params=helper.optimizer_params, + initial_weights=helper.get_initial_weights(probes), + optimize_eigenmode_weights=helper.optimize_eigenmode_weights, + optimize_intensity_variation=helper.optimize_intensity_variation, + smoothing=helper.smoothing, + update_relaxation=helper.update_relaxation, + ) + + def _create_task_options(self, parameters: ReconstructInput) -> RPIEOptions: + product = parameters.product + return RPIEOptions( + data_options=self._options_helper.create_data_options(parameters), + reconstructor_options=self._create_reconstructor_options(), + object_options=self._create_object_options(product.object_), + probe_options=self._create_probe_options(product.probes, product.metadata), + probe_position_options=self._create_probe_position_options( + product.positions, product.object_.get_geometry() + ), + opr_mode_weight_options=self._create_opr_mode_weight_options(product.probes), + ) + + def reconstruct(self, parameters: ReconstructInput) -> ReconstructOutput: + task_options = self._create_task_options(parameters) + task = PtychographyTask(task_options) + task.run() # TODO (n_epochs: int | None = None) + + costs: Sequence[float] = list() + task_reconstructor = task.reconstructor + + if task_reconstructor is not None: + loss_tracker = task_reconstructor.loss_tracker + # TODO update api to include epoch and loss + # epoch = loss_tracker.table['epoch'].to_numpy() + loss = loss_tracker.table['loss'].to_numpy() + costs = [float(x) for x in loss.flatten()] + + product = self._options_helper.create_product( + product=parameters.product, + position_x_px=task.get_probe_positions_x(as_numpy=True), + position_y_px=task.get_probe_positions_y(as_numpy=True), + probe_array=task.get_data_to_cpu('probe', as_numpy=True), + object_array=task.get_data_to_cpu('object', as_numpy=True), + opr_weights=task.get_data_to_cpu('opr_mode_weights', as_numpy=True), + costs=costs, + ) + return ReconstructOutput(product, 0) diff --git a/src/ptychodus/model/ptychi/settings.py b/src/ptychodus/model/ptychi/settings.py new file mode 100644 index 00000000..c97c5bee --- /dev/null +++ b/src/ptychodus/model/ptychi/settings.py @@ -0,0 +1,548 @@ +from ptychodus.api.observer import Observable, Observer +from ptychodus.api.settings import SettingsRegistry + + +class PtyChiSettings(Observable, Observer): + def __init__(self, registry: SettingsRegistry) -> None: + super().__init__() + self._group = registry.create_group('PtyChi') + self._group.add_observer(self) + + # ReconstructorOptions + self.num_epochs = self._group.create_integer_parameter('NumEpochs', 100, minimum=1) + self.batch_size = self._group.create_integer_parameter('BatchSize', 100, minimum=1) + self.batching_mode = self._group.create_string_parameter('BatchingMode', 'random') + self.compact_mode_update_clustering = self._group.create_integer_parameter( + 'CompactModeUpdateClustering', 1, minimum=0 + ) + self.use_devices = self._group.create_boolean_parameter('UseDevices', True) + self.use_double_precision = self._group.create_boolean_parameter( + 'UseDoublePrecision', False + ) + self.use_double_precision_for_fft = self._group.create_boolean_parameter( + 'UseDoublePrecisionForFFT', False + ) + self.allow_nondeterministic_algorithms = self._group.create_boolean_parameter( + 'AllowNondeterministicAlgorithms', True + ) + + # ForwardModelOptions + self.use_low_memory_mode = self._group.create_boolean_parameter('UseLowMemoryMode', False) + self.pad_for_shift = self._group.create_integer_parameter('PadForShift', 0, minimum=0) + + # PtychographyDataOptions + self.use_far_field_propagation = self._group.create_boolean_parameter( + 'UseFarFieldPropagation', True + ) + self.fft_shift_diffraction_patterns = self._group.create_boolean_parameter( + 'FFTShiftDiffractionPatterns', True + ) + self.save_data_on_device = self._group.create_boolean_parameter('SaveDataOnDevice', False) + + def _update(self, observable: Observable) -> None: + if observable is self._group: + self.notify_observers() + + +class PtyChiObjectSettings(Observable, Observer): + def __init__(self, registry: SettingsRegistry) -> None: + super().__init__() + self._group = registry.create_group('PtyChiObject') + self._group.add_observer(self) + + self.is_optimizable = self._group.create_boolean_parameter('IsOptimizable', True) + self.optimization_plan_start = self._group.create_integer_parameter( + 'OptimizationPlanStart', 0, minimum=0 + ) + self.optimization_plan_stop = self._group.create_integer_parameter( + 'OptimizationPlanStop', -1 + ) + self.optimization_plan_stride = self._group.create_integer_parameter( + 'OptimizationPlanStride', 1, minimum=1 + ) + self.optimizer = self._group.create_string_parameter('Optimizer', 'SGD') + self.step_size = self._group.create_real_parameter('StepSize', 1.0, minimum=0.0) + + self.optimize_slice_spacing = self._group.create_boolean_parameter( + 'OptimizeSliceSpacing', False + ) + self.optimize_slice_spacing_start = self._group.create_integer_parameter( + 'OptimizeSliceSpacingStart', 0, minimum=0 + ) + self.optimize_slice_spacing_stop = self._group.create_integer_parameter( + 'OptimizeSliceSpacingStop', -1 + ) + self.optimize_slice_spacing_stride = self._group.create_integer_parameter( + 'OptimizeSliceSpacingStride', 1, minimum=1 + ) + self.optimize_slice_spacing_optimizer = self._group.create_string_parameter( + 'OptimizeSliceSpacingOptimizer', 'SGD' + ) + self.optimize_slice_spacing_step_size = self._group.create_real_parameter( + 'OptimizeSliceSpacingStepSize', 1.0e-10, minimum=0.0 + ) + + self.constrain_l1_norm = self._group.create_boolean_parameter('ConstrainL1Norm', False) + self.constrain_l1_norm_start = self._group.create_integer_parameter( + 'ConstrainL1NormStart', 0, minimum=0 + ) + self.constrain_l1_norm_stop = self._group.create_integer_parameter( + 'ConstrainL1NormStop', -1 + ) + self.constrain_l1_norm_stride = self._group.create_integer_parameter( + 'ConstrainL1NormStride', 1, minimum=1 + ) + self.constrain_l1_norm_weight = self._group.create_real_parameter( + 'ConstrainL1NormWeight', 0.0, minimum=0.0 + ) + + self.constrain_l2_norm = self._group.create_boolean_parameter('ConstrainL2Norm', False) + self.constrain_l2_norm_start = self._group.create_integer_parameter( + 'ConstrainL2NormStart', 0, minimum=0 + ) + self.constrain_l2_norm_stop = self._group.create_integer_parameter( + 'ConstrainL2NormStop', -1 + ) + self.constrain_l2_norm_stride = self._group.create_integer_parameter( + 'ConstrainL2NormStride', 1, minimum=1 + ) + self.constrain_l2_norm_weight = self._group.create_real_parameter( + 'ConstrainL2NormWeight', 0.0, minimum=0.0 + ) + + self.constrain_smoothness = self._group.create_boolean_parameter( + 'ConstrainSmoothness', False + ) + self.constrain_smoothness_start = self._group.create_integer_parameter( + 'ConstrainSmoothnessStart', 0, minimum=0 + ) + self.constrain_smoothness_stop = self._group.create_integer_parameter( + 'ConstrainSmoothnessStop', -1 + ) + self.constrain_smoothness_stride = self._group.create_integer_parameter( + 'ConstrainSmoothnessStride', 1, minimum=1 + ) + self.constrain_smoothness_alpha = self._group.create_real_parameter( + 'ConstrainSmoothnessAlpha', 0.0, minimum=0.0, maximum=1.0 / 8 + ) + + self.constrain_total_variation = self._group.create_boolean_parameter( + 'ConstrainTotalVariation', False + ) + self.constrain_total_variation_start = self._group.create_integer_parameter( + 'ConstrainTotalVariationStart', 0, minimum=0 + ) + self.constrain_total_variation_stop = self._group.create_integer_parameter( + 'ConstrainTotalVariationStop', -1 + ) + self.constrain_total_variation_stride = self._group.create_integer_parameter( + 'ConstrainTotalVariationStride', 1, minimum=1 + ) + self.constrain_total_variation_weight = self._group.create_real_parameter( + 'ConstrainTotalVariationWeight', 0.0, minimum=0.0 + ) + + self.remove_grid_artifacts = self._group.create_boolean_parameter( + 'RemoveGridArtifacts', False + ) + self.remove_grid_artifacts_start = self._group.create_integer_parameter( + 'RemoveGridArtifactsStart', 0, minimum=0 + ) + self.remove_grid_artifacts_stop = self._group.create_integer_parameter( + 'RemoveGridArtifactsStop', -1 + ) + self.remove_grid_artifacts_stride = self._group.create_integer_parameter( + 'RemoveGridArtifactsStride', 1, minimum=1 + ) + self.remove_grid_artifacts_period_x_m = self._group.create_real_parameter( + 'RemoveGridArtifactsPeriodXInMeters', 1e-7, minimum=0.0 + ) + self.remove_grid_artifacts_period_y_m = self._group.create_real_parameter( + 'RemoveGridArtifactsPeriodYInMeters', 1e-7, minimum=0.0 + ) + self.remove_grid_artifacts_window_size_px = self._group.create_integer_parameter( + 'RemoveGridArtifactsWindowSizeInPixels', + 5, + minimum=1, + ) + self.remove_grid_artifacts_direction = self._group.create_string_parameter( + 'RemoveGridArtifactsDirection', 'XY' + ) + + self.regularize_multislice = self._group.create_boolean_parameter( + 'RegularizeMultislice', False + ) + self.regularize_multislice_start = self._group.create_integer_parameter( + 'RegularizeMultisliceStart', 0, minimum=0 + ) + self.regularize_multislice_stop = self._group.create_integer_parameter( + 'RegularizeMultisliceStop', -1 + ) + self.regularize_multislice_stride = self._group.create_integer_parameter( + 'RegularizeMultisliceStride', 1, minimum=1 + ) + self.regularize_multislice_weight = self._group.create_real_parameter( + 'RegularizeMultisliceWeight', 0.0, minimum=0.0 + ) + self.regularize_multislice_unwrap_phase = self._group.create_boolean_parameter( + 'RegularizeMultisliceUnwrapPhase', True + ) + self.regularize_multislice_unwrap_phase_image_gradient_method = ( + self._group.create_string_parameter( + 'RegularizeMultisliceUnwrapPhaseImageGradientMethod', 'FOURIER_DIFFERENTIATION' + ) + ) + self.regularize_multislice_unwrap_phase_image_integration_method = ( + self._group.create_string_parameter( + 'RegularizeMultisliceUnwrapPhaseImageIntegrationMethod', 'FOURIER' + ) + ) + + self.patch_interpolator = self._group.create_string_parameter( + 'PatchInterpolator', 'FOURIER' + ) + + self.remove_object_probe_ambiguity = self._group.create_boolean_parameter( + 'RemoveObjectProbeAmbiguity', True + ) + self.remove_object_probe_ambiguity_start = self._group.create_integer_parameter( + 'RemoveObjectProbeAmbiguityStart', 0, minimum=0 + ) + self.remove_object_probe_ambiguity_stop = self._group.create_integer_parameter( + 'RemoveObjectProbeAmbiguityStop', -1 + ) + self.remove_object_probe_ambiguity_stride = self._group.create_integer_parameter( + 'RemoveObjectProbeAmbiguityStride', 10, minimum=1 + ) + + self.build_preconditioner_with_all_modes = self._group.create_boolean_parameter( + 'BuildPreconditionerWithAllModes', False + ) + + def _update(self, observable: Observable) -> None: + if observable is self._group: + self.notify_observers() + + +class PtyChiProbeSettings(Observable, Observer): + def __init__(self, registry: SettingsRegistry) -> None: + super().__init__() + self._group = registry.create_group('PtyChiProbe') + self._group.add_observer(self) + + self.is_optimizable = self._group.create_boolean_parameter('IsOptimizable', True) + self.optimization_plan_start = self._group.create_integer_parameter( + 'OptimizationPlanStart', 0, minimum=0 + ) + self.optimization_plan_stop = self._group.create_integer_parameter( + 'OptimizationPlanStop', -1 + ) + self.optimization_plan_stride = self._group.create_integer_parameter( + 'OptimizationPlanStride', 1, minimum=1 + ) + self.optimizer = self._group.create_string_parameter('Optimizer', 'SGD') + self.step_size = self._group.create_real_parameter('StepSize', 1.0, minimum=0.0) + + self.constrain_probe_power = self._group.create_boolean_parameter( + 'ConstrainProbePower', False + ) + self.constrain_probe_power_start = self._group.create_integer_parameter( + 'ConstrainProbePowerStart', 0, minimum=0 + ) + self.constrain_probe_power_stop = self._group.create_integer_parameter( + 'ConstrainProbePowerStop', -1 + ) + self.constrain_probe_power_stride = self._group.create_integer_parameter( + 'ConstrainProbePowerStride', 1, minimum=1 + ) + + self.orthogonalize_incoherent_modes = self._group.create_boolean_parameter( + 'OrthogonalizeIncoherentModes', True + ) + self.orthogonalize_incoherent_modes_start = self._group.create_integer_parameter( + 'OrthogonalizeIncoherentModesStart', 0, minimum=0 + ) + self.orthogonalize_incoherent_modes_stop = self._group.create_integer_parameter( + 'OrthogonalizeIncoherentModesStop', -1 + ) + self.orthogonalize_incoherent_modes_stride = self._group.create_integer_parameter( + 'OrthogonalizeIncoherentModesStride', 1, minimum=1 + ) + self.orthogonalize_incoherent_modes_method = self._group.create_string_parameter( + 'OrthogonalizeIncoherentModesMethod', 'SVD' + ) + + self.orthogonalize_opr_modes = self._group.create_boolean_parameter( + 'OrthogonalizeOPRModes', True + ) + self.orthogonalize_opr_modes_start = self._group.create_integer_parameter( + 'OrthogonalizeOPRModesStart', 0, minimum=0 + ) + self.orthogonalize_opr_modes_stop = self._group.create_integer_parameter( + 'OrthogonalizeOPRModesStop', -1 + ) + self.orthogonalize_opr_modes_stride = self._group.create_integer_parameter( + 'OrthogonalizeOPRModesStride', 1, minimum=1 + ) + + self.constrain_support = self._group.create_boolean_parameter('ConstrainSupport', False) + self.constrain_support_start = self._group.create_integer_parameter( + 'ConstrainSupportStart', 0, minimum=0 + ) + self.constrain_support_stop = self._group.create_integer_parameter( + 'ConstrainSupportStop', -1 + ) + self.constrain_support_stride = self._group.create_integer_parameter( + 'ConstrainSupportStride', 1, minimum=1 + ) + self.constrain_support_threshold = self._group.create_real_parameter( + 'ConstrainSupportThreshold', 0.005, minimum=0.0 + ) + + self.constrain_center = self._group.create_boolean_parameter('ConstrainCenter', False) + self.constrain_center_start = self._group.create_integer_parameter( + 'ConstrainCenterStart', 0, minimum=0 + ) + self.constrain_center_stop = self._group.create_integer_parameter('ConstrainCenterStop', -1) + self.constrain_center_stride = self._group.create_integer_parameter( + 'ConstrainCenterStride', 1, minimum=1 + ) + + self.relax_eigenmode_update = self._group.create_real_parameter( + 'RelaxEigenmodeUpdate', 1.0, minimum=0.0, maximum=1.0 + ) + + def _update(self, observable: Observable) -> None: + if observable is self._group: + self.notify_observers() + + +class PtyChiProbePositionSettings(Observable, Observer): + def __init__(self, registry: SettingsRegistry) -> None: + super().__init__() + self._group = registry.create_group('PtyChiProbePosition') + self._group.add_observer(self) + + self.is_optimizable = self._group.create_boolean_parameter('IsOptimizable', False) + self.optimization_plan_start = self._group.create_integer_parameter( + 'OptimizationPlanStart', 0, minimum=0 + ) + self.optimization_plan_stop = self._group.create_integer_parameter( + 'OptimizationPlanStop', -1 + ) + self.optimization_plan_stride = self._group.create_integer_parameter( + 'OptimizationPlanStride', 1, minimum=1 + ) + self.optimizer = self._group.create_string_parameter('Optimizer', 'SGD') + self.step_size = self._group.create_real_parameter('StepSize', 1.0, minimum=0.0) + + self.constrain_centroid = self._group.create_boolean_parameter('ConstrainCentroid', False) + + self.correction_type = self._group.create_string_parameter('CorrectionType', 'GRADIENT') + self.differentiation_method = self._group.create_string_parameter( + 'DifferentiationMethod', 'GAUSSIAN' + ) + self.cross_correlation_scale = self._group.create_integer_parameter( + 'CrossCorrelationScale', 20000, minimum=1 + ) + self.cross_correlation_real_space_width = self._group.create_real_parameter( + 'CrossCorrelationRealSpaceWidth', 0.01, minimum=0.0 + ) + self.cross_correlation_probe_threshold = self._group.create_real_parameter( + 'CrossCorrelationProbeThreshold', 0.1, minimum=0.0, maximum=1.0 + ) + + self.limit_update_magnitude = self._group.create_boolean_parameter( + 'LimitUpdateMagnitude', False + ) + self.update_magnitude_limit = self._group.create_real_parameter( + 'UpdateMagnitudeLimit', 1.0, minimum=0.0 + ) + + self.constrain_affine_transform = self._group.create_boolean_parameter( + 'ConstrainAffineTransform', False + ) + self.constrain_affine_transform_start = self._group.create_integer_parameter( + 'ConstrainAffineTransformStart', 0, minimum=0 + ) + self.constrain_affine_transform_stop = self._group.create_integer_parameter( + 'ConstrainAffineTransformStop', -1 + ) + self.constrain_affine_transform_stride = self._group.create_integer_parameter( + 'ConstrainAffineTransformStride', 1, minimum=1 + ) + self.constrain_affine_transform_degrees_of_freedom = self._group.create_integer_parameter( + 'ConstrainAffineTransformDegreesOfFreedom', 0, minimum=0 + ) + self.constrain_affine_transform_position_weight_update_interval = ( + self._group.create_integer_parameter( + 'ConstrainAffineTransformPositionWeightUpdateInterval', 10, minimum=1 + ) + ) + self.constrain_affine_transform_apply_constraint = self._group.create_boolean_parameter( + 'ConstrainAffineTransformApplyConstraint', True + ) + self.constrain_affine_transform_max_expected_error_px = self._group.create_real_parameter( + 'ConstrainAffineTransformMaxExpectedErrorInPixels', 1.0, minimum=0.0 + ) + + def _update(self, observable: Observable) -> None: + if observable is self._group: + self.notify_observers() + + +class PtyChiOPRSettings(Observable, Observer): + def __init__(self, registry: SettingsRegistry) -> None: + super().__init__() + self._group = registry.create_group('PtyChiOPR') + self._group.add_observer(self) + + self.is_optimizable = self._group.create_boolean_parameter('IsOptimizable', False) + self.optimization_plan_start = self._group.create_integer_parameter( + 'OptimizationPlanStart', 0, minimum=0 + ) + self.optimization_plan_stop = self._group.create_integer_parameter( + 'OptimizationPlanStop', -1 + ) + self.optimization_plan_stride = self._group.create_integer_parameter( + 'OptimizationPlanStride', 1, minimum=1 + ) + self.optimizer = self._group.create_string_parameter('Optimizer', 'SGD') + self.step_size = self._group.create_real_parameter('StepSize', 1.0, minimum=0.0) + + self.optimize_eigenmode_weights = self._group.create_boolean_parameter( + 'OptimizeEigenmodeWeigts', True + ) + self.optimize_intensities = self._group.create_boolean_parameter( + 'OptimizeIntensities', False + ) + + self.smooth_mode_weights = self._group.create_boolean_parameter('SmoothModeWeights', False) + self.smooth_mode_weights_start = self._group.create_integer_parameter( + 'SmoothModeWeightsStart', 0, minimum=0 + ) + self.smooth_mode_weights_stop = self._group.create_integer_parameter( + 'SmoothModeWeightsStop', -1 + ) + self.smooth_mode_weights_stride = self._group.create_integer_parameter( + 'SmoothModeWeightsStride', 1, minimum=1 + ) + self.smoothing_method = self._group.create_string_parameter('SmoothingMethod', '') + self.polynomial_smoothing_degree = self._group.create_integer_parameter( + 'PolynomialSmoothingDegree', 4, minimum=0, maximum=10 + ) + + self.relax_update = self._group.create_real_parameter( + 'RelaxUpdate', 1.0, minimum=0.0, maximum=1.0 + ) + + def _update(self, observable: Observable) -> None: + if observable is self._group: + self.notify_observers() + + +class PtyChiAutodiffSettings(Observable, Observer): + def __init__(self, registry: SettingsRegistry) -> None: + super().__init__() + self._group = registry.create_group('PtyChiAutodiff') + self._group.add_observer(self) + + self.loss_function = self._group.create_string_parameter('LossFunction', 'MSE_SQRT') + self.forward_model_class = self._group.create_string_parameter( + 'ForwardModelClass', 'PLANAR_PTYCHOGRAPHY' + ) + + def _update(self, observable: Observable) -> None: + if observable is self._group: + self.notify_observers() + + +class PtyChiDMSettings(Observable, Observer): + def __init__(self, registry: SettingsRegistry) -> None: + super().__init__() + self._group = registry.create_group('PtyChiDM') + self._group.add_observer(self) + + self.exit_wave_update_relaxation = self._group.create_real_parameter( + 'ExitWaveUpdateRelaxation', 1.0, minimum=0.0, maximum=1.0 + ) + self.chunk_length = self._group.create_integer_parameter('ChunkLength', 1, minimum=1) + self.object_amplitude_clamp_limit = self._group.create_real_parameter( + 'ObjectAmplitudeClampLimit', 1000, minimum=0.0 + ) + self.object_inertia = self._group.create_real_parameter( + 'ObjectInertia', 0.0, minimum=0.0, maximum=1.0 + ) + self.probe_inertia = self._group.create_real_parameter( + 'ProbeInertia', 0.0, minimum=0.0, maximum=1.0 + ) + + def _update(self, observable: Observable) -> None: + if observable is self._group: + self.notify_observers() + + +class PtyChiLSQMLSettings(Observable, Observer): + def __init__(self, registry: SettingsRegistry) -> None: + super().__init__() + self._group = registry.create_group('PtyChiLSQML') + self._group.add_observer(self) + + self.noise_model = self._group.create_string_parameter('NoiseModel', 'GAUSSIAN') + self.gaussian_noise_deviation = self._group.create_real_parameter( + 'GaussianNoiseDeviation', 0.5 + ) + self.solve_object_probe_step_size_jointly_for_first_slice_in_multislice = ( + self._group.create_boolean_parameter( + 'SolveObjectProbeStepSizeJointlyForFirstSliceInMultislice', False + ) + ) + self.solve_step_sizes_only_using_first_probe_mode = self._group.create_boolean_parameter( + 'SolveStepSizesOnlyUsingFirstProbeMode', True + ) + self.momentum_acceleration_gain = self._group.create_real_parameter( + 'MomentumAccelerationGain', 0.0, minimum=0.0 + ) + self.use_momentum_acceleration_gradient_mixing_factor = ( + self._group.create_boolean_parameter( + 'UseMomentumAccelerationGradientMixingFactor', False + ) + ) + self.momentum_acceleration_gradient_mixing_factor = self._group.create_real_parameter( + 'MomentumAccelerationGradientMixingFactor', 1.0 + ) + self.rescale_probe_intensity_in_first_epoch = self._group.create_boolean_parameter( + 'RescaleProbeIntensityInFirstEpoch', True + ) + + self.object_optimal_step_size_scaler = self._group.create_real_parameter( + 'ObjectOptimalStepSizeScaler', 0.9, minimum=0.0 + ) + self.object_multimodal_update = self._group.create_boolean_parameter( + 'ObjectMultimodalUpdate', True + ) + self.probe_optimal_step_size_scaler = self._group.create_real_parameter( + 'ProbeOptimalStepSizeScaler', 0.9, minimum=0.0 + ) + + def _update(self, observable: Observable) -> None: + if observable is self._group: + self.notify_observers() + + +class PtyChiPIESettings(Observable, Observer): + def __init__(self, registry: SettingsRegistry) -> None: + super().__init__() + self._group = registry.create_group('PtyChiPIE') + self._group.add_observer(self) + + self.probe_alpha = self._group.create_real_parameter( + 'ProbeAlpha', 0.1, minimum=0.0, maximum=1.0 + ) + self.object_alpha = self._group.create_real_parameter( + 'ObjectAlpha', 0.1, minimum=0.0, maximum=1.0 + ) + + def _update(self, observable: Observable) -> None: + if observable is self._group: + self.notify_observers() diff --git a/src/ptychodus/model/ptychonn/__init__.py b/src/ptychodus/model/ptychonn/__init__.py index 72abe188..b4f67d39 100644 --- a/src/ptychodus/model/ptychonn/__init__.py +++ b/src/ptychodus/model/ptychonn/__init__.py @@ -1,11 +1,11 @@ from .core import ( - PtychoNNModelPresenter, + PtychoNNModelSettings, PtychoNNReconstructorLibrary, - PtychoNNTrainingPresenter, + PtychoNNTrainingSettings, ) __all__ = [ - 'PtychoNNModelPresenter', + 'PtychoNNModelSettings', 'PtychoNNReconstructorLibrary', - 'PtychoNNTrainingPresenter', + 'PtychoNNTrainingSettings', ] diff --git a/src/ptychodus/model/ptychonn/buffers.py b/src/ptychodus/model/ptychonn/buffers.py index 60e71d8c..4b30d00f 100644 --- a/src/ptychodus/model/ptychonn/buffers.py +++ b/src/ptychodus/model/ptychonn/buffers.py @@ -1,31 +1,33 @@ from __future__ import annotations +from typing import Any, TypeAlias import logging import numpy import numpy.typing from ptychodus.api.geometry import ImageExtent -from ptychodus.api.object import ObjectArrayType -from ptychodus.api.typing import Float32ArrayType +from ptychodus.api.typing import ComplexArrayType + +Float32ArrayType: TypeAlias = numpy.typing.NDArray[numpy.float32] logger = logging.getLogger(__name__) class PatternCircularBuffer: - def __init__(self, extent: ImageExtent, maxSize: int) -> None: + def __init__(self, extent: ImageExtent, max_size: int) -> None: self._buffer: Float32ArrayType = numpy.zeros( - (maxSize, *extent.shape), + (max_size, *extent.shape), dtype=numpy.float32, ) self._pos = 0 self._full = False @classmethod - def createZeroSized(cls) -> PatternCircularBuffer: + def create_zero_sized(cls) -> PatternCircularBuffer: return cls(ImageExtent(0, 0), 0) @property - def isZeroSized(self) -> bool: + def is_zero_sized(self) -> bool: return self._buffer.size == 0 def append(self, array: Float32ArrayType) -> None: @@ -36,33 +38,33 @@ def append(self, array: Float32ArrayType) -> None: self._pos = 0 self._full = True - def getBuffer(self) -> Float32ArrayType: + def get_buffer(self) -> Float32ArrayType: return self._buffer if self._full else self._buffer[: self._pos] - def setBuffer(self, array: Float32ArrayType) -> None: + def set_buffer(self, array: Float32ArrayType) -> None: self._buffer = array self._pos = 0 self._full = True class ObjectPatchCircularBuffer: - def __init__(self, extent: ImageExtent, channels: int, maxSize: int) -> None: + def __init__(self, extent: ImageExtent, channels: int, max_size: int) -> None: self._buffer: Float32ArrayType = numpy.zeros( - (maxSize, channels, *extent.shape), + (max_size, channels, *extent.shape), dtype=numpy.float32, ) self._pos = 0 self._full = False @classmethod - def createZeroSized(cls) -> ObjectPatchCircularBuffer: + def create_zero_sized(cls) -> ObjectPatchCircularBuffer: return cls(ImageExtent(0, 0), 0, 0) @property - def isZeroSized(self) -> bool: + def is_zero_sized(self) -> bool: return self._buffer.size == 0 - def append(self, array: ObjectArrayType) -> None: + def append(self, array: ComplexArrayType) -> None: self._buffer[self._pos, 0, :, :] = numpy.angle(array).astype(numpy.float32) if self._buffer.shape[1] > 1: @@ -74,10 +76,10 @@ def append(self, array: ObjectArrayType) -> None: self._pos = 0 self._full = True - def getBuffer(self) -> Float32ArrayType: + def get_buffer(self) -> Float32ArrayType: return self._buffer if self._full else self._buffer[: self._pos] - def setBuffer(self, array: Float32ArrayType) -> None: + def set_buffer(self, array: Float32ArrayType) -> None: self._buffer = array self._pos = 0 self._full = True diff --git a/src/ptychodus/model/ptychonn/common.py b/src/ptychodus/model/ptychonn/common.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/ptychodus/model/ptychonn/core.py b/src/ptychodus/model/ptychonn/core.py index b5f343d4..36088f2b 100644 --- a/src/ptychodus/model/ptychonn/core.py +++ b/src/ptychodus/model/ptychonn/core.py @@ -19,134 +19,26 @@ logger = logging.getLogger(__name__) -class PtychoNNModelPresenter(Observable, Observer): - MAX_INT: Final[int] = 0x7FFFFFFF - - def __init__(self, settings: PtychoNNModelSettings) -> None: - super().__init__() - self._settings = settings - - settings.addObserver(self) - - def getNumberOfConvolutionKernelsLimits(self) -> Interval[int]: - return Interval[int](1, self.MAX_INT) - - def getNumberOfConvolutionKernels(self) -> int: - limits = self.getNumberOfConvolutionKernelsLimits() - return limits.clamp(self._settings.numberOfConvolutionKernels.getValue()) - - def setNumberOfConvolutionKernels(self, value: int) -> None: - self._settings.numberOfConvolutionKernels.setValue(value) - - def getBatchSizeLimits(self) -> Interval[int]: - return Interval[int](1, self.MAX_INT) - - def getBatchSize(self) -> int: - limits = self.getBatchSizeLimits() - return limits.clamp(self._settings.batchSize.getValue()) - - def setBatchSize(self, value: int) -> None: - self._settings.batchSize.setValue(value) - - def isBatchNormalizationEnabled(self) -> bool: - return self._settings.useBatchNormalization.getValue() - - def setBatchNormalizationEnabled(self, enabled: bool) -> None: - self._settings.useBatchNormalization.setValue(enabled) - - def update(self, observable: Observable) -> None: - if observable is self._settings: - self.notifyObservers() - - -class PtychoNNTrainingPresenter(Observable, Observer): - MAX_INT: Final[int] = 0x7FFFFFFF - - def __init__(self, settings: PtychoNNTrainingSettings) -> None: - super().__init__() - self._settings = settings - - settings.addObserver(self) - - def getValidationSetFractionalSizeLimits(self) -> Interval[Decimal]: - return Interval[Decimal](Decimal(0), Decimal(1)) - - def getValidationSetFractionalSize(self) -> Decimal: - limits = self.getValidationSetFractionalSizeLimits() - return limits.clamp( - Decimal.from_float(self._settings.validationSetFractionalSize.getValue()) - ) - - def setValidationSetFractionalSize(self, value: Decimal) -> None: - self._settings.validationSetFractionalSize.setValue(float(value)) - - def getMaximumLearningRateLimits(self) -> Interval[Decimal]: - return Interval[Decimal](Decimal(0), Decimal(1)) - - def getMaximumLearningRate(self) -> Decimal: - limits = self.getMaximumLearningRateLimits() - return limits.clamp(Decimal.from_float(self._settings.maximumLearningRate.getValue())) - - def setMaximumLearningRate(self, value: Decimal) -> None: - self._settings.maximumLearningRate.setValue(float(value)) - - def getMinimumLearningRateLimits(self) -> Interval[Decimal]: - return Interval[Decimal](Decimal(0), Decimal(1)) - - def getMinimumLearningRate(self) -> Decimal: - limits = self.getMinimumLearningRateLimits() - return limits.clamp(Decimal.from_float(self._settings.minimumLearningRate.getValue())) - - def setMinimumLearningRate(self, value: Decimal) -> None: - self._settings.minimumLearningRate.setValue(float(value)) - - def getTrainingEpochsLimits(self) -> Interval[int]: - return Interval[int](1, self.MAX_INT) - - def getTrainingEpochs(self) -> int: - limits = self.getTrainingEpochsLimits() - return limits.clamp(self._settings.trainingEpochs.getValue()) - - def setTrainingEpochs(self, value: int) -> None: - self._settings.trainingEpochs.setValue(value) - - def getStatusIntervalInEpochsLimits(self) -> Interval[int]: - return Interval[int](1, self.MAX_INT) - - def getStatusIntervalInEpochs(self) -> int: - limits = self.getStatusIntervalInEpochsLimits() - return limits.clamp(self._settings.statusIntervalInEpochs.getValue()) - - def setStatusIntervalInEpochs(self, value: int) -> None: - self._settings.statusIntervalInEpochs.setValue(value) - - def update(self, observable: Observable) -> None: - if observable is self._settings: - self.notifyObservers() - - class PtychoNNReconstructorLibrary(ReconstructorLibrary): def __init__( self, - modelSettings: PtychoNNModelSettings, - trainingSettings: PtychoNNTrainingSettings, + model_settings: PtychoNNModelSettings, + training_settings: PtychoNNTrainingSettings, reconstructors: Sequence[Reconstructor], ) -> None: super().__init__() - self._modelSettings = modelSettings - self._trainingSettings = trainingSettings - self.modelPresenter = PtychoNNModelPresenter(modelSettings) - self.trainingPresenter = PtychoNNTrainingPresenter(trainingSettings) + self.model_settings = model_settings + self.training_settings = training_settings self._reconstructors = reconstructors @classmethod - def createInstance( - cls, settingsRegistry: SettingsRegistry, isDeveloperModeEnabled: bool + def create_instance( + cls, settings_registry: SettingsRegistry, is_developer_mode_enabled: bool ) -> PtychoNNReconstructorLibrary: - modelSettings = PtychoNNModelSettings(settingsRegistry) - trainingSettings = PtychoNNTrainingSettings(settingsRegistry) - phaseOnlyReconstructor: TrainableReconstructor = NullReconstructor('PhaseOnly') - amplitudePhaseReconstructor: TrainableReconstructor = NullReconstructor('AmplitudePhase') + model_settings = PtychoNNModelSettings(settings_registry) + training_settings = PtychoNNTrainingSettings(settings_registry) + phase_only_reconstructor: TrainableReconstructor = NullReconstructor('PhaseOnly') + amplitude_phase_reconstructor: TrainableReconstructor = NullReconstructor('AmplitudePhase') reconstructors: list[TrainableReconstructor] = list() try: @@ -155,30 +47,34 @@ def createInstance( except ModuleNotFoundError: logger.info('PtychoNN not found.') - if isDeveloperModeEnabled: - reconstructors.append(phaseOnlyReconstructor) - reconstructors.append(amplitudePhaseReconstructor) + if is_developer_mode_enabled: + reconstructors.append(phase_only_reconstructor) + reconstructors.append(amplitude_phase_reconstructor) else: - phaseOnlyModelProvider = PtychoNNModelProvider( - modelSettings, trainingSettings, enableAmplitude=False + phase_only_model_provider = PtychoNNModelProvider( + model_settings, training_settings, enable_amplitude=False ) - phaseOnlyReconstructor = PtychoNNTrainableReconstructor( - modelSettings, trainingSettings, phaseOnlyModelProvider + phase_only_reconstructor = PtychoNNTrainableReconstructor( + model_settings, training_settings, phase_only_model_provider ) - amplitudePhaseModelProvider = PtychoNNModelProvider( - modelSettings, trainingSettings, enableAmplitude=True + amplitude_phase_model_provider = PtychoNNModelProvider( + model_settings, training_settings, enable_amplitude=True ) - amplitudePhaseReconstructor = PtychoNNTrainableReconstructor( - modelSettings, trainingSettings, amplitudePhaseModelProvider + amplitude_phase_reconstructor = PtychoNNTrainableReconstructor( + model_settings, training_settings, amplitude_phase_model_provider ) - reconstructors.append(phaseOnlyReconstructor) - reconstructors.append(amplitudePhaseReconstructor) + reconstructors.append(phase_only_reconstructor) + reconstructors.append(amplitude_phase_reconstructor) - return cls(modelSettings, trainingSettings, reconstructors) + return cls(model_settings, training_settings, reconstructors) @property def name(self) -> str: return 'PtychoNN' + @property + def logger_name(self) -> str: + return 'ptychonn' + def __iter__(self) -> Iterator[Reconstructor]: return iter(self._reconstructors) diff --git a/src/ptychodus/model/ptychonn/model.py b/src/ptychodus/model/ptychonn/model.py index 0ea7c935..6ca28356 100644 --- a/src/ptychodus/model/ptychonn/model.py +++ b/src/ptychodus/model/ptychonn/model.py @@ -12,24 +12,24 @@ class PtychoNNModelProvider: def __init__( self, - modelSettings: PtychoNNModelSettings, - trainingSettings: PtychoNNTrainingSettings, + model_settings: PtychoNNModelSettings, + training_settings: PtychoNNTrainingSettings, *, - enableAmplitude: bool, + enable_amplitude: bool, ) -> None: - self._modelSettings = modelSettings - self._trainingSettings = trainingSettings - self._enableAmplitude = enableAmplitude + self._model_settings = model_settings + self._training_settings = training_settings + self._enable_amplitude = enable_amplitude self._model: ptychonn.LitReconSmallModel | None = None self._trainer: lightning.Trainer | None = None - def getModelName(self) -> str: - return 'AmplitudePhase' if self._enableAmplitude else 'PhaseOnly' + def get_model_name(self) -> str: + return 'AmplitudePhase' if self._enable_amplitude else 'PhaseOnly' - def getNumberOfChannels(self) -> int: - return 2 if self._enableAmplitude else 1 + def get_num_channels(self) -> int: + return 2 if self._enable_amplitude else 1 - def getModel(self) -> ptychonn.LitReconSmallModel: + def get_model(self) -> ptychonn.LitReconSmallModel: if ( self._model is None and self._trainer is not None @@ -39,27 +39,27 @@ def getModel(self) -> ptychonn.LitReconSmallModel: else: logger.debug('Initializing model from settings') self._model = ptychonn.LitReconSmallModel( - nconv=self._modelSettings.numberOfConvolutionKernels.getValue(), - use_batch_norm=self._modelSettings.useBatchNormalization.getValue(), - enable_amplitude=self._enableAmplitude, - max_lr=float(self._trainingSettings.maximumLearningRate.getValue()), - min_lr=float(self._trainingSettings.minimumLearningRate.getValue()), + nconv=self._model_settings.num_convolution_kernels.get_value(), + use_batch_norm=self._model_settings.use_batch_normalization.get_value(), + enable_amplitude=self._enable_amplitude, + max_lr=float(self._training_settings.max_learning_rate.get_value()), + min_lr=float(self._training_settings.min_learning_rate.get_value()), ) return self._model - def openModel(self, filePath: Path) -> None: - logger.debug(f'Reading model from "{filePath}"') - self._model = ptychonn.LitReconSmallModel.load_from_checkpoint(filePath) + def open_model(self, file_path: Path) -> None: + logger.debug(f'Reading model from "{file_path}"') + self._model = ptychonn.LitReconSmallModel.load_from_checkpoint(file_path) self._trainer = None - def setTrainer(self, trainer: lightning.Trainer) -> None: + def set_trainer(self, trainer: lightning.Trainer) -> None: self._model = None self._trainer = trainer - def saveModel(self, filePath: Path) -> None: + def save_model(self, file_path: Path) -> None: if self._trainer is None: logger.warning('Need trainer to save model!') else: - logger.debug(f'Writing model to "{filePath}"') - self._trainer.save_checkpoint(filePath) + logger.debug(f'Writing model to "{file_path}"') + self._trainer.save_checkpoint(file_path) diff --git a/src/ptychodus/model/ptychonn/reconstructor.py b/src/ptychodus/model/ptychonn/reconstructor.py index 9903fcb2..9e59d2bd 100644 --- a/src/ptychodus/model/ptychonn/reconstructor.py +++ b/src/ptychodus/model/ptychonn/reconstructor.py @@ -1,4 +1,3 @@ -from collections.abc import Sequence from importlib.metadata import version from pathlib import Path from typing import Final @@ -9,22 +8,39 @@ import ptychonn from ptychodus.api.geometry import ImageExtent +from ptychodus.api.object import Object from ptychodus.api.product import Product from ptychodus.api.reconstructor import ( + LossValue, ReconstructInput, ReconstructOutput, - TrainableReconstructor, TrainOutput, + TrainableReconstructor, ) +from ptychodus.api.typing import ComplexArrayType -from ..analysis import ObjectLinearInterpolator, ObjectStitcher -from .buffers import ObjectPatchCircularBuffer, PatternCircularBuffer +from ..analysis import BarycentricArrayInterpolator, BarycentricArrayStitcher from .model import PtychoNNModelProvider from .settings import PtychoNNModelSettings, PtychoNNTrainingSettings logger = logging.getLogger(__name__) +class CenterBoxMeanPhaseCenteringStrategy: # TODO USE + def __call__(self, array: ComplexArrayType) -> ComplexArrayType: + one_third_height = array.shape[-2] // 3 + one_third_width = array.shape[-1] // 3 + + amplitude = numpy.absolute(array) + phase = numpy.angle(array) + + center_box_mean_phase = phase[ + one_third_height : one_third_height * 2, one_third_width : one_third_width * 2 + ].mean() + + return amplitude * numpy.exp(1j * (phase - center_box_mean_phase)) + + class PtychoNNTrainableReconstructor(TrainableReconstructor): MODEL_FILE_FILTER: Final[str] = 'PyTorch Lightning Checkpoint Files (*.ckpt)' TRAINING_DATA_FILE_FILTER: Final[str] = 'NumPy Zipped Archive (*.npz)' @@ -33,191 +49,160 @@ class PtychoNNTrainableReconstructor(TrainableReconstructor): def __init__( self, - modelSettings: PtychoNNModelSettings, - trainingSettings: PtychoNNTrainingSettings, - modelProvider: PtychoNNModelProvider, + model_settings: PtychoNNModelSettings, + training_settings: PtychoNNTrainingSettings, + model_provider: PtychoNNModelProvider, ) -> None: - self._modelSettings = modelSettings - self._trainingSettings = trainingSettings - self._modelProvider = modelProvider - self._patternBuffer = PatternCircularBuffer.createZeroSized() - self._objectPatchBuffer = ObjectPatchCircularBuffer.createZeroSized() + self._model_settings = model_settings + self._training_settings = training_settings + self._model_provider = model_provider - ptychonnVersion = version('ptychonn') - logger.info(f'\tPtychoNN {ptychonnVersion}') + ptychonn_version = version('ptychonn') + logger.info(f'\tPtychoNN {ptychonn_version}') @property def name(self) -> str: - return self._modelProvider.getModelName() + return self._model_provider.get_model_name() def reconstruct(self, parameters: ReconstructInput) -> ReconstructOutput: # TODO data size/shape requirements to GUI data = parameters.patterns - dataSize = data.shape[-1] + data_size = data.shape[-1] - if dataSize != data.shape[-2]: + if data_size != data.shape[-2]: raise ValueError('PtychoNN expects square diffraction data!') - isDataSizePow2 = dataSize & (dataSize - 1) == 0 and dataSize > 0 + is_data_size_pow2 = data_size & (data_size - 1) == 0 and data_size > 0 - if not isDataSizePow2: + if not is_data_size_pow2: raise ValueError('PtychoNN expects that the diffraction data size is a power of two!') - # Bin diffraction data - # TODO extract binning to data loading (and verify that x-y coordinates are correct) - inputSize = dataSize - binSize = dataSize // inputSize - - if binSize == 1: - binnedData = data - else: - binnedData = numpy.zeros((data.shape[0], inputSize, inputSize), dtype=data.dtype) - - for i in range(inputSize): - for j in range(inputSize): - binnedData[:, i, j] = numpy.sum( - data[ - :, - binSize * i : binSize * (i + 1), - binSize * j : binSize * (j + 1), - ] - ) - - model = self._modelProvider.getModel() + model = self._model_provider.get_model() logger.debug('Inferring...') - objectPatches = ptychonn.infer( - data=binnedData.astype(numpy.float32), + object_patches = ptychonn.infer( + data=data.astype(numpy.float32), model=model, ) logger.debug('Stitching...') - stitcher = ObjectStitcher(parameters.product.object_.getGeometry()) + object_array = parameters.product.object_.get_array() + object_geometry = parameters.product.object_.get_geometry() + stitcher = BarycentricArrayStitcher( + upper=numpy.zeros_like(object_array), lower=numpy.zeros_like(object_array, dtype=float) + ) - for scanPoint, objectPatchChannels in zip(parameters.product.scan, objectPatches): - patchArray = numpy.exp(1j * objectPatchChannels[0]) + for scan_point, object_patch_channels in zip(parameters.product.positions, object_patches): + patch_array = numpy.exp(1j * object_patch_channels[0]) - if objectPatchChannels.shape[0] == 2: - patchArray *= objectPatchChannels[1] + if object_patch_channels.shape[0] == 2: + patch_array *= object_patch_channels[1] else: - patchArray *= 0.5 + patch_array *= 0.5 - stitcher.addPatch(scanPoint, patchArray) + object_point = object_geometry.map_scan_point_to_object_point(scan_point) + stitcher.add_patch(object_point.position_x_px, object_point.position_y_px, patch_array) + + object_ = Object( + array=stitcher.stitch(), + pixel_geometry=object_geometry.get_pixel_geometry(), + center=object_geometry.get_center(), + layer_spacing_m=parameters.product.object_.layer_spacing_m, + ) product = Product( metadata=parameters.product.metadata, - scan=parameters.product.scan, - probe=parameters.product.probe, - object_=stitcher.build(), + positions=parameters.product.positions, + probes=parameters.product.probes, + object_=object_, costs=list(), # TODO put something here? ) return ReconstructOutput(product, 0) - def ingestTrainingData(self, parameters: ReconstructInput) -> None: - interpolator = ObjectLinearInterpolator(parameters.product.object_) - probeExtent = parameters.product.probe.getExtent() - - if self._patternBuffer.isZeroSized: - patternExtent = ImageExtent( - widthInPixels=parameters.patterns.shape[-1], - heightInPixels=parameters.patterns.shape[-2], - ) - maximumSize = max(1, self._trainingSettings.maximumTrainingDatasetSize.getValue()) - self._patternBuffer = PatternCircularBuffer(patternExtent, maximumSize) - self._objectPatchBuffer = ObjectPatchCircularBuffer( - patternExtent, self._modelProvider.getNumberOfChannels(), maximumSize - ) - - for scanPoint in parameters.product.scan: - objectPatch = interpolator.getPatch(scanPoint, probeExtent) - self._objectPatchBuffer.append(objectPatch.array) + def get_model_file_filter(self) -> str: + return self.MODEL_FILE_FILTER - for pattern in parameters.patterns.astype(numpy.float32): - self._patternBuffer.append(pattern) + def open_model(self, file_path: Path) -> None: + self._model_provider.open_model(file_path) - def getOpenTrainingDataFileFilterList(self) -> Sequence[str]: - return [self.getOpenTrainingDataFileFilter()] + def save_model(self, file_path: Path) -> None: + self._model_provider.save_model(file_path) - def getOpenTrainingDataFileFilter(self) -> str: + def get_training_data_file_filter(self) -> str: return self.TRAINING_DATA_FILE_FILTER - def openTrainingData(self, filePath: Path) -> None: - logger.debug(f'Reading "{filePath}" as "NPZ"') - trainingData = numpy.load(filePath) - self._patternBuffer.setBuffer(trainingData[self.PATTERNS_KW]) - self._objectPatchBuffer.setBuffer(trainingData[self.PATCHES_KW]) + def export_training_data(self, file_path: Path, parameters: ReconstructInput) -> None: + object_geometry = parameters.product.object_.get_geometry() + interpolator = BarycentricArrayInterpolator(parameters.product.object_.get_array()) + num_channels = self._model_provider.get_num_channels() + probe_extent = ImageExtent( + width_px=parameters.product.probes.width_px, + height_px=parameters.product.probes.height_px, + ) + patches = numpy.zeros( + (len(parameters.product.positions), num_channels, *probe_extent.shape), + dtype=numpy.float32, + ) - def getSaveTrainingDataFileFilterList(self) -> Sequence[str]: - return [self.getSaveTrainingDataFileFilter()] + for index, scan_point in enumerate(parameters.product.positions): + object_point = object_geometry.map_scan_point_to_object_point(scan_point) + patch = interpolator.get_patch( + object_point.position_x_px, + object_point.position_y_px, + probe_extent.width_px, + probe_extent.height_px, + ) + patches[index, 0, :, :] = numpy.angle(patch) - def getSaveTrainingDataFileFilter(self) -> str: - return self.TRAINING_DATA_FILE_FILTER + if num_channels > 1: + patches[index, 1, :, :] = numpy.absolute(patch) - def saveTrainingData(self, filePath: Path) -> None: - logger.debug(f'Writing "{filePath}" as "NPZ"') - trainingData = { - self.PATTERNS_KW: self._patternBuffer.getBuffer(), - self.PATCHES_KW: self._objectPatchBuffer.getBuffer(), + logger.debug(f'Writing "{file_path}" as "NPZ"') + training_data = { + self.PATTERNS_KW: parameters.patterns.astype(numpy.float32), + self.PATCHES_KW: patches, } - numpy.savez_compressed(filePath, **trainingData) + numpy.savez_compressed(file_path, **training_data) + + def get_training_data_path(self) -> Path: + return self._training_settings.training_data_path.get_value() - def train(self) -> TrainOutput: - model = self._modelProvider.getModel() + def train(self, data_path: Path) -> TrainOutput: + logger.debug(f'Reading "{data_path}" as "NPZ"') + training_data = numpy.load(data_path) + self._training_settings.training_data_path.set_value(data_path) + + model = self._model_provider.get_model() logger.debug('Training...') - trainingSetFractionalSize = ( - 1 - self._trainingSettings.validationSetFractionalSize.getValue() + training_set_fractional_size = ( + 1 - self._training_settings.validation_set_fractional_size.get_value() ) - trainer, trainerLog = ptychonn.train( + trainer, trainer_log = ptychonn.train( model=model, - batch_size=self._modelSettings.batchSize.getValue(), + batch_size=self._model_settings.batch_size.get_value(), out_dir=None, - X_train=self._patternBuffer.getBuffer(), - Y_train=self._objectPatchBuffer.getBuffer(), - epochs=self._trainingSettings.trainingEpochs.getValue(), - training_fraction=float(trainingSetFractionalSize), - log_frequency=self._trainingSettings.statusIntervalInEpochs.getValue(), + X_train=training_data[self.PATTERNS_KW], + Y_train=training_data[self.PATCHES_KW], + epochs=self._training_settings.training_epochs.get_value(), + training_fraction=float(training_set_fractional_size), + log_frequency=self._training_settings.status_interval_in_epochs.get_value(), strategy='ddp_notebook', ) - self._modelProvider.setTrainer(trainer) + self._model_provider.set_trainer(trainer) - trainingLoss: list[float] = list() - validationLoss: list[float] = list() + losses: list[LossValue] = [] - for entry in trainerLog.logs: + for epoch, entry in enumerate(trainer_log.logs): try: tloss = entry['training_loss'] vloss = entry['validation_loss'] except KeyError: pass else: - trainingLoss.append(tloss) - validationLoss.append(vloss) + losses.append(LossValue(epoch=epoch, training_loss=tloss, validation_loss=vloss)) return TrainOutput( - trainingLoss=trainingLoss, - validationLoss=validationLoss, + losses=losses, result=0, ) - - def clearTrainingData(self) -> None: - self._patternBuffer = PatternCircularBuffer.createZeroSized() - self._objectPatchBuffer = ObjectPatchCircularBuffer.createZeroSized() - - def getOpenModelFileFilterList(self) -> Sequence[str]: - return [self.getOpenModelFileFilter()] - - def getOpenModelFileFilter(self) -> str: - return self.MODEL_FILE_FILTER - - def openModel(self, filePath: Path) -> None: - self._modelProvider.openModel(filePath) - - def getSaveModelFileFilterList(self) -> Sequence[str]: - return [self.getSaveModelFileFilter()] - - def getSaveModelFileFilter(self) -> str: - return self.MODEL_FILE_FILTER - - def saveModel(self, filePath: Path) -> None: - self._modelProvider.saveModel(filePath) diff --git a/src/ptychodus/model/ptychonn/settings.py b/src/ptychodus/model/ptychonn/settings.py index 3199971e..5b2a9ac6 100644 --- a/src/ptychodus/model/ptychonn/settings.py +++ b/src/ptychodus/model/ptychonn/settings.py @@ -1,3 +1,5 @@ +from pathlib import Path + from ptychodus.api.observer import Observable, Observer from ptychodus.api.settings import SettingsRegistry @@ -5,45 +7,45 @@ class PtychoNNModelSettings(Observable, Observer): def __init__(self, registry: SettingsRegistry) -> None: super().__init__() - self._settingsGroup = registry.createGroup('PtychoNN') - self._settingsGroup.addObserver(self) + self._group = registry.create_group('PtychoNN') + self._group.add_observer(self) - self.numberOfConvolutionKernels = self._settingsGroup.createIntegerParameter( - 'NumberOfConvolutionKernels', 16 + self.num_convolution_kernels = self._group.create_integer_parameter( + 'NumberOfConvolutionKernels', 16, minimum=1 ) - self.batchSize = self._settingsGroup.createIntegerParameter('BatchSize', 64) - self.useBatchNormalization = self._settingsGroup.createBooleanParameter( + self.batch_size = self._group.create_integer_parameter('BatchSize', 64, minimum=1) + self.use_batch_normalization = self._group.create_boolean_parameter( 'UseBatchNormalization', False ) - def update(self, observable: Observable) -> None: - if observable is self._settingsGroup: - self.notifyObservers() + def _update(self, observable: Observable) -> None: + if observable is self._group: + self.notify_observers() class PtychoNNTrainingSettings(Observable, Observer): def __init__(self, registry: SettingsRegistry) -> None: super().__init__() - self._settingsGroup = registry.createGroup('PtychoNNTraining') - self._settingsGroup.addObserver(self) + self._group = registry.create_group('PtychoNNTraining') + self._group.add_observer(self) - self.maximumTrainingDatasetSize = self._settingsGroup.createIntegerParameter( - 'MaximumTrainingDatasetSize', 100000 + self.training_data_path = self._group.create_path_parameter( + 'TrainingDataPath', Path('/path/to/training_data') ) - self.validationSetFractionalSize = self._settingsGroup.createRealParameter( - 'ValidationSetFractionalSize', 0.1 + self.validation_set_fractional_size = self._group.create_real_parameter( + 'ValidationSetFractionalSize', 0.1, minimum=0.0, maximum=1.0 ) - self.maximumLearningRate = self._settingsGroup.createRealParameter( - 'MaximumLearningRate', 1e-3 + self.max_learning_rate = self._group.create_real_parameter( + 'MaximumLearningRate', 1e-3, minimum=0.0, maximum=1.0 ) - self.minimumLearningRate = self._settingsGroup.createRealParameter( - 'MinimumLearningRate', 1e-4 + self.min_learning_rate = self._group.create_real_parameter( + 'MinimumLearningRate', 1e-4, minimum=0.0, maximum=1.0 ) - self.trainingEpochs = self._settingsGroup.createIntegerParameter('TrainingEpochs', 50) - self.statusIntervalInEpochs = self._settingsGroup.createIntegerParameter( - 'StatusIntervalInEpochs', 1 + self.training_epochs = self._group.create_integer_parameter('TrainingEpochs', 50, minimum=1) + self.status_interval_in_epochs = self._group.create_integer_parameter( + 'StatusIntervalInEpochs', 1, minimum=1 ) - def update(self, observable: Observable) -> None: - if observable is self._settingsGroup: - self.notifyObservers() + def _update(self, observable: Observable) -> None: + if observable is self._group: + self.notify_observers() diff --git a/src/ptychodus/model/ptychopinn/__init__.py b/src/ptychodus/model/ptychopinn/__init__.py new file mode 100644 index 00000000..5ebc2e91 --- /dev/null +++ b/src/ptychodus/model/ptychopinn/__init__.py @@ -0,0 +1,5 @@ +from .core import PtychoPINNReconstructorLibrary + +__all__ = [ + 'PtychoPINNReconstructorLibrary', +] diff --git a/src/ptychodus/model/ptychopinn/core.py b/src/ptychodus/model/ptychopinn/core.py new file mode 100644 index 00000000..ad8dfc9c --- /dev/null +++ b/src/ptychodus/model/ptychopinn/core.py @@ -0,0 +1,65 @@ +from __future__ import annotations +from collections.abc import Iterator +import logging + +from ...api.reconstructor import ( + NullReconstructor, + Reconstructor, + ReconstructorLibrary, + TrainableReconstructor, +) +from ...api.settings import SettingsRegistry +from .enums import PtychoPINNEnumerators +from .settings import ( + PtychoPINNInferenceSettings, + PtychoPINNModelSettings, + PtychoPINNTrainingSettings, +) + +logger = logging.getLogger(__name__) + + +class PtychoPINNReconstructorLibrary(ReconstructorLibrary): + def __init__( + self, settings_registry: SettingsRegistry, is_developer_mode_enabled: bool + ) -> None: + super().__init__() + self.model_settings = PtychoPINNModelSettings(settings_registry) + self.training_settings = PtychoPINNTrainingSettings(settings_registry) + self.inference_settings = PtychoPINNInferenceSettings(settings_registry) + self.enumerators = PtychoPINNEnumerators() + self._reconstructors: list[TrainableReconstructor] = list() + + try: + from .reconstructor import PtychoPINNTrainableReconstructor + except ModuleNotFoundError: + logger.info('PtychoPINN not found.') + + if is_developer_mode_enabled: + self._reconstructors.append(NullReconstructor('PINN')) + self._reconstructors.append(NullReconstructor('Supervised')) + else: + self._reconstructors.append( + PtychoPINNTrainableReconstructor( + 'PINN', self.model_settings, self.inference_settings, self.training_settings + ) + ) + self._reconstructors.append( + PtychoPINNTrainableReconstructor( + 'Supervised', + self.model_settings, + self.inference_settings, + self.training_settings, + ) + ) + + @property + def name(self) -> str: + return 'PtychoPINN' + + @property + def logger_name(self) -> str: + return 'ptychopinn' + + def __iter__(self) -> Iterator[Reconstructor]: + return iter(self._reconstructors) diff --git a/src/ptychodus/model/ptychopinn/enums.py b/src/ptychodus/model/ptychopinn/enums.py new file mode 100644 index 00000000..0e351111 --- /dev/null +++ b/src/ptychodus/model/ptychopinn/enums.py @@ -0,0 +1,9 @@ +from collections.abc import Iterator, Sequence + + +class PtychoPINNEnumerators: + def __init__(self) -> None: + self._amp_activations: Sequence[str] = ['sigmoid', 'swish', 'softplus', 'relu'] + + def get_amp_activations(self) -> Iterator[str]: + return iter(self._amp_activations) diff --git a/src/ptychodus/model/ptychopinn/reconstructor.py b/src/ptychodus/model/ptychopinn/reconstructor.py new file mode 100644 index 00000000..22421151 --- /dev/null +++ b/src/ptychodus/model/ptychopinn/reconstructor.py @@ -0,0 +1,273 @@ +from __future__ import annotations +from collections.abc import Sequence +from importlib.metadata import version +from pathlib import Path +from typing import Any, Final +import logging + +import numpy +import numpy.typing + +from ptycho.config.config import InferenceConfig, ModelConfig, TrainingConfig, update_legacy_dict +from ptycho.raw_data import RawData +import ptycho.loader +import ptycho.model_manager +import ptycho.params + +from ptychodus.api.object import Object +from ptychodus.api.product import Product +from ptychodus.api.reconstructor import ( + LossValue, + ReconstructInput, + ReconstructOutput, + TrainOutput, + TrainableReconstructor, +) + +from .settings import ( + PtychoPINNInferenceSettings, + PtychoPINNModelSettings, + PtychoPINNTrainingSettings, +) + +__all__ = [ + 'PtychoPINNTrainableReconstructor', +] + +logger = logging.getLogger(__name__) + + +def create_raw_data(parameters: ReconstructInput) -> RawData: + object_geometry = parameters.product.object_.get_geometry() + position_x_px: list[float] = list() + position_y_px: list[float] = list() + + for scan_point in parameters.product.positions: + object_point = object_geometry.map_scan_point_to_object_point(scan_point) + position_x_px.append(object_point.position_x_px) + position_y_px.append(object_point.position_y_px) + + return RawData.from_coords_without_pc( + xcoords=numpy.array(position_x_px), + ycoords=numpy.array(position_y_px), + diff3d=parameters.patterns, + probeGuess=parameters.product.probes.get_probe_no_opr().get_incoherent_mode(0), + # assume that all patches are from the same object + scan_index=numpy.zeros(len(parameters.product.positions), dtype=int), + objectGuess=parameters.product.object_.get_layer(0), + ) + + +class PtychoPINNTrainableReconstructor(TrainableReconstructor): + MODEL_FILE_FILTER: Final[str] = 'Zipped Archive (*.zip)' + TRAINING_DATA_FILE_FILTER: Final[str] = 'NumPy Zipped Archive (*.npz)' + + # TODO datasets for testing: xpp, "u", ALS + # TODO normalize data in preprocessing step (see note in slack) + # TODO ptychodus stitches + + def __init__( + self, + name: str, + model_settings: PtychoPINNModelSettings, + inference_settings: PtychoPINNInferenceSettings, + training_settings: PtychoPINNTrainingSettings, + *in_developer_mode: bool, + ) -> None: + super().__init__() + self._name = name + self._model_settings = model_settings + self._inference_settings = inference_settings + self._training_settings = training_settings + self._model_dict: dict[str, Any] | None = None + self._in_developer_mode = in_developer_mode + + ptychopinn_version = version('ptychopinn') + logger.info(f'\tPtychoPINN {ptychopinn_version}') + + def _create_model_config(self, model_size: int) -> ModelConfig: + return ModelConfig( + N=model_size, + gridsize=self._model_settings.gridsize.get_value(), + n_filters_scale=self._model_settings.n_filters_scale.get_value(), + model_type=self._name.lower(), + amp_activation=self._model_settings.amp_activation.get_value(), + object_big=self._model_settings.object_big.get_value(), + probe_big=self._model_settings.probe_big.get_value(), + probe_mask=self._model_settings.probe_mask.get_value(), + pad_object=self._model_settings.pad_object.get_value(), + probe_scale=self._model_settings.probe_scale.get_value(), + gaussian_smoothing_sigma=self._model_settings.gaussian_smoothing_sigma.get_value(), + ) + + @property + def name(self) -> str: + return self._name + + def _reconstruct_image(self, test_data: ptycho.loader.PtychoDataContainer) -> Any: + if self._model_dict is None: + raise RuntimeError('Model not loaded!') + + import ptycho.model + + diffraction_to_obj = self._model_dict['diffraction_to_obj'] # tf.keras.Model + intensity_scale = ptycho.model.params()['intensity_scale'] + return diffraction_to_obj.predict([test_data.X * intensity_scale, test_data.local_offsets]) + + def reconstruct(self, parameters: ReconstructInput) -> ReconstructOutput: + model_size = parameters.patterns.shape[-1] + + if parameters.patterns.shape[-2] != model_size: + raise ValueError('Model requires square diffraction patterns!') + + if self._model_dict is None: + raise ValueError('Model not loaded.') + + model_config = self._create_model_config(model_size) + inference_config = InferenceConfig( + model=model_config, + model_path=Path(), # not used + test_data_file=Path(), # not used + debug=self._in_developer_mode, + output_dir=Path(), # not used + ) + + # Update global params with new-style config + update_legacy_dict(ptycho.params.cfg, inference_config) + + # Create RawData + test_raw_data = create_raw_data(parameters) + ptycho.probe.set_probe_guess(None, test_raw_data.probeGuess) + + # Group overlapping scan positions + test_dataset = test_raw_data.generate_grouped_data( + model_config.N, + K=self._inference_settings.n_nearest_neighbors.get_value(), + nsamples=self._inference_settings.n_samples.get_value(), + ) + + # Create PtychoDataContainer + test_data_container = ptycho.loader.load( + lambda: test_dataset, test_raw_data.probeGuess, which=None, create_split=False + ) + + # Perform reconstruction + obj_tensor_full = self._reconstruct_image(test_data_container) + + # Process the reconstructed image + object_out_array = ptycho.tf_helper.reassemble_position( + obj_tensor_full, test_data_container.global_offsets, M=20 + ) + + object_in = parameters.product.object_ + object_out = Object( + array=numpy.squeeze(object_out_array), + layer_spacing_m=object_in.layer_spacing_m, + pixel_geometry=object_in.get_pixel_geometry(), + center=object_in.get_center(), + ) + costs: Sequence[float] = list() + + product = Product( + metadata=parameters.product.metadata, + positions=parameters.product.positions, + probes=parameters.product.probes, + object_=object_out, + costs=costs, + ) + + return ReconstructOutput(product, 0) + + def get_model_file_filter(self) -> str: + return self.MODEL_FILE_FILTER + + def open_model(self, file_path: Path) -> None: + # TODO model path to/from settings + self._inference_settings.model_path.set_value(file_path) + # ModelManager updates global config (ptycho.params.cfg) when loading + self._model_dict = ptycho.model_manager.ModelManager.load_multiple_models( + file_path.parent / file_path.stem + ) + # TODO update settings from ptycho.params.cfg after loading + + def save_model(self, file_path: Path) -> None: + ptycho.model_manager.save(file_path) + + def get_training_data_file_filter(self) -> str: + return self.TRAINING_DATA_FILE_FILTER + + def export_training_data(self, file_path: Path, parameters: ReconstructInput) -> None: + object_geometry = parameters.product.object_.get_geometry() + position_x_px: list[float] = list() + position_y_px: list[float] = list() + + for scan_point in parameters.product.positions: + object_point = object_geometry.map_scan_point_to_object_point(scan_point) + position_x_px.append(object_point.position_x_px) + position_y_px.append(object_point.position_y_px) + + xcoords = numpy.array(position_x_px) + ycoords = numpy.array(position_y_px) + + numpy.savez( + file_path, + xcoords=xcoords, + ycoords=ycoords, + xcoords_start=xcoords, + ycoords_start=ycoords, + diff3d=parameters.patterns, + probeGuess=parameters.product.probes.get_probe_no_opr().get_incoherent_mode(0), + # assume that all patches are from the same object + objectGuess=parameters.product.object_.get_layer(0), + scan_index=numpy.zeros(len(parameters.product.positions), dtype=int), + ) + + def get_training_data_path(self) -> Path: + return self._training_settings.data_dir.get_value() + + def train(self, data_path: Path) -> TrainOutput: + self._training_settings.data_dir.set_value(data_path) + + test_raw_data = RawData.from_file(data_path / 'test_data.npz') # TODO RawData | None + train_raw_data = RawData.from_file(data_path / 'train_data.npz') + + model_size = train_raw_data.diff3d.shape[-1] + + if train_raw_data.diff3d.shape[-2] != model_size: + raise ValueError('Model requires square diffraction patterns!') + + model_config = self._create_model_config(model_size) + training_config = TrainingConfig( + model=model_config, + train_data_file=Path(), # not used + test_data_file=None, # not used + batch_size=self._training_settings.batch_size.get_value(), + nepochs=self._training_settings.nepochs.get_value(), + mae_weight=self._training_settings.mae_weight.get_value(), + nll_weight=self._training_settings.nll_weight.get_value(), + realspace_mae_weight=self._training_settings.realspace_mae_weight.get_value(), + realspace_weight=self._training_settings.realspace_weight.get_value(), + nphotons=self._training_settings.nphotons.get_value(), # TODO get from product + positions_provided=self._training_settings.positions_provided.get_value(), + probe_trainable=self._training_settings.probe_trainable.get_value(), + intensity_scale_trainable=self._training_settings.intensity_scale_trainable.get_value(), + output_dir=Path(), # not used + ) + + # Update global params with new-style config + update_legacy_dict(ptycho.params.cfg, training_config) + + from ptycho.workflows.components import run_cdi_example, save_outputs + + recon_amp, recon_phase, train_results = run_cdi_example( + train_raw_data, test_raw_data, training_config + ) + output_dir = self._training_settings.output_dir.get_value() + self.save_model(output_dir) + save_outputs(recon_amp, recon_phase, train_results, str(output_dir)) + print(train_results.keys()) # TODO remove + # dict_keys(['history', 'model_instance', 'reconstructed_obj', 'pred_amp', 'reconstructed_obj_cdi', 'stitched_obj', 'train_container', 'test_container', 'obj_tensor_full', 'global_offsets', 'recon_amp', 'recon_phase']) + # TODO self._model_dict = train_results + + losses: Sequence[LossValue] = [] + return TrainOutput(losses, 0) # TODO diff --git a/src/ptychodus/model/ptychopinn/settings.py b/src/ptychodus/model/ptychopinn/settings.py new file mode 100644 index 00000000..5f715a4d --- /dev/null +++ b/src/ptychodus/model/ptychopinn/settings.py @@ -0,0 +1,88 @@ +from pathlib import Path + +from ptychodus.api.observer import Observable, Observer +from ptychodus.api.settings import SettingsRegistry + + +class PtychoPINNModelSettings(Observable, Observer): + def __init__(self, registry: SettingsRegistry) -> None: + super().__init__() + self._group = registry.create_group('PtychoPINNModel') + self._group.add_observer(self) + + self.gridsize = self._group.create_integer_parameter('gridsize', 1, minimum=1, maximum=5) + self.n_filters_scale = self._group.create_integer_parameter( + 'n_filters_scale', 2, minimum=1, maximum=4 + ) + self.amp_activation = self._group.create_string_parameter('amp_activation', 'sigmoid') + self.object_big = self._group.create_boolean_parameter('object_big', True) + self.probe_big = self._group.create_boolean_parameter('probe_big', True) + self.probe_mask = self._group.create_boolean_parameter('probe_mask', False) + self.pad_object = self._group.create_boolean_parameter('pad_object', True) + self.probe_scale = self._group.create_real_parameter('probe_scale', 4.0, minimum=0.0) + self.gaussian_smoothing_sigma = self._group.create_real_parameter( + 'gaussian_smoothing_sigma', 0.0, minimum=0.0 + ) + + def _update(self, observable: Observable) -> None: + if observable is self._group: + self.notify_observers() + + +class PtychoPINNTrainingSettings(Observable, Observer): + def __init__(self, registry: SettingsRegistry) -> None: + super().__init__() + self._group = registry.create_group('PtychoPINNTraining') + self._group.add_observer(self) + + self.nphotons = self._group.create_real_parameter('NPhotons', 1e6) # TODO remove + self.data_dir = self._group.create_path_parameter( + 'data_dir', Path('/path/to/training_data') + ) + self.batch_size = self._group.create_integer_parameter( + 'batch_size', 16, minimum=1, maximum=1 << 30 + ) # must be positive powers of two + self.nepochs = self._group.create_integer_parameter('nepochs', 50, minimum=1) + self.mae_weight = self._group.create_real_parameter( + 'mae_weight', 0.0, minimum=0.0, maximum=1.0 + ) + self.nll_weight = self._group.create_real_parameter( + 'nll_weight', 1.0, minimum=0.0, maximum=1.0 + ) + self.realspace_mae_weight = self._group.create_real_parameter( + 'realspace_mae_weight', 0.0, minimum=0.0, maximum=1.0 + ) + self.realspace_weight = self._group.create_real_parameter( + 'realspace_weight', 0.0, minimum=0.0, maximum=1.0 + ) + self.positions_provided = self._group.create_boolean_parameter('positions_provided', True) + self.probe_trainable = self._group.create_boolean_parameter('probe_trainable', False) + self.intensity_scale_trainable = self._group.create_boolean_parameter( + 'intensity_scale_trainable', True + ) + self.output_dir = self._group.create_path_parameter( + 'output_dir', Path('/path/to/output_data') + ) + + def _update(self, observable: Observable) -> None: + if observable is self._group: + self.notify_observers() + + +class PtychoPINNInferenceSettings(Observable, Observer): + def __init__(self, registry: SettingsRegistry) -> None: + super().__init__() + self._group = registry.create_group('PtychoPINNInference') + self._group.add_observer(self) + + self.model_path = self._group.create_path_parameter( + 'model_path', Path('/path/to/model.zip') + ) + self.n_nearest_neighbors = self._group.create_integer_parameter( + 'n_nearest_neighbors', 7, minimum=0 + ) + self.n_samples = self._group.create_integer_parameter('n_samples', 1, minimum=1) + + def _update(self, observable: Observable) -> None: + if observable is self._group: + self.notify_observers() diff --git a/src/ptychodus/model/reconstructor/api.py b/src/ptychodus/model/reconstructor/api.py index c04fd699..c2f0880c 100644 --- a/src/ptychodus/model/reconstructor/api.py +++ b/src/ptychodus/model/reconstructor/api.py @@ -1,3 +1,4 @@ +from collections.abc import Sequence from pathlib import Path import logging import time @@ -9,8 +10,9 @@ TrainOutput, ) -from ..product import ProductRepository +from ..product import ProductAPI from .matcher import DiffractionPatternPositionMatcher, ScanIndexFilter +from .queue import ReconstructionQueue logger = logging.getLogger(__name__) @@ -18,142 +20,138 @@ class ReconstructorAPI: def __init__( self, - dataMatcher: DiffractionPatternPositionMatcher, - productRepository: ProductRepository, - reconstructorChooser: PluginChooser[Reconstructor], + reconstruction_queue: ReconstructionQueue, + data_matcher: DiffractionPatternPositionMatcher, + product_api: ProductAPI, + reconstructor_chooser: PluginChooser[Reconstructor], ) -> None: - self._dataMatcher = dataMatcher - self._productRepository = productRepository - self._reconstructorChooser = reconstructorChooser + self._reconstruction_queue = reconstruction_queue + self._data_matcher = data_matcher + self._product_api = product_api + self._reconstructor_chooser = reconstructor_chooser + + @property + def is_reconstructing(self) -> bool: + return self._reconstruction_queue.is_reconstructing + + def process_results(self, *, block: bool) -> None: + self._reconstruction_queue.process_results(block=block) def reconstruct( self, - inputProductIndex: int, - outputProductName: str, - indexFilter: ScanIndexFilter = ScanIndexFilter.ALL, + input_product_index: int, + *, + output_product_suffix: str = '', + transform: int | None = None, + index_filter: ScanIndexFilter = ScanIndexFilter.ALL, ) -> int: - reconstructor = self._reconstructorChooser.currentPlugin.strategy - parameters = self._dataMatcher.matchDiffractionPatternsWithPositions( - inputProductIndex, indexFilter - ) + reconstructor = self._reconstructor_chooser.get_current_plugin() + input_product_item = self._product_api.get_item(input_product_index) + output_product_index = self._product_api.insert_product(input_product_item.get_product()) + output_product_item = self._product_api.get_item(output_product_index) + output_product_name = f'{input_product_item.get_name()}_{reconstructor.simple_name}' - outputProductIndex = self._productRepository.insertNewProduct(likeIndex=inputProductIndex) - outputProduct = self._productRepository[outputProductIndex] + if output_product_suffix: + output_product_name += f'_{output_product_suffix}' - tic = time.perf_counter() - result = reconstructor.reconstruct(parameters) - toc = time.perf_counter() - logger.info(f'Reconstruction time {toc - tic:.4f} seconds. (code={result.result})') + output_product_item.set_name(output_product_name) - outputProduct.assign(result.product) + if transform is not None: + scan_item_transform = output_product_item.get_scan_item().get_transform() + scan_item_transform.apply_presets(transform) - return outputProductIndex + object_item = output_product_item.get_object_item() + object_item.rebuild(recenter=True) - def reconstructSplit(self, inputProductIndex: int, outputProductName: str) -> tuple[int, int]: - outputProductIndexOdd = self.reconstruct( - inputProductIndex, - f'{outputProductName}_odd', - ScanIndexFilter.ODD, + self._reconstruction_queue.put(reconstructor.strategy, output_product_index, index_filter) + return output_product_index + + def reconstruct_split(self, input_product_index: int) -> tuple[int, int]: + output_product_index_odd = self.reconstruct( + input_product_index, + output_product_suffix='odd', + index_filter=ScanIndexFilter.ODD, ) - outputProductIndexEven = self.reconstruct( - inputProductIndex, - f'{outputProductName}_even', - ScanIndexFilter.EVEN, + output_product_index_even = self.reconstruct( + input_product_index, + output_product_suffix='even', + index_filter=ScanIndexFilter.EVEN, ) - return outputProductIndexOdd, outputProductIndexEven + return output_product_index_odd, output_product_index_even - def ingestTrainingData(self, inputProductIndex: int) -> None: - reconstructor = self._reconstructorChooser.currentPlugin.strategy + def reconstruct_transformed(self, input_product_index: int) -> Sequence[int]: + output_product_indexes: list[int] = [] + input_product = self._product_api.get_item(input_product_index) - if isinstance(reconstructor, TrainableReconstructor): - logger.info('Preparing input data...') - tic = time.perf_counter() - parameters = self._dataMatcher.matchDiffractionPatternsWithPositions( - inputProductIndex, ScanIndexFilter.ALL + for preset_value, preset_label in enumerate( + input_product.get_scan_item().get_transform().labels_for_presets() + ): + output_product_index = self.reconstruct( + input_product_index, + output_product_suffix=preset_label, + transform=preset_value, + index_filter=ScanIndexFilter.ALL, ) - toc = time.perf_counter() - logger.info(f'Data preparation time {toc - tic:.4f} seconds.') + output_product_indexes.append(output_product_index) - logger.info('Ingesting...') - tic = time.perf_counter() - reconstructor.ingestTrainingData(parameters) - toc = time.perf_counter() - logger.info(f'Ingest time {toc - tic:.4f} seconds.') - else: - logger.warning('Reconstructor is not trainable!') + return output_product_indexes - def openTrainingData(self, filePath: Path) -> None: - reconstructor = self._reconstructorChooser.currentPlugin.strategy + def open_model(self, file_path: Path) -> None: + reconstructor = self._reconstructor_chooser.get_current_plugin().strategy if isinstance(reconstructor, TrainableReconstructor): - logger.info('Opening training data...') + logger.info('Opening model...') tic = time.perf_counter() - reconstructor.openTrainingData(filePath) + reconstructor.open_model(file_path) toc = time.perf_counter() logger.info(f'Open time {toc - tic:.4f} seconds.') else: logger.warning('Reconstructor is not trainable!') - def saveTrainingData(self, filePath: Path) -> None: - reconstructor = self._reconstructorChooser.currentPlugin.strategy + def save_model(self, file_path: Path) -> None: + reconstructor = self._reconstructor_chooser.get_current_plugin().strategy if isinstance(reconstructor, TrainableReconstructor): - logger.info('Saving training data...') + logger.info('Saving model...') tic = time.perf_counter() - reconstructor.saveTrainingData(filePath) + reconstructor.save_model(file_path) toc = time.perf_counter() logger.info(f'Save time {toc - tic:.4f} seconds.') else: logger.warning('Reconstructor is not trainable!') - def train(self) -> TrainOutput: - reconstructor = self._reconstructorChooser.currentPlugin.strategy - result = TrainOutput([], [], -1) + def export_training_data(self, file_path: Path, input_product_index: int) -> None: + reconstructor = self._reconstructor_chooser.get_current_plugin().strategy if isinstance(reconstructor, TrainableReconstructor): - logger.info('Training...') + logger.info('Preparing input data...') tic = time.perf_counter() - result = reconstructor.train() + parameters = self._data_matcher.match_diffraction_patterns_with_positions( + input_product_index, ScanIndexFilter.ALL + ) toc = time.perf_counter() - logger.info(f'Training time {toc - tic:.4f} seconds. (code={result.result})') - else: - logger.warning('Reconstructor is not trainable!') - - return result - - def clearTrainingData(self) -> None: - reconstructor = self._reconstructorChooser.currentPlugin.strategy + logger.info(f'Data preparation time {toc - tic:.4f} seconds.') - if isinstance(reconstructor, TrainableReconstructor): - logger.info('Resetting...') + logger.info('Exporting...') tic = time.perf_counter() - reconstructor.clearTrainingData() + reconstructor.export_training_data(file_path, parameters) toc = time.perf_counter() - logger.info(f'Reset time {toc - tic:.4f} seconds.') + logger.info(f'Export time {toc - tic:.4f} seconds.') else: logger.warning('Reconstructor is not trainable!') - def openModel(self, filePath: Path) -> None: - reconstructor = self._reconstructorChooser.currentPlugin.strategy + def train(self, data_path: Path) -> TrainOutput: + reconstructor = self._reconstructor_chooser.get_current_plugin().strategy + result = TrainOutput([], -1) if isinstance(reconstructor, TrainableReconstructor): - logger.info('Opening model...') + logger.info('Training...') tic = time.perf_counter() - reconstructor.openModel(filePath) + result = reconstructor.train(data_path) toc = time.perf_counter() - logger.info(f'Open time {toc - tic:.4f} seconds.') + logger.info(f'Training time {toc - tic:.4f} seconds. (code={result.result})') else: logger.warning('Reconstructor is not trainable!') - def saveModel(self, filePath: Path) -> None: - reconstructor = self._reconstructorChooser.currentPlugin.strategy - - if isinstance(reconstructor, TrainableReconstructor): - logger.info('Saving model...') - tic = time.perf_counter() - reconstructor.saveModel(filePath) - toc = time.perf_counter() - logger.info(f'Save time {toc - tic:.4f} seconds.') - else: - logger.warning('Reconstructor is not trainable!') + return result diff --git a/src/ptychodus/model/reconstructor/core.py b/src/ptychodus/model/reconstructor/core.py index b501efcf..1edb1be8 100644 --- a/src/ptychodus/model/reconstructor/core.py +++ b/src/ptychodus/model/reconstructor/core.py @@ -9,41 +9,61 @@ ) from ptychodus.api.settings import SettingsRegistry -from ..patterns import ActiveDiffractionDataset -from ..product import ProductRepository +from ..patterns import AssembledDiffractionDataset +from ..product import ProductAPI from .api import ReconstructorAPI +from .log import ReconstructorLogHandler from .matcher import DiffractionPatternPositionMatcher from .presenter import ReconstructorPresenter +from .queue import ReconstructionQueue from .settings import ReconstructorSettings -logger = logging.getLogger(__name__) - class ReconstructorCore: def __init__( self, - settingsRegistry: SettingsRegistry, - diffractionDataset: ActiveDiffractionDataset, - productRepository: ProductRepository, - librarySeq: Sequence[ReconstructorLibrary], + settings_registry: SettingsRegistry, + dataset: AssembledDiffractionDataset, + product_api: ProductAPI, + library_seq: Sequence[ReconstructorLibrary], ) -> None: - self.settings = ReconstructorSettings(settingsRegistry) - self._pluginChooser = PluginChooser[Reconstructor]() + self.settings = ReconstructorSettings(settings_registry) + self._plugin_chooser = PluginChooser[Reconstructor]() + self._log_handler = ReconstructorLogHandler() + self._log_handler.setFormatter( + logging.Formatter('%(asctime)s [%(levelname)s] %(name)s: %(message)s') + ) - for library in librarySeq: + for library in library_seq: for reconstructor in library: - self._pluginChooser.registerPlugin( + self._plugin_chooser.register_plugin( reconstructor, - displayName=f'{library.name}/{reconstructor.name}', + simple_name=f'{library.name}_{reconstructor.name}', + display_name=f'{library.name}/{reconstructor.name}', ) - if not self._pluginChooser: - self._pluginChooser.registerPlugin(NullReconstructor('None'), displayName='None/None') + library_logger = logging.getLogger(library.logger_name) + library_logger.addHandler(self._log_handler) + + if not self._plugin_chooser: + self._plugin_chooser.register_plugin( + NullReconstructor('None'), display_name='None/None' + ) - self.dataMatcher = DiffractionPatternPositionMatcher(diffractionDataset, productRepository) - self.reconstructorAPI = ReconstructorAPI( - self.dataMatcher, productRepository, self._pluginChooser + self.data_matcher = DiffractionPatternPositionMatcher(dataset, product_api) + self._reconstruction_queue = ReconstructionQueue(self.data_matcher) + self.reconstructor_api = ReconstructorAPI( + self._reconstruction_queue, self.data_matcher, product_api, self._plugin_chooser ) self.presenter = ReconstructorPresenter( - self.settings, self._pluginChooser, self.reconstructorAPI, settingsRegistry + self.settings, + self._plugin_chooser, + self._log_handler, + self.reconstructor_api, ) + + def start(self) -> None: + self._reconstruction_queue.start() + + def stop(self) -> None: + self._reconstruction_queue.stop() diff --git a/src/ptychodus/model/reconstructor/log.py b/src/ptychodus/model/reconstructor/log.py new file mode 100644 index 00000000..ee3c5d47 --- /dev/null +++ b/src/ptychodus/model/reconstructor/log.py @@ -0,0 +1,21 @@ +from collections.abc import Iterator +import queue +import logging + + +class ReconstructorLogHandler(logging.Handler): + def __init__(self) -> None: + super().__init__() + self._log: queue.Queue[str] = queue.Queue() + + def messages(self) -> Iterator[str]: + while True: + try: + yield self._log.get(block=False) + self._log.task_done() + except queue.Empty: + break + + def emit(self, record: logging.LogRecord) -> None: + text = self.format(record) + self._log.put(text) diff --git a/src/ptychodus/model/reconstructor/matcher.py b/src/ptychodus/model/reconstructor/matcher.py index a3c47c7f..6c974fa9 100644 --- a/src/ptychodus/model/reconstructor/matcher.py +++ b/src/ptychodus/model/reconstructor/matcher.py @@ -6,10 +6,10 @@ from ptychodus.api.geometry import PixelGeometry from ptychodus.api.product import Product from ptychodus.api.reconstructor import ReconstructInput -from ptychodus.api.scan import Scan, ScanPoint +from ptychodus.api.scan import PositionSequence, ScanPoint -from ..patterns import ActiveDiffractionDataset -from ..product import ProductRepository +from ..patterns import AssembledDiffractionDataset +from ..product import ProductAPI, ProductRepositoryItem logger = logging.getLogger(__name__) @@ -34,57 +34,56 @@ def __call__(self, index: int) -> bool: class DiffractionPatternPositionMatcher: def __init__( self, - diffractionDataset: ActiveDiffractionDataset, - productRepository: ProductRepository, + dataset: AssembledDiffractionDataset, + product_api: ProductAPI, ) -> None: - self._diffractionDataset = diffractionDataset - self._productRepository = productRepository + self._dataset = dataset + self._product_api = product_api - def getProductName(self, inputProductIndex: int) -> str: - inputProductItem = self._productRepository[inputProductIndex] - return inputProductItem.getName() + def get_product_item(self, input_product_index: int) -> ProductRepositoryItem: + return self._product_api.get_item(input_product_index) - def getObjectPlanePixelGeometry(self, inputProductIndex: int) -> PixelGeometry: - inputProductItem = self._productRepository[inputProductIndex] - objectGeometry = inputProductItem.getGeometry().getObjectGeometry() - return objectGeometry.getPixelGeometry() + def get_object_plane_pixel_geometry(self, input_product_index: int) -> PixelGeometry: + input_product_item = self._product_api.get_item(input_product_index) + object_geometry = input_product_item.get_geometry().get_object_geometry() + return object_geometry.get_pixel_geometry() - def matchDiffractionPatternsWithPositions( - self, inputProductIndex: int, indexFilter: ScanIndexFilter = ScanIndexFilter.ALL + def match_diffraction_patterns_with_positions( + self, input_product_index: int, index_filter: ScanIndexFilter = ScanIndexFilter.ALL ) -> ReconstructInput: - goodPixelMask = self._diffractionDataset.getGoodPixelMask() - - inputProductItem = self._productRepository[inputProductIndex] - inputProduct = inputProductItem.getProduct() - dataIndexes = self._diffractionDataset.getAssembledIndexes() - scanIndexes = [point.index for point in inputProduct.scan if indexFilter(point.index)] - commonIndexes = sorted(set(dataIndexes).intersection(scanIndexes)) + input_product_item = self._product_api.get_item(input_product_index) + input_product = input_product_item.get_product() + data_indexes = self._dataset.get_assembled_indexes() + scan_indexes = [ + point.index for point in input_product.positions if index_filter(point.index) + ] + common_indexes = sorted(set(data_indexes).intersection(scan_indexes)) patterns = numpy.take( - self._diffractionDataset.getAssembledData(), - commonIndexes, + self._dataset.get_assembled_patterns(), + common_indexes, axis=0, ) - pointList: list[ScanPoint] = list() - pointIter = iter(inputProduct.scan) + point_list: list[ScanPoint] = list() + point_iterator = iter(input_product.positions) - for index in commonIndexes: + for index in common_indexes: while True: - point = next(pointIter) + point = next(point_iterator) if point.index == index: - pointList.append(point) + point_list.append(point) break - probe = inputProduct.probe # TODO remap if needed + probe = input_product.probes # TODO remap if needed product = Product( - metadata=inputProduct.metadata, - scan=Scan(pointList), - probe=probe, - object_=inputProduct.object_, - costs=inputProduct.costs, + metadata=input_product.metadata, + positions=PositionSequence(point_list), + probes=probe, + object_=input_product.object_, + costs=input_product.costs, ) - return ReconstructInput(patterns, goodPixelMask, product) + return ReconstructInput(patterns, self._dataset.get_processed_bad_pixels(), product) diff --git a/src/ptychodus/model/reconstructor/presenter.py b/src/ptychodus/model/reconstructor/presenter.py index d310c2e6..ce36fba8 100644 --- a/src/ptychodus/model/reconstructor/presenter.py +++ b/src/ptychodus/model/reconstructor/presenter.py @@ -1,4 +1,4 @@ -from collections.abc import Sequence +from collections.abc import Iterator, Sequence from pathlib import Path import logging @@ -11,7 +11,7 @@ ) from .api import ReconstructorAPI -from .matcher import ScanIndexFilter +from .log import ReconstructorLogHandler from .settings import ReconstructorSettings logger = logging.getLogger(__name__) @@ -21,155 +21,96 @@ class ReconstructorPresenter(Observable, Observer): def __init__( self, settings: ReconstructorSettings, - reconstructorChooser: PluginChooser[Reconstructor], - reconstructorAPI: ReconstructorAPI, - reinitObservable: Observable, + reconstructor_chooser: PluginChooser[Reconstructor], + log_handler: ReconstructorLogHandler, + reconstructor_api: ReconstructorAPI, ) -> None: super().__init__() self._settings = settings - self._reconstructorChooser = reconstructorChooser - self._reconstructorAPI = reconstructorAPI - self._reinitObservable = reinitObservable + self._reconstructor_chooser = reconstructor_chooser + self._log_handler = log_handler + self._reconstructor_api = reconstructor_api - reconstructorChooser.addObserver(self) - reinitObservable.addObserver(self) - self._syncFromSettings() + reconstructor_chooser.synchronize_with_parameter(settings.algorithm) + reconstructor_chooser.add_observer(self) - def getReconstructorList(self) -> Sequence[str]: - return self._reconstructorChooser.getDisplayNameList() + def reconstructors(self) -> Iterator[str]: + for plugin in self._reconstructor_chooser: + yield plugin.display_name - def getReconstructor(self) -> str: - return self._reconstructorChooser.currentPlugin.displayName + def get_reconstructor(self) -> str: + return self._reconstructor_chooser.get_current_plugin().display_name - def setReconstructor(self, name: str) -> None: - self._reconstructorChooser.setCurrentPluginByName(name) + def set_reconstructor(self, name: str) -> None: + self._reconstructor_chooser.set_current_plugin(name) - def _syncFromSettings(self) -> None: - self.setReconstructor(self._settings.algorithm.getValue()) + def reconstruct(self, input_product_index: int) -> int: + return self._reconstructor_api.reconstruct(input_product_index) - def _syncToSettings(self) -> None: - self._settings.algorithm.setValue(self._reconstructorChooser.currentPlugin.simpleName) + def reconstruct_split(self, input_product_index: int) -> tuple[int, int]: + return self._reconstructor_api.reconstruct_split(input_product_index) - def reconstruct( - self, - inputProductIndex: int, - outputProductName: str, - indexFilter: ScanIndexFilter = ScanIndexFilter.ALL, - ) -> int: - return self._reconstructorAPI.reconstruct(inputProductIndex, outputProductName, indexFilter) - - def reconstructSplit(self, inputProductIndex: int, outputProductName: str) -> tuple[int, int]: - return self._reconstructorAPI.reconstructSplit(inputProductIndex, outputProductName) + def reconstruct_transformed(self, input_product_index: int) -> Sequence[int]: + return self._reconstructor_api.reconstruct_transformed(input_product_index) @property - def isTrainable(self) -> bool: - reconstructor = self._reconstructorChooser.currentPlugin.strategy - return isinstance(reconstructor, TrainableReconstructor) - - def ingestTrainingData(self, inputProductIndex: int) -> None: - return self._reconstructorAPI.ingestTrainingData(inputProductIndex) - - def getOpenTrainingDataFileFilterList(self) -> Sequence[str]: - reconstructor = self._reconstructorChooser.currentPlugin.strategy + def is_reconstructing(self) -> bool: + return self._reconstructor_api.is_reconstructing - if isinstance(reconstructor, TrainableReconstructor): - return reconstructor.getOpenTrainingDataFileFilterList() - else: - logger.warning('Reconstructor is not trainable!') + def flush_log(self) -> Iterator[str]: + for text in self._log_handler.messages(): + yield text - return list() + def process_results(self, *, block: bool) -> None: + self._reconstructor_api.process_results(block=block) - def getOpenTrainingDataFileFilter(self) -> str: - reconstructor = self._reconstructorChooser.currentPlugin.strategy - - if isinstance(reconstructor, TrainableReconstructor): - return reconstructor.getOpenTrainingDataFileFilter() - else: - logger.warning('Reconstructor is not trainable!') - - return str() - - def openTrainingData(self, filePath: Path) -> None: - return self._reconstructorAPI.openTrainingData(filePath) - - def getSaveTrainingDataFileFilterList(self) -> Sequence[str]: - reconstructor = self._reconstructorChooser.currentPlugin.strategy - - if isinstance(reconstructor, TrainableReconstructor): - return reconstructor.getSaveTrainingDataFileFilterList() - else: - logger.warning('Reconstructor is not trainable!') - - return list() + @property + def is_trainable(self) -> bool: + reconstructor = self._reconstructor_chooser.get_current_plugin().strategy + return isinstance(reconstructor, TrainableReconstructor) - def getSaveTrainingDataFileFilter(self) -> str: - reconstructor = self._reconstructorChooser.currentPlugin.strategy + def get_model_file_filter(self) -> str: + reconstructor = self._reconstructor_chooser.get_current_plugin().strategy if isinstance(reconstructor, TrainableReconstructor): - return reconstructor.getSaveTrainingDataFileFilter() + return reconstructor.get_model_file_filter() else: logger.warning('Reconstructor is not trainable!') return str() - def saveTrainingData(self, filePath: Path) -> None: - return self._reconstructorAPI.saveTrainingData(filePath) - - def train(self) -> TrainOutput: - return self._reconstructorAPI.train() - - def clearTrainingData(self) -> None: - self._reconstructorAPI.clearTrainingData() - - def getOpenModelFileFilterList(self) -> Sequence[str]: - reconstructor = self._reconstructorChooser.currentPlugin.strategy - - if isinstance(reconstructor, TrainableReconstructor): - return reconstructor.getOpenModelFileFilterList() - else: - logger.warning('Reconstructor is not trainable!') + def open_model(self, file_path: Path) -> None: + return self._reconstructor_api.open_model(file_path) - return list() + def save_model(self, file_path: Path) -> None: + return self._reconstructor_api.save_model(file_path) - def getOpenModelFileFilter(self) -> str: - reconstructor = self._reconstructorChooser.currentPlugin.strategy + def get_training_data_file_filter(self) -> str: + reconstructor = self._reconstructor_chooser.get_current_plugin().strategy if isinstance(reconstructor, TrainableReconstructor): - return reconstructor.getOpenModelFileFilter() + return reconstructor.get_training_data_file_filter() else: logger.warning('Reconstructor is not trainable!') return str() - def openModel(self, filePath: Path) -> None: - return self._reconstructorAPI.openModel(filePath) + def export_training_data(self, file_path: Path, input_product_index: int) -> None: + return self._reconstructor_api.export_training_data(file_path, input_product_index) - def getSaveModelFileFilterList(self) -> Sequence[str]: - reconstructor = self._reconstructorChooser.currentPlugin.strategy + def get_training_data_path(self) -> Path: + reconstructor = self._reconstructor_chooser.get_current_plugin().strategy if isinstance(reconstructor, TrainableReconstructor): - return reconstructor.getSaveModelFileFilterList() + return reconstructor.get_training_data_path() else: logger.warning('Reconstructor is not trainable!') - return list() - - def getSaveModelFileFilter(self) -> str: - reconstructor = self._reconstructorChooser.currentPlugin.strategy - - if isinstance(reconstructor, TrainableReconstructor): - return reconstructor.getSaveModelFileFilter() - else: - logger.warning('Reconstructor is not trainable!') - - return str() + return Path() - def saveModel(self, filePath: Path) -> None: - return self._reconstructorAPI.saveModel(filePath) + def train(self, data_path: Path) -> TrainOutput: + return self._reconstructor_api.train(data_path) - def update(self, observable: Observable) -> None: - if observable is self._reconstructorChooser: - self._syncToSettings() - self.notifyObservers() - elif observable is self._reinitObservable: - self._syncFromSettings() + def _update(self, observable: Observable) -> None: + if observable is self._reconstructor_chooser: + self.notify_observers() diff --git a/src/ptychodus/model/reconstructor/queue.py b/src/ptychodus/model/reconstructor/queue.py new file mode 100644 index 00000000..8b6fcf47 --- /dev/null +++ b/src/ptychodus/model/reconstructor/queue.py @@ -0,0 +1,135 @@ +from __future__ import annotations +from abc import ABC, abstractmethod +import logging +import queue +import threading +import time + +from ptychodus.api.reconstructor import Reconstructor, ReconstructInput, ReconstructOutput + +from ..product import ProductRepositoryItem +from .matcher import DiffractionPatternPositionMatcher, ScanIndexFilter + +logger = logging.getLogger(__name__) + +__all__ = ['ReconstructionQueue'] + + +class ReconstructionTask(ABC): + @abstractmethod + def execute(self) -> ReconstructionTask | None: + pass + + +class UpdateProductTask(ReconstructionTask): + def __init__(self, product: ProductRepositoryItem, result: ReconstructOutput) -> None: + self._product = product + self._result = result + + def execute(self) -> None: + name = self._product.get_name() + self._product.assign(self._result.product) + self._product.set_name(name) + + +class ExecuteReconstructorTask(ReconstructionTask): + def __init__( + self, + data_matcher: DiffractionPatternPositionMatcher, + reconstructor: Reconstructor, + product_index: int, + index_filter: ScanIndexFilter = ScanIndexFilter.ALL, + ) -> None: + self._data_matcher = data_matcher + self._reconstructor = reconstructor + self._product_index = product_index + self._index_filter = index_filter + + def execute(self) -> UpdateProductTask: + product_item = self._data_matcher.get_product_item(self._product_index) + logger.info(f'Reconstructing {product_item.get_name()}...') + + logger.info('Preparing input data...') + tic = time.perf_counter() + parameters = self._data_matcher.match_diffraction_patterns_with_positions( + self._product_index, self._index_filter + ) + toc = time.perf_counter() + logger.info(f'Data preparation time {toc - tic:.4f} seconds.') + + tic = time.perf_counter() + result = self._reconstructor.reconstruct(parameters) + toc = time.perf_counter() + logger.info(f'Reconstruction time {toc - tic:.4f} seconds. (code={result.result})') + + return UpdateProductTask(product_item, result) + + +class ReconstructionQueue: + def __init__(self, data_matcher: DiffractionPatternPositionMatcher) -> None: + self._data_matcher = data_matcher + self._input_queue: queue.Queue[ExecuteReconstructorTask] = queue.Queue() + self._output_queue: queue.Queue[UpdateProductTask] = queue.Queue() + self._stop_work_event = threading.Event() + self._worker = threading.Thread(target=self._reconstruct) + + @property + def is_reconstructing(self) -> bool: + return self._input_queue.unfinished_tasks > 0 + + def _reconstruct(self) -> None: + while not self._stop_work_event.is_set(): + try: + input_task = self._input_queue.get(block=True, timeout=1) + + try: + output_task = input_task.execute() + except Exception: + logger.exception('Reconstructor error!') + else: + self._output_queue.put(output_task) + finally: + self._input_queue.task_done() + except queue.Empty: + pass + + def put( + self, + reconstructor: Reconstructor, + product_index: int, + index_filter: ScanIndexFilter = ScanIndexFilter.ALL, + ) -> None: + task = ExecuteReconstructorTask( + self._data_matcher, reconstructor, product_index, index_filter + ) + self._input_queue.put(task) + + def process_results(self, *, block: bool) -> None: + while True: + try: + task = self._output_queue.get(block=block) + + try: + task.execute() + finally: + self._output_queue.task_done() + except queue.Empty: + break + + def start(self) -> None: + logger.info('Starting reconstructor...') + self._worker.start() + logger.info('Reconstructor started.') + + def stop(self) -> None: + logger.info('Finishing reconstructions...') + self._input_queue.join() + + logger.info('Stopping reconstructor...') + self._stop_work_event.set() + self._worker.join() + self.process_results(block=False) + logger.info('Reconstructor stopped.') + + def __len__(self) -> int: + return self._input_queue.qsize() diff --git a/src/ptychodus/model/reconstructor/settings.py b/src/ptychodus/model/reconstructor/settings.py index a9bc8fc3..7e8fcacc 100644 --- a/src/ptychodus/model/reconstructor/settings.py +++ b/src/ptychodus/model/reconstructor/settings.py @@ -5,11 +5,11 @@ class ReconstructorSettings(Observable, Observer): def __init__(self, registry: SettingsRegistry) -> None: super().__init__() - self._settingsGroup = registry.createGroup('Reconstructor') - self._settingsGroup.addObserver(self) + self._group = registry.create_group('Reconstructor') + self._group.add_observer(self) - self.algorithm = self._settingsGroup.createStringParameter('Algorithm', 'Tike/lstsq_grad') + self.algorithm = self._group.create_string_parameter('Algorithm', 'Tike/lstsq_grad') - def update(self, observable: Observable) -> None: - if observable is self._settingsGroup: - self.notifyObservers() + def _update(self, observable: Observable) -> None: + if observable is self._group: + self.notify_observers() diff --git a/src/ptychodus/model/tike/core.py b/src/ptychodus/model/tike/core.py index cc7c2d79..877cb23b 100644 --- a/src/ptychodus/model/tike/core.py +++ b/src/ptychodus/model/tike/core.py @@ -1,5 +1,6 @@ from __future__ import annotations from collections.abc import Iterator +from importlib.metadata import version import logging from ptychodus.api.reconstructor import ( @@ -21,21 +22,21 @@ class TikeReconstructorLibrary(ReconstructorLibrary): - def __init__(self, settingsRegistry: SettingsRegistry) -> None: + def __init__(self, settings_registry: SettingsRegistry) -> None: super().__init__() - self.settings = TikeSettings(settingsRegistry) - self.multigridSettings = TikeMultigridSettings(settingsRegistry) - self.positionCorrectionSettings = TikePositionCorrectionSettings(settingsRegistry) - self.probeCorrectionSettings = TikeProbeCorrectionSettings(settingsRegistry) - self.objectCorrectionSettings = TikeObjectCorrectionSettings(settingsRegistry) + self.settings = TikeSettings(settings_registry) + self.multigrid_settings = TikeMultigridSettings(settings_registry) + self.position_correction_settings = TikePositionCorrectionSettings(settings_registry) + self.probe_correction_settings = TikeProbeCorrectionSettings(settings_registry) + self.object_correction_settings = TikeObjectCorrectionSettings(settings_registry) - self.reconstructorList: list[Reconstructor] = list() + self.reconstructor_list: list[Reconstructor] = list() @classmethod - def createInstance( - cls, settingsRegistry: SettingsRegistry, isDeveloperModeEnabled: bool + def create_instance( + cls, settings_registry: SettingsRegistry, is_developer_mode_enabled: bool ) -> TikeReconstructorLibrary: - core = cls(settingsRegistry) + core = cls(settings_registry) try: from .reconstructor import IterativeLeastSquaresReconstructor @@ -44,19 +45,22 @@ def createInstance( except ModuleNotFoundError: logger.info('Tike not found.') - if isDeveloperModeEnabled: - core.reconstructorList.append(NullReconstructor('rpie')) - core.reconstructorList.append(NullReconstructor('lstsq_grad')) + if is_developer_mode_enabled: + core.reconstructor_list.append(NullReconstructor('rpie')) + core.reconstructor_list.append(NullReconstructor('lstsq_grad')) else: - tikeReconstructor = TikeReconstructor( + tike_version = version('tike') + logger.info(f'Tike {tike_version}') + + tike_reconstructor = TikeReconstructor( core.settings, - core.multigridSettings, - core.positionCorrectionSettings, - core.probeCorrectionSettings, - core.objectCorrectionSettings, + core.multigrid_settings, + core.position_correction_settings, + core.probe_correction_settings, + core.object_correction_settings, ) - core.reconstructorList.append(RegularizedPIEReconstructor(tikeReconstructor)) - core.reconstructorList.append(IterativeLeastSquaresReconstructor(tikeReconstructor)) + core.reconstructor_list.append(RegularizedPIEReconstructor(tike_reconstructor)) + core.reconstructor_list.append(IterativeLeastSquaresReconstructor(tike_reconstructor)) return core @@ -64,5 +68,9 @@ def createInstance( def name(self) -> str: return 'Tike' + @property + def logger_name(self) -> str: + return 'tike' + def __iter__(self) -> Iterator[Reconstructor]: - return iter(self.reconstructorList) + return iter(self.reconstructor_list) diff --git a/src/ptychodus/model/tike/reconstructor.py b/src/ptychodus/model/tike/reconstructor.py index c97469b2..4aee73c5 100644 --- a/src/ptychodus/model/tike/reconstructor.py +++ b/src/ptychodus/model/tike/reconstructor.py @@ -1,4 +1,3 @@ -from importlib.metadata import version from typing import Any import logging import pprint @@ -9,14 +8,14 @@ import tike.ptycho from ptychodus.api.object import Object, ObjectPoint -from ptychodus.api.probe import Probe +from ptychodus.api.probe import ProbeSequence from ptychodus.api.product import Product from ptychodus.api.reconstructor import ( Reconstructor, ReconstructInput, ReconstructOutput, ) -from ptychodus.api.scan import Scan, ScanPoint +from ptychodus.api.scan import PositionSequence, ScanPoint from .settings import ( TikeMultigridSettings, @@ -33,254 +32,253 @@ class TikeReconstructor: def __init__( self, settings: TikeSettings, - multigridSettings: TikeMultigridSettings, - positionCorrectionSettings: TikePositionCorrectionSettings, - probeCorrectionSettings: TikeProbeCorrectionSettings, - objectCorrectionSettings: TikeObjectCorrectionSettings, + multigrid_settings: TikeMultigridSettings, + position_correction_settings: TikePositionCorrectionSettings, + probe_correction_settings: TikeProbeCorrectionSettings, + object_correction_settings: TikeObjectCorrectionSettings, ) -> None: self._settings = settings - self._multigridSettings = multigridSettings - self._positionCorrectionSettings = positionCorrectionSettings - self._probeCorrectionSettings = probeCorrectionSettings - self._objectCorrectionSettings = objectCorrectionSettings + self._multigrid_settings = multigrid_settings + self._position_correction_settings = position_correction_settings + self._probe_correction_settings = probe_correction_settings + self._object_correction_settings = object_correction_settings - tikeVersion = version('tike') - logger.info(f'\tTike {tikeVersion}') - - def getObjectOptions(self) -> tike.ptycho.ObjectOptions: - settings = self._objectCorrectionSettings + def get_object_options(self) -> tike.ptycho.ObjectOptions: + settings = self._object_correction_settings options = None - if settings.useObjectCorrection.getValue(): + if settings.use_object_correction.get_value(): options = tike.ptycho.ObjectOptions( - positivity_constraint=float(settings.positivityConstraint.getValue()), - smoothness_constraint=float(settings.smoothnessConstraint.getValue()), - use_adaptive_moment=settings.useAdaptiveMoment.getValue(), - vdecay=float(settings.vdecay.getValue()), - mdecay=float(settings.mdecay.getValue()), - clip_magnitude=settings.useMagnitudeClipping.getValue(), + positivity_constraint=float(settings.positivity_constraint.get_value()), + smoothness_constraint=float(settings.smoothness_constraint.get_value()), + use_adaptive_moment=settings.use_adaptive_moment.get_value(), + vdecay=float(settings.vdecay.get_value()), + mdecay=float(settings.mdecay.get_value()), + clip_magnitude=settings.use_magnitude_clipping.get_value(), ) return options - def getPositionOptions( - self, initialScan: numpy.typing.NDArray[Any] + def get_position_options( + self, initial_scan: numpy.typing.NDArray[Any] ) -> tike.ptycho.PositionOptions: - settings = self._positionCorrectionSettings + settings = self._position_correction_settings options = None - if settings.usePositionCorrection.getValue(): + if settings.use_position_correction.get_value(): options = tike.ptycho.PositionOptions( - initial_scan=initialScan, - use_adaptive_moment=settings.useAdaptiveMoment.getValue(), - vdecay=float(settings.vdecay.getValue()), - mdecay=float(settings.mdecay.getValue()), - use_position_regularization=settings.usePositionRegularization.getValue(), - update_magnitude_limit=float(settings.updateMagnitudeLimit.getValue()), + initial_scan=initial_scan, + use_adaptive_moment=settings.use_adaptive_moment.get_value(), + vdecay=float(settings.vdecay.get_value()), + mdecay=float(settings.mdecay.get_value()), + use_position_regularization=settings.use_position_regularization.get_value(), + update_magnitude_limit=float(settings.update_magnitude_limit.get_value()), ) return options - def getProbeOptions(self) -> tike.ptycho.ProbeOptions: - settings = self._probeCorrectionSettings + def get_probe_options(self) -> tike.ptycho.ProbeOptions: + settings = self._probe_correction_settings options = None - if settings.useProbeCorrection.getValue(): - probeSupport = ( - float(settings.probeSupportWeight.getValue()) - if settings.useFiniteProbeSupport.getValue() + if settings.use_probe_correction.get_value(): + probe_support = ( + float(settings.probe_support_weight.get_value()) + if settings.use_finite_probe_support.get_value() else 0.0 ) options = tike.ptycho.ProbeOptions( - force_orthogonality=settings.forceOrthogonality.getValue(), - force_centered_intensity=settings.forceCenteredIntensity.getValue(), - force_sparsity=float(settings.forceSparsity.getValue()), - use_adaptive_moment=settings.useAdaptiveMoment.getValue(), - vdecay=float(settings.vdecay.getValue()), - mdecay=float(settings.mdecay.getValue()), - probe_support=probeSupport, - probe_support_radius=float(settings.probeSupportRadius.getValue()), - probe_support_degree=float(settings.probeSupportDegree.getValue()), - additional_probe_penalty=float(settings.additionalProbePenalty.getValue()), + force_orthogonality=settings.force_orthogonality.get_value(), + force_centered_intensity=settings.force_centered_intensity.get_value(), + force_sparsity=float(settings.force_sparsity.get_value()), + use_adaptive_moment=settings.use_adaptive_moment.get_value(), + vdecay=float(settings.vdecay.get_value()), + mdecay=float(settings.mdecay.get_value()), + probe_support=probe_support, + probe_support_radius=float(settings.probe_support_radius.get_value()), + probe_support_degree=float(settings.probe_support_degree.get_value()), + additional_probe_penalty=float(settings.additional_probe_penalty.get_value()), ) return options - def getNumGpus(self) -> int | tuple[int, ...]: - numGpus = self._settings.numGpus.getValue() - onlyDigitsAndCommas = all(c.isdigit() or c == ',' for c in numGpus) - hasDigit = any(c.isdigit() for c in numGpus) + def get_num_gpus(self) -> int | tuple[int, ...]: + num_gpus = self._settings.num_gpus.get_value() + only_digits_and_commas = all(c.isdigit() or c == ',' for c in num_gpus) + has_digit = any(c.isdigit() for c in num_gpus) - if onlyDigitsAndCommas and hasDigit: - if ',' in numGpus: - return tuple(int(n) for n in numGpus.split(',') if n) + if only_digits_and_commas and has_digit: + if ',' in num_gpus: + return tuple(int(n) for n in num_gpus.split(',') if n) else: - return int(numGpus) + return int(num_gpus) return 1 def __call__( self, parameters: ReconstructInput, - algorithmOptions: tike.ptycho.solvers.IterativeOptions, + algorithm_options: tike.ptycho.solvers.IterativeOptions, ) -> ReconstructOutput: - patternsArray = numpy.fft.ifftshift(parameters.patterns, axes=(-2, -1)) + patterns_array = numpy.fft.ifftshift(parameters.patterns, axes=(-2, -1)) + + object_input = parameters.product.object_ + object_geometry = object_input.get_geometry() + object_input_array = object_input.get_array().astype('complex64') + num_layers = object_input.num_layers - objectInput = parameters.product.object_ - objectGeometry = objectInput.getGeometry() - # TODO change array[0] -> array when multislice is available - objectInputArray = objectInput.array[0].astype('complex64') + if num_layers == 1: + object_input_array = object_input_array[0] + else: + raise ValueError(f'Tike does not support multislice (layers={num_layers})!') - probeInput = parameters.product.probe - probeInputArray = probeInput.array[numpy.newaxis, numpy.newaxis, ...].astype('complex64') + probe_input = parameters.product.probes + probe_input_array = probe_input.get_array().astype('complex64') - scanInput = parameters.product.scan - scanInputCoords: list[float] = list() + scan_input = parameters.product.positions + scan_input_coords: list[float] = list() # Tike coordinate system origin is top-left corner; requires padding - ux = -probeInputArray.shape[-1] / 2 - uy = -probeInputArray.shape[-2] / 2 + ux = -probe_input_array.shape[-1] / 2 + uy = -probe_input_array.shape[-2] / 2 - for scanPoint in scanInput: - objectPoint = objectGeometry.mapScanPointToObjectPoint(scanPoint) - scanInputCoords.append(objectPoint.positionYInPixels + uy) - scanInputCoords.append(objectPoint.positionXInPixels + ux) + for scan_point in scan_input: + object_point = object_geometry.map_scan_point_to_object_point(scan_point) + scan_input_coords.append(object_point.position_y_px + uy) + scan_input_coords.append(object_point.position_x_px + ux) - scanInputArray = numpy.array( - scanInputCoords, + scan_input_array = numpy.array( + scan_input_coords, dtype=numpy.float32, - ).reshape(len(scanInput), 2) - scanMin = scanInputArray.min(axis=0) - scanMax = scanInputArray.max(axis=0) - logger.debug(f'Scan range [px]: {scanMin} -> {scanMax}') - numGpus = self.getNumGpus() - - logger.debug(f'data shape={patternsArray.shape}') - logger.debug(f'scan shape={scanInputArray.shape}') - logger.debug(f'probe shape={probeInputArray.shape}') - logger.debug(f'object shape={objectInputArray.shape}') - logger.debug(f'num_gpu={numGpus}') + ).reshape(len(scan_input), 2) + scan_min = scan_input_array.min(axis=0) + scan_max = scan_input_array.max(axis=0) + logger.debug(f'Scan range [px]: {scan_min} -> {scan_max}') + num_gpus = self.get_num_gpus() + + logger.debug(f'data shape={patterns_array.shape}') + logger.debug(f'scan shape={scan_input_array.shape}') + logger.debug(f'probe shape={probe_input_array.shape}') + logger.debug(f'object shape={object_input_array.shape}') + logger.debug(f'num_gpu={num_gpus}') exitwave_options = tike.ptycho.ExitWaveOptions( - # TODO: Use a user supplied `measured_pixels` instead - measured_pixels=numpy.ones(probeInputArray.shape[-2:], dtype=numpy.bool_), - noise_model=self._settings.noiseModel.getValue(), + measured_pixels=numpy.logical_not(parameters.bad_pixels), + noise_model=self._settings.noise_model.get_value(), ) - ptychoParameters = tike.ptycho.solvers.PtychoParameters( - probe=probeInputArray, - psi=objectInputArray, - scan=scanInputArray, - algorithm_options=algorithmOptions, - probe_options=self.getProbeOptions(), - object_options=self.getObjectOptions(), - position_options=self.getPositionOptions(scanInputArray), + ptycho_parameters = tike.ptycho.solvers.PtychoParameters( + probe=probe_input_array, + psi=object_input_array, + scan=scan_input_array, + algorithm_options=algorithm_options, + probe_options=self.get_probe_options(), + object_options=self.get_object_options(), + position_options=self.get_position_options(scan_input_array), exitwave_options=exitwave_options, ) - if self._multigridSettings.useMultigrid.getValue(): + if self._multigrid_settings.use_multigrid.get_value(): result = tike.ptycho.reconstruct_multigrid( - data=patternsArray, - parameters=ptychoParameters, - num_gpu=numGpus, + data=patterns_array, + parameters=ptycho_parameters, + num_gpu=num_gpus, use_mpi=False, - num_levels=self._multigridSettings.numLevels.getValue(), + num_levels=self._multigrid_settings.num_levels.get_value(), interp=None, # TODO does this have other options? ) else: # TODO support interactive reconstructions with tike.ptycho.Reconstruction( - data=patternsArray, - parameters=ptychoParameters, - num_gpu=numGpus, + data=patterns_array, + parameters=ptycho_parameters, + num_gpu=num_gpus, use_mpi=False, ) as context: - context.iterate(ptychoParameters.algorithm_options.num_iter) + context.iterate(ptycho_parameters.algorithm_options.num_iter) result = context.parameters logger.debug(f'Result: {pprint.pformat(result)}') - scanOutputPoints: list[ScanPoint] = list() + scan_output_points: list[ScanPoint] = list() - for uncorrectedPoint, xy in zip(scanInput, result.scan): - objectPoint = ObjectPoint(uncorrectedPoint.index, xy[1] - ux, xy[0] - uy) - scanPoint = objectGeometry.mapObjectPointToScanPoint(objectPoint) - scanOutputPoints.append(scanPoint) + for uncorrected_point, xy in zip(scan_input, result.scan): + object_point = ObjectPoint(uncorrected_point.index, xy[1] - ux, xy[0] - uy) + scan_point = object_geometry.map_object_point_to_scan_point(object_point) + scan_output_points.append(scan_point) - scanOutput = Scan(scanOutputPoints) + scan_output = PositionSequence(scan_output_points) - if self._probeCorrectionSettings.useProbeCorrection.getValue(): - probeOutput = Probe( - array=result.probe[0, 0], - pixelWidthInMeters=probeInput.pixelWidthInMeters, - pixelHeightInMeters=probeInput.pixelHeightInMeters, + if self._probe_correction_settings.use_probe_correction.get_value(): + probe_output = ProbeSequence( + array=result.probe, + opr_weights=None, + pixel_geometry=probe_input.get_pixel_geometry(), ) else: - probeOutput = probeInput.copy() + probe_output = probe_input.copy() - if self._objectCorrectionSettings.useObjectCorrection.getValue(): - objectOutput = Object( + if self._object_correction_settings.use_object_correction.get_value(): + object_output = Object( array=result.psi, - layerDistanceInMeters=objectInput.layerDistanceInMeters, - pixelWidthInMeters=objectInput.pixelWidthInMeters, - pixelHeightInMeters=objectInput.pixelHeightInMeters, - centerXInMeters=objectInput.centerXInMeters, - centerYInMeters=objectInput.centerYInMeters, + layer_spacing_m=object_input.layer_spacing_m, + pixel_geometry=object_input.get_pixel_geometry(), + center=object_input.get_center(), ) else: - objectOutput = objectInput.copy() + object_output = object_input.copy() product = Product( metadata=parameters.product.metadata, - scan=scanOutput, - probe=probeOutput, - object_=objectOutput, + positions=scan_output, + probes=probe_output, + object_=object_output, costs=[float(numpy.mean(values)) for values in result.algorithm_options.costs], ) return ReconstructOutput(product, 0) class RegularizedPIEReconstructor(Reconstructor): - def __init__(self, tikeReconstructor: TikeReconstructor) -> None: + def __init__(self, tike_reconstructor: TikeReconstructor) -> None: super().__init__() - self._algorithmOptions = tike.ptycho.solvers.RpieOptions() - self._tikeReconstructor = tikeReconstructor + self._algorithm_options = tike.ptycho.solvers.RpieOptions() + self._tike_reconstructor = tike_reconstructor @property def name(self) -> str: - return self._algorithmOptions.name + return self._algorithm_options.name @property def _settings(self) -> TikeSettings: - return self._tikeReconstructor._settings + return self._tike_reconstructor._settings def reconstruct(self, parameters: ReconstructInput) -> ReconstructOutput: - self._algorithmOptions.num_batch = self._settings.numBatch.getValue() - self._algorithmOptions.batch_method = self._settings.batchMethod.getValue() - self._algorithmOptions.num_iter = self._settings.numIter.getValue() - self._algorithmOptions.convergence_window = self._settings.convergenceWindow.getValue() - self._algorithmOptions.alpha = float(self._settings.alpha.getValue()) - return self._tikeReconstructor(parameters, self._algorithmOptions) + self._algorithm_options.num_batch = self._settings.num_batch.get_value() + self._algorithm_options.batch_method = self._settings.batch_method.get_value() + self._algorithm_options.num_iter = self._settings.num_iter.get_value() + self._algorithm_options.convergence_window = self._settings.convergence_window.get_value() + self._algorithm_options.alpha = float(self._settings.alpha.get_value()) + return self._tike_reconstructor(parameters, self._algorithm_options) class IterativeLeastSquaresReconstructor(Reconstructor): - def __init__(self, tikeReconstructor: TikeReconstructor) -> None: + def __init__(self, tike_reconstructor: TikeReconstructor) -> None: super().__init__() - self._algorithmOptions = tike.ptycho.solvers.LstsqOptions() - self._tikeReconstructor = tikeReconstructor + self._algorithm_options = tike.ptycho.solvers.LstsqOptions() + self._tike_reconstructor = tike_reconstructor @property def name(self) -> str: - return self._algorithmOptions.name + return self._algorithm_options.name @property def _settings(self) -> TikeSettings: - return self._tikeReconstructor._settings + return self._tike_reconstructor._settings def reconstruct(self, parameters: ReconstructInput) -> ReconstructOutput: - self._algorithmOptions.num_batch = self._settings.numBatch.getValue() - self._algorithmOptions.batch_method = self._settings.batchMethod.getValue() - self._algorithmOptions.num_iter = self._settings.numIter.getValue() - self._algorithmOptions.convergence_window = self._settings.convergenceWindow.getValue() - return self._tikeReconstructor(parameters, self._algorithmOptions) + self._algorithm_options.num_batch = self._settings.num_batch.get_value() + self._algorithm_options.batch_method = self._settings.batch_method.get_value() + self._algorithm_options.num_iter = self._settings.num_iter.get_value() + self._algorithm_options.convergence_window = self._settings.convergence_window.get_value() + return self._tike_reconstructor(parameters, self._algorithm_options) diff --git a/src/ptychodus/model/tike/settings.py b/src/ptychodus/model/tike/settings.py index 0b410dbe..b9d115e1 100644 --- a/src/ptychodus/model/tike/settings.py +++ b/src/ptychodus/model/tike/settings.py @@ -10,39 +10,37 @@ class TikeSettings(Observable, Observer): def __init__(self, registry: SettingsRegistry) -> None: super().__init__() - self._settingsGroup = registry.createGroup('Tike') - self._settingsGroup.addObserver(self) - - self.numGpus = self._settingsGroup.createStringParameter('NumGpus', '1') - self.noiseModel = self._settingsGroup.createStringParameter('NoiseModel', 'gaussian') - self.numBatch = self._settingsGroup.createIntegerParameter('NumBatch', 10, minimum=1) - self.batchMethod = self._settingsGroup.createStringParameter('BatchMethod', 'wobbly_center') - self.numIter = self._settingsGroup.createIntegerParameter('NumIter', 1, minimum=1) - self.convergenceWindow = self._settingsGroup.createIntegerParameter( + self._group = registry.create_group('Tike') + self._group.add_observer(self) + + self.num_gpus = self._group.create_string_parameter('NumGpus', '1') + self.noise_model = self._group.create_string_parameter('NoiseModel', 'gaussian') + self.num_batch = self._group.create_integer_parameter('NumBatch', 10, minimum=1) + self.batch_method = self._group.create_string_parameter('BatchMethod', 'wobbly_center') + self.num_iter = self._group.create_integer_parameter('NumIter', 1, minimum=1) + self.convergence_window = self._group.create_integer_parameter( 'ConvergenceWindow', 0, minimum=0 ) - self.alpha = self._settingsGroup.createRealParameter( - 'Alpha', 0.05, minimum=0.0, maximum=1.0 - ) + self.alpha = self._group.create_real_parameter('Alpha', 0.05, minimum=0.0, maximum=1.0) self._logger = logging.getLogger('tike') - def getNoiseModels(self) -> Sequence[str]: + def get_noise_models(self) -> Sequence[str]: return ['poisson', 'gaussian'] - def getBatchMethods(self) -> Sequence[str]: + def get_batch_methods(self) -> Sequence[str]: return ['wobbly_center', 'wobbly_center_random_bootstrap', 'compact'] - def getLogLevels(self) -> Sequence[str]: + def get_log_levels(self) -> Sequence[str]: return ['CRITICAL', 'ERROR', 'WARNING', 'INFO', 'DEBUG'] - def getLogLevel(self) -> str: + def get_log_level(self) -> str: level = self._logger.getEffectiveLevel() return logging.getLevelName(level) - def setLogLevel(self, name: str) -> None: - nameBefore = self.getLogLevel() + def set_log_level(self, name: str) -> None: + name_before = self.get_log_level() - if name == nameBefore: + if name == name_before: return try: @@ -50,143 +48,121 @@ def setLogLevel(self, name: str) -> None: except ValueError: logger.error(f'Bad log level "{name}".') - nameAfter = self.getLogLevel() - logger.info(f'Changed Tike log level {nameBefore} -> {nameAfter}') - self.notifyObservers() + name_after = self.get_log_level() + logger.info(f'Changed Tike log level {name_before} -> {name_after}') + self.notify_observers() - def update(self, observable: Observable) -> None: - if observable is self._settingsGroup: - self.notifyObservers() + def _update(self, observable: Observable) -> None: + if observable is self._group: + self.notify_observers() class TikeMultigridSettings(Observable, Observer): def __init__(self, registry: SettingsRegistry) -> None: super().__init__() - self._settingsGroup = registry.createGroup('TikeMultigrid') - self._settingsGroup.addObserver(self) + self._group = registry.create_group('TikeMultigrid') + self._group.add_observer(self) - self.useMultigrid = self._settingsGroup.createBooleanParameter('UseMultigrid', False) - self.numLevels = self._settingsGroup.createIntegerParameter('NumLevels', 3, minimum=1) + self.use_multigrid = self._group.create_boolean_parameter('UseMultigrid', False) + self.num_levels = self._group.create_integer_parameter('NumLevels', 3, minimum=1) - def update(self, observable: Observable) -> None: - if observable is self._settingsGroup: - self.notifyObservers() + def _update(self, observable: Observable) -> None: + if observable is self._group: + self.notify_observers() class TikeObjectCorrectionSettings(Observable, Observer): def __init__(self, registry: SettingsRegistry) -> None: super().__init__() - self._settingsGroup = registry.createGroup('TikeObjectCorrection') - self._settingsGroup.addObserver(self) + self._group = registry.create_group('TikeObjectCorrection') + self._group.add_observer(self) - self.useObjectCorrection = self._settingsGroup.createBooleanParameter( + self.use_object_correction = self._group.create_boolean_parameter( 'UseObjectCorrection', True ) - self.positivityConstraint = self._settingsGroup.createRealParameter( + self.positivity_constraint = self._group.create_real_parameter( 'PositivityConstraint', 0.0, minimum=0.0, maximum=1.0 ) - self.smoothnessConstraint = self._settingsGroup.createRealParameter( + self.smoothness_constraint = self._group.create_real_parameter( 'SmoothnessConstraint', 0.0, minimum=0.0, maximum=1.0 / 8, ) - self.useMagnitudeClipping = self._settingsGroup.createBooleanParameter( + self.use_magnitude_clipping = self._group.create_boolean_parameter( 'UseMagnitudeClipping', False ) - self.useAdaptiveMoment = self._settingsGroup.createBooleanParameter( - 'UseAdaptiveMoment', False - ) - self.mdecay = self._settingsGroup.createRealParameter( - 'MDecay', 0.9, minimum=0.0, maximum=1.0 - ) - self.vdecay = self._settingsGroup.createRealParameter( - 'VDecay', 0.999, minimum=0.0, maximum=1.0 - ) + self.use_adaptive_moment = self._group.create_boolean_parameter('UseAdaptiveMoment', False) + self.mdecay = self._group.create_real_parameter('MDecay', 0.9, minimum=0.0, maximum=1.0) + self.vdecay = self._group.create_real_parameter('VDecay', 0.999, minimum=0.0, maximum=1.0) - def update(self, observable: Observable) -> None: - if observable is self._settingsGroup: - self.notifyObservers() + def _update(self, observable: Observable) -> None: + if observable is self._group: + self.notify_observers() class TikeProbeCorrectionSettings(Observable, Observer): def __init__(self, registry: SettingsRegistry) -> None: super().__init__() - self._settingsGroup = registry.createGroup('TikeProbeCorrection') - self._settingsGroup.addObserver(self) + self._group = registry.create_group('TikeProbeCorrection') + self._group.add_observer(self) - self.useProbeCorrection = self._settingsGroup.createBooleanParameter( - 'UseProbeCorrection', True - ) - self.forceOrthogonality = self._settingsGroup.createBooleanParameter( - 'ForceOrthogonality', False - ) - self.forceCenteredIntensity = self._settingsGroup.createBooleanParameter( + self.use_probe_correction = self._group.create_boolean_parameter('UseProbeCorrection', True) + self.force_orthogonality = self._group.create_boolean_parameter('ForceOrthogonality', False) + self.force_centered_intensity = self._group.create_boolean_parameter( 'ForceCenteredIntensity', False ) - self.forceSparsity = self._settingsGroup.createRealParameter( + self.force_sparsity = self._group.create_real_parameter( 'ForceSparsity', 0.0, minimum=0.0, maximum=1.0 ) - self.useFiniteProbeSupport = self._settingsGroup.createBooleanParameter( + self.use_finite_probe_support = self._group.create_boolean_parameter( 'UseFiniteProbeSupport', False ) - self.probeSupportWeight = self._settingsGroup.createRealParameter( + self.probe_support_weight = self._group.create_real_parameter( 'ProbeSupportWeight', 10, minimum=0.0 ) - self.probeSupportRadius = self._settingsGroup.createRealParameter( + self.probe_support_radius = self._group.create_real_parameter( 'ProbeSupportRadius', 0.35, minimum=0.0, maximum=0.5 ) - self.probeSupportDegree = self._settingsGroup.createRealParameter( + self.probe_support_degree = self._group.create_real_parameter( 'ProbeSupportDegree', 2.5, minimum=0.0 ) - self.additionalProbePenalty = self._settingsGroup.createRealParameter( + self.additional_probe_penalty = self._group.create_real_parameter( 'AdditionalProbePenalty', 0.0, minimum=0.0 ) - self.useAdaptiveMoment = self._settingsGroup.createBooleanParameter( - 'UseAdaptiveMoment', False - ) - self.mdecay = self._settingsGroup.createRealParameter( - 'MDecay', 0.9, minimum=0.0, maximum=1.0 - ) - self.vdecay = self._settingsGroup.createRealParameter( - 'VDecay', 0.999, minimum=0.0, maximum=1.0 - ) + self.use_adaptive_moment = self._group.create_boolean_parameter('UseAdaptiveMoment', False) + self.mdecay = self._group.create_real_parameter('MDecay', 0.9, minimum=0.0, maximum=1.0) + self.vdecay = self._group.create_real_parameter('VDecay', 0.999, minimum=0.0, maximum=1.0) - def update(self, observable: Observable) -> None: - if observable is self._settingsGroup: - self.notifyObservers() + def _update(self, observable: Observable) -> None: + if observable is self._group: + self.notify_observers() class TikePositionCorrectionSettings(Observable, Observer): def __init__(self, registry: SettingsRegistry) -> None: super().__init__() - self._settingsGroup = registry.createGroup('TikePositionCorrection') - self._settingsGroup.addObserver(self) + self._group = registry.create_group('TikePositionCorrection') + self._group.add_observer(self) - self.usePositionCorrection = self._settingsGroup.createBooleanParameter( + self.use_position_correction = self._group.create_boolean_parameter( 'UsePositionCorrection', False ) - self.usePositionRegularization = self._settingsGroup.createBooleanParameter( + self.use_position_regularization = self._group.create_boolean_parameter( 'UsePositionRegularization', False ) - self.updateMagnitudeLimit = self._settingsGroup.createRealParameter( + self.update_magnitude_limit = self._group.create_real_parameter( 'UpdateMagnitudeLimit', 0.0, minimum=0.0 ) # TODO transform: Global transform of positions. # TODO origin: The rotation center of the transformation. - self.useAdaptiveMoment = self._settingsGroup.createBooleanParameter( - 'UseAdaptiveMoment', False - ) - self.mdecay = self._settingsGroup.createRealParameter( - 'MDecay', 0.9, minimum=0.0, maximum=1.0 - ) - self.vdecay = self._settingsGroup.createRealParameter( - 'VDecay', 0.999, minimum=0.0, maximum=1.0 - ) + self.use_adaptive_moment = self._group.create_boolean_parameter('UseAdaptiveMoment', False) + self.mdecay = self._group.create_real_parameter('MDecay', 0.9, minimum=0.0, maximum=1.0) + self.vdecay = self._group.create_real_parameter('VDecay', 0.999, minimum=0.0, maximum=1.0) - def update(self, observable: Observable) -> None: - if observable is self._settingsGroup: - self.notifyObservers() + def _update(self, observable: Observable) -> None: + if observable is self._group: + self.notify_observers() diff --git a/src/ptychodus/model/visualization/colorModelRenderer.py b/src/ptychodus/model/visualization/colorModelRenderer.py deleted file mode 100644 index a72260bc..00000000 --- a/src/ptychodus/model/visualization/colorModelRenderer.py +++ /dev/null @@ -1,88 +0,0 @@ -from __future__ import annotations -from collections.abc import Iterator -from matplotlib.colors import Normalize -import numpy - -from ptychodus.api.geometry import PixelGeometry -from ptychodus.api.visualization import ( - NumberArrayType, - RealArrayType, - VisualizationProduct, -) - -from .colorAxis import ColorAxis -from .colorModel import CylindricalColorModelParameter -from .components import AmplitudeArrayComponent, PhaseInRadiansArrayComponent -from .renderer import Renderer -from .transformation import ScalarTransformationParameter - - -class CylindricalColorModelRenderer(Renderer): - def __init__( - self, - amplitudeComponent: AmplitudeArrayComponent, - phaseComponent: PhaseInRadiansArrayComponent, - transformation: ScalarTransformationParameter, - colorAxis: ColorAxis, - ) -> None: - super().__init__('Complex') - self._amplitudeComponent = amplitudeComponent - self._phaseComponent = phaseComponent - self._transformation = transformation - self._addParameter('transformation', transformation) - self._colorAxis = colorAxis - self._addGroup('color_axis', colorAxis, observe=True) - self._colorModel = CylindricalColorModelParameter() - self._addParameter('color_model', self._colorModel) - - def variants(self) -> Iterator[str]: - return self._colorModel.choices() - - def getVariant(self) -> str: - return self._colorModel.getValue() - - def setVariant(self, variant: str) -> None: - self._colorModel.setValue(variant) - - def isCyclic(self) -> bool: - return True - - def _colorize( - self, amplitudeTransformed: RealArrayType, phaseInRadians: RealArrayType - ) -> RealArrayType: - vrange = self._colorAxis.getRange() - norm = Normalize(vmin=vrange.lower, vmax=vrange.upper, clip=False) - - model = numpy.vectorize(self._colorModel.getPlugin()) - h = (phaseInRadians + numpy.pi) / (2 * numpy.pi) - r, g, b, a = model(h, norm(amplitudeTransformed)) - return numpy.stack((r, g, b, a), axis=-1) - - def colorize(self, array: NumberArrayType) -> RealArrayType: - amplitude = self._amplitudeComponent.calculate(array) - amplitudeTransformed = self._transformation.transform(amplitude) - phaseInRadians = self._phaseComponent.calculate(array) - return self._colorize(amplitudeTransformed, phaseInRadians) - - def render( - self, - array: NumberArrayType, - pixelGeometry: PixelGeometry, - *, - autoscaleColorAxis: bool, - ) -> VisualizationProduct: - amplitude = self._amplitudeComponent.calculate(array) - amplitudeTransformed = self._transformation.transform(amplitude) - phaseInRadians = self._phaseComponent.calculate(array) - - if autoscaleColorAxis: - self._colorAxis.setToDataRange(amplitudeTransformed) - - rgba = self._colorize(amplitudeTransformed, phaseInRadians) - - return VisualizationProduct( - valueLabel=self._transformation.decorateText(self._amplitudeComponent.name), - values=array, - rgba=rgba, - pixelGeometry=pixelGeometry, - ) diff --git a/src/ptychodus/model/visualization/colorAxis.py b/src/ptychodus/model/visualization/color_axis.py similarity index 56% rename from src/ptychodus/model/visualization/colorAxis.py rename to src/ptychodus/model/visualization/color_axis.py index 826f2a7c..698bccec 100644 --- a/src/ptychodus/model/visualization/colorAxis.py +++ b/src/ptychodus/model/visualization/color_axis.py @@ -4,7 +4,7 @@ from ptychodus.api.geometry import Interval from ptychodus.api.parametric import ParameterGroup -from ptychodus.api.visualization import RealArrayType +from ptychodus.api.typing import RealArrayType logger = logging.getLogger(__name__) @@ -12,21 +12,21 @@ class ColorAxis(ParameterGroup): def __init__(self) -> None: super().__init__() - self.lower = self.createRealParameter('lower', 0.0) - self.upper = self.createRealParameter('upper', 1.0) + self.lower = self.create_real_parameter('lower', 0.0) + self.upper = self.create_real_parameter('upper', 1.0) - def getRange(self) -> Interval[float]: - return Interval[float].createProper( - self.lower.getValue(), - self.upper.getValue(), + def get_range(self) -> Interval[float]: + return Interval[float].create_proper( + self.lower.get_value(), + self.upper.get_value(), ) - def setRange(self, lower: float, upper: float): - self.lower.setValue(lower, notify=False) - self.upper.setValue(upper, notify=False) - self.notifyObservers() + def set_range(self, lower: float, upper: float): + self.lower.set_value(lower, notify=False) + self.upper.set_value(upper, notify=False) + self.notify_observers() - def setToDataRange(self, array: RealArrayType) -> None: + def set_to_data_range(self, array: RealArrayType) -> None: if array.size > 0: lower = array.min().item() upper = array.max().item() @@ -41,6 +41,6 @@ def setToDataRange(self, array: RealArrayType) -> None: lower = 0.0 upper = 1.0 - self.setRange(lower, upper) + self.set_range(lower, upper) else: logger.warning('Array has zero size!') diff --git a/src/ptychodus/model/visualization/colorModel.py b/src/ptychodus/model/visualization/color_model.py similarity index 61% rename from src/ptychodus/model/visualization/colorModel.py rename to src/ptychodus/model/visualization/color_model.py index 58bb5c16..df83de66 100644 --- a/src/ptychodus/model/visualization/colorModel.py +++ b/src/ptychodus/model/visualization/color_model.py @@ -52,63 +52,63 @@ class CylindricalColorModelParameter(Parameter[str], Observer): def __init__(self) -> None: super().__init__() self._chooser = PluginChooser[CylindricalColorModel]() - self._chooser.registerPlugin( + self._chooser.register_plugin( HSVSaturationColorModel(), - simpleName='HSV-S', - displayName='HSV Saturation', + simple_name='HSV-S', + display_name='HSV Saturation', ) - self._chooser.registerPlugin( + self._chooser.register_plugin( HSVValueColorModel(), - simpleName='HSV-V', - displayName='HSV Value', + simple_name='HSV-V', + display_name='HSV Value', ) - self._chooser.registerPlugin( + self._chooser.register_plugin( HSVAlphaColorModel(), - simpleName='HSV-A', - displayName='HSV Alpha', + simple_name='HSV-A', + display_name='HSV Alpha', ) - self._chooser.registerPlugin( + self._chooser.register_plugin( HLSLightnessColorModel(), - simpleName='HLS-L', - displayName='HLS Lightness', + simple_name='HLS-L', + display_name='HLS Lightness', ) - self._chooser.registerPlugin( + self._chooser.register_plugin( HLSSaturationColorModel(), - simpleName='HLS-S', - displayName='HLS Saturation', + simple_name='HLS-S', + display_name='HLS Saturation', ) - self._chooser.registerPlugin( + self._chooser.register_plugin( HLSAlphaColorModel(), - simpleName='HLS-A', - displayName='HLS Alpha', + simple_name='HLS-A', + display_name='HLS Alpha', ) - self.setValue('HSV-V') - self._chooser.addObserver(self) + self.set_value('HSV-V') + self._chooser.add_observer(self) def choices(self) -> Iterator[str]: - for name in self._chooser.getDisplayNameList(): - yield name + for plugin in self._chooser: + yield plugin.display_name - def getValue(self) -> str: - return self._chooser.currentPlugin.displayName + def get_value(self) -> str: + return self._chooser.get_current_plugin().display_name - def setValue(self, value: str, *, notify: bool = True) -> None: - self._chooser.setCurrentPluginByName(value) + def set_value(self, value: str, *, notify: bool = True) -> None: + self._chooser.set_current_plugin(value) - def getValueAsString(self) -> str: - return self.getValue() + def get_value_as_string(self) -> str: + return self.get_value() - def setValueFromString(self, value: str) -> None: - self.setValue(value) + def set_value_from_string(self, value: str) -> None: + self.set_value(value) def copy(self) -> Parameter[str]: parameter = CylindricalColorModelParameter() - parameter.setValue(self.getValue()) + parameter.set_value(self.get_value()) return parameter - def getPlugin(self) -> CylindricalColorModel: - return self._chooser.currentPlugin.strategy + def get_plugin(self) -> CylindricalColorModel: + return self._chooser.get_current_plugin().strategy - def update(self, observable: Observable) -> None: + def _update(self, observable: Observable) -> None: if observable is self._chooser: - self.notifyObservers() + self.notify_observers() diff --git a/src/ptychodus/model/visualization/color_model_renderer.py b/src/ptychodus/model/visualization/color_model_renderer.py new file mode 100644 index 00000000..3c0982af --- /dev/null +++ b/src/ptychodus/model/visualization/color_model_renderer.py @@ -0,0 +1,84 @@ +from __future__ import annotations +from collections.abc import Iterator +from matplotlib.colors import Normalize +import numpy + +from ptychodus.api.geometry import PixelGeometry +from ptychodus.api.visualization import ( + NumberArrayType, + RealArrayType, + VisualizationProduct, +) + +from .color_axis import ColorAxis +from .color_model import CylindricalColorModelParameter +from .components import AmplitudeArrayComponent, PhaseInRadiansArrayComponent +from .renderer import Renderer +from .transformation import ScalarTransformationParameter + + +class CylindricalColorModelRenderer(Renderer): + def __init__( + self, + amplitude_component: AmplitudeArrayComponent, + phase_component: PhaseInRadiansArrayComponent, + transformation: ScalarTransformationParameter, + color_axis: ColorAxis, + ) -> None: + super().__init__('Complex') + self._amplitude_component = amplitude_component + self._phase_component = phase_component + self._transformation = transformation + self._add_parameter('transformation', transformation) + self._color_axis = color_axis + self._add_group('color_axis', color_axis, observe=True) + self._color_model = CylindricalColorModelParameter() + self._add_parameter('color_model', self._color_model) + + def variants(self) -> Iterator[str]: + return self._color_model.choices() + + def get_variant(self) -> str: + return self._color_model.get_value() + + def set_variant(self, variant: str) -> None: + self._color_model.set_value(variant) + + def is_cyclic(self) -> bool: + return True + + def _colorize( + self, amplitude_transformed: RealArrayType, phase_rad: RealArrayType + ) -> RealArrayType: + vrange = self._color_axis.get_range() + norm = Normalize(vmin=vrange.lower, vmax=vrange.upper, clip=False) + + model = numpy.vectorize(self._color_model.get_plugin()) + h = (phase_rad + numpy.pi) / (2 * numpy.pi) + r, g, b, a = model(h, norm(amplitude_transformed)) + return numpy.stack((r, g, b, a), axis=-1) + + def colorize(self, array: NumberArrayType) -> RealArrayType: + amplitude = self._amplitude_component.calculate(array) + amplitude_transformed = self._transformation.transform(amplitude) + phase_rad = self._phase_component.calculate(array) + return self._colorize(amplitude_transformed, phase_rad) + + def render( + self, array: NumberArrayType, pixel_geometry: PixelGeometry, *, autoscale_color_axis: bool + ) -> VisualizationProduct: + amplitude = self._amplitude_component.calculate(array) + amplitude_transformed = self._transformation.transform(amplitude) + phase_rad = self._phase_component.calculate(array) + + if autoscale_color_axis: + self._color_axis.set_to_data_range(amplitude_transformed) + + rgba = self._colorize(amplitude_transformed, phase_rad) + + return VisualizationProduct( + value_label=self._transformation.decorate_text(self._amplitude_component.name), + values=array, + rgba=rgba, + pixel_geometry=pixel_geometry, + ) diff --git a/src/ptychodus/model/visualization/colormap.py b/src/ptychodus/model/visualization/colormap.py index 2dd3b71c..ae5d5b36 100644 --- a/src/ptychodus/model/visualization/colormap.py +++ b/src/ptychodus/model/visualization/colormap.py @@ -13,44 +13,44 @@ class ColormapParameter(Parameter[str], Observer): # See https://matplotlib.org/stable/gallery/color/colormap_reference.html CYCLIC_COLORMAPS: Final[tuple[str, ...]] = ('hsv', 'twilight', 'twilight_shifted') - def __init__(self, *, isCyclic: bool) -> None: + def __init__(self, *, is_cyclic: bool) -> None: super().__init__() - self._isCyclic = isCyclic + self._is_cyclic = is_cyclic self._chooser = PluginChooser[Colormap]() for name, cmap in matplotlib.colormaps.items(): - isCyclicColormap = name in ColormapParameter.CYCLIC_COLORMAPS + is_cyclic_colormap = name in ColormapParameter.CYCLIC_COLORMAPS - if isCyclic == isCyclicColormap: - self._chooser.registerPlugin(cmap, displayName=name) + if is_cyclic == is_cyclic_colormap: + self._chooser.register_plugin(cmap, display_name=name) - self.setValue('hsv' if isCyclic else 'gray') - self._chooser.addObserver(self) + self.set_value('hsv' if is_cyclic else 'gray') + self._chooser.add_observer(self) def choices(self) -> Iterator[str]: - for name in self._chooser.getDisplayNameList(): - yield name + for plugin in self._chooser: + yield plugin.display_name - def getValue(self) -> str: - return self._chooser.currentPlugin.displayName + def get_value(self) -> str: + return self._chooser.get_current_plugin().display_name - def setValue(self, value: str, *, notify: bool = True) -> None: - self._chooser.setCurrentPluginByName(value) + def set_value(self, value: str, *, notify: bool = True) -> None: + self._chooser.set_current_plugin(value) - def getValueAsString(self) -> str: - return self.getValue() + def get_value_as_string(self) -> str: + return self.get_value() - def setValueFromString(self, value: str) -> None: - self.setValue(value) + def set_value_from_string(self, value: str) -> None: + self.set_value(value) def copy(self) -> Parameter[str]: - parameter = ColormapParameter(isCyclic=self._isCyclic) - parameter.setValue(self.getValue()) + parameter = ColormapParameter(is_cyclic=self._is_cyclic) + parameter.set_value(self.get_value()) return parameter - def getPlugin(self) -> Colormap: - return self._chooser.currentPlugin.strategy + def get_plugin(self) -> Colormap: + return self._chooser.get_current_plugin().strategy - def update(self, observable: Observable) -> None: + def _update(self, observable: Observable) -> None: if observable is self._chooser: - self.notifyObservers() + self.notify_observers() diff --git a/src/ptychodus/model/visualization/colormapRenderer.py b/src/ptychodus/model/visualization/colormap_renderer.py similarity index 50% rename from src/ptychodus/model/visualization/colormapRenderer.py rename to src/ptychodus/model/visualization/colormap_renderer.py index 7a9c0ae5..0ceb24bb 100644 --- a/src/ptychodus/model/visualization/colormapRenderer.py +++ b/src/ptychodus/model/visualization/colormap_renderer.py @@ -10,7 +10,7 @@ VisualizationProduct, ) -from .colorAxis import ColorAxis +from .color_axis import ColorAxis from .colormap import ColormapParameter from .components import DataArrayComponent from .renderer import Renderer @@ -22,60 +22,60 @@ def __init__( self, component: DataArrayComponent, transformation: ScalarTransformationParameter, - colorAxis: ColorAxis, + color_axis: ColorAxis, colormap: ColormapParameter, ) -> None: super().__init__(component.name) self._component = component self._transformation = transformation - self._addParameter('transformation', transformation) - self._colorAxis = colorAxis - self._addGroup('color_axis', colorAxis, observe=True) + self._add_parameter('transformation', transformation) + self._color_axis = color_axis + self._add_group('color_axis', color_axis, observe=True) self._colormap = colormap - self._addParameter('colormap', colormap) + self._add_parameter('colormap', colormap) def variants(self) -> Iterator[str]: return self._colormap.choices() - def getVariant(self) -> str: - return self._colormap.getValue() + def get_variant(self) -> str: + return self._colormap.get_value() - def setVariant(self, variant: str) -> None: - self._colormap.setValue(variant) + def set_variant(self, variant: str) -> None: + self._colormap.set_value(variant) - def isCyclic(self) -> bool: - return self._component.isCyclic + def is_cyclic(self) -> bool: + return self._component.is_cyclic - def _colorize(self, valuesTransformed: RealArrayType) -> RealArrayType: - vrange = self._colorAxis.getRange() + def _colorize(self, values_transformed: RealArrayType) -> RealArrayType: + vrange = self._color_axis.get_range() norm = Normalize(vmin=vrange.lower, vmax=vrange.upper, clip=False) - cmap = self._colormap.getPlugin() - scalarMappable = ScalarMappable(norm, cmap) - return scalarMappable.to_rgba(valuesTransformed) + cmap = self._colormap.get_plugin() + scalar_mappable = ScalarMappable(norm, cmap) + return scalar_mappable.to_rgba(values_transformed) def colorize(self, array: NumberArrayType) -> RealArrayType: values = self._component.calculate(array) - valuesTransformed = self._transformation.transform(values) - return self._colorize(valuesTransformed) + values_transformed = self._transformation.transform(values) + return self._colorize(values_transformed) def render( self, array: NumberArrayType, - pixelGeometry: PixelGeometry, + pixel_geometry: PixelGeometry, *, - autoscaleColorAxis: bool, + autoscale_color_axis: bool, ) -> VisualizationProduct: values = self._component.calculate(array) - valuesTransformed = self._transformation.transform(values) + values_transformed = self._transformation.transform(values) - if autoscaleColorAxis: - self._colorAxis.setToDataRange(valuesTransformed) + if autoscale_color_axis: + self._color_axis.set_to_data_range(values_transformed) - rgba = self._colorize(valuesTransformed) + rgba = self._colorize(values_transformed) return VisualizationProduct( - valueLabel=self._transformation.decorateText(self._component.name), + value_label=self._transformation.decorate_text(self._component.name), values=array, rgba=rgba, - pixelGeometry=pixelGeometry, + pixel_geometry=pixel_geometry, ) diff --git a/src/ptychodus/model/visualization/components.py b/src/ptychodus/model/visualization/components.py index c7ce88ed..5280c2a9 100644 --- a/src/ptychodus/model/visualization/components.py +++ b/src/ptychodus/model/visualization/components.py @@ -3,21 +3,21 @@ from skimage.restoration import unwrap_phase import numpy -from ptychodus.api.visualization import NumberArrayType, RealArrayType +from ptychodus.api.typing import NumberArrayType, RealArrayType class DataArrayComponent(ABC): - def __init__(self, name: str, *, isCyclic: bool) -> None: + def __init__(self, name: str, *, is_cyclic: bool) -> None: self._name = name - self._isCyclic = isCyclic + self._is_cyclic = is_cyclic @property def name(self) -> str: return self._name @property - def isCyclic(self) -> bool: - return self._isCyclic + def is_cyclic(self) -> bool: + return self._is_cyclic @abstractmethod def calculate(self, array: NumberArrayType) -> RealArrayType: @@ -26,7 +26,7 @@ def calculate(self, array: NumberArrayType) -> RealArrayType: class RealArrayComponent(DataArrayComponent): def __init__(self) -> None: - super().__init__('real', isCyclic=False) + super().__init__('real', is_cyclic=False) def calculate(self, array: NumberArrayType) -> RealArrayType: return numpy.real(array).astype(numpy.single) @@ -34,7 +34,7 @@ def calculate(self, array: NumberArrayType) -> RealArrayType: class ImaginaryArrayComponent(DataArrayComponent): def __init__(self) -> None: - super().__init__('imaginary', isCyclic=False) + super().__init__('imaginary', is_cyclic=False) def calculate(self, array: NumberArrayType) -> RealArrayType: return numpy.imag(array).astype(numpy.single) @@ -42,7 +42,7 @@ def calculate(self, array: NumberArrayType) -> RealArrayType: class AmplitudeArrayComponent(DataArrayComponent): def __init__(self) -> None: - super().__init__('amplitude', isCyclic=False) + super().__init__('amplitude', is_cyclic=False) def calculate(self, array: NumberArrayType) -> RealArrayType: return numpy.absolute(array).astype(numpy.single) @@ -50,7 +50,7 @@ def calculate(self, array: NumberArrayType) -> RealArrayType: class PhaseInRadiansArrayComponent(DataArrayComponent): def __init__(self) -> None: - super().__init__('phase', isCyclic=True) + super().__init__('phase', is_cyclic=True) def calculate(self, array: NumberArrayType) -> RealArrayType: return numpy.angle(array).astype(numpy.single) # type: ignore @@ -58,8 +58,8 @@ def calculate(self, array: NumberArrayType) -> RealArrayType: class UnwrappedPhaseInRadiansArrayComponent(DataArrayComponent): def __init__(self) -> None: - super().__init__('unwrapped_phase', isCyclic=False) + super().__init__('unwrapped_phase', is_cyclic=False) def calculate(self, array: NumberArrayType) -> RealArrayType: - phaseInRadians = numpy.angle(array).astype(numpy.single) # type: ignore - return unwrap_phase(phaseInRadians) + phase_rad = numpy.angle(array).astype(numpy.single) # type: ignore + return unwrap_phase(phase_rad) diff --git a/src/ptychodus/model/visualization/core.py b/src/ptychodus/model/visualization/core.py index 754706de..7f6edfd6 100644 --- a/src/ptychodus/model/visualization/core.py +++ b/src/ptychodus/model/visualization/core.py @@ -11,10 +11,10 @@ VisualizationProduct, ) -from .colorAxis import ColorAxis -from .colorModelRenderer import CylindricalColorModelRenderer +from .color_axis import ColorAxis +from .color_model_renderer import CylindricalColorModelRenderer from .colormap import ColormapParameter -from .colormapRenderer import ColormapRenderer +from .colormap_renderer import ColormapRenderer from .components import ( AmplitudeArrayComponent, ImaginaryArrayComponent, @@ -30,144 +30,144 @@ class VisualizationEngine(Observable, Observer): - def __init__(self, *, isComplex: bool) -> None: + def __init__(self, *, is_complex: bool) -> None: super().__init__() - self._rendererChooser = PluginChooser[Renderer]() + self._renderer_chooser = PluginChooser[Renderer]() self._transformation = ScalarTransformationParameter() - self._colorAxis = ColorAxis() - acyclicColormap = ColormapParameter(isCyclic=False) + self._color_axis = ColorAxis() + acyclic_colormap = ColormapParameter(is_cyclic=False) - self._rendererChooser.registerPlugin( + self._renderer_chooser.register_plugin( ColormapRenderer( RealArrayComponent(), self._transformation, - self._colorAxis, - acyclicColormap, + self._color_axis, + acyclic_colormap, ), - displayName='Real', + display_name='Real', ) - if isComplex: - amplitudeComponent = AmplitudeArrayComponent() - phaseComponent = PhaseInRadiansArrayComponent() - cyclicColormap = ColormapParameter(isCyclic=True) + if is_complex: + amplitude_component = AmplitudeArrayComponent() + phase_component = PhaseInRadiansArrayComponent() + cyclic_colormap = ColormapParameter(is_cyclic=True) - self._rendererChooser.registerPlugin( + self._renderer_chooser.register_plugin( ColormapRenderer( ImaginaryArrayComponent(), self._transformation, - self._colorAxis, - acyclicColormap, + self._color_axis, + acyclic_colormap, ), - displayName='Imaginary', + display_name='Imaginary', ) - self._rendererChooser.registerPlugin( + self._renderer_chooser.register_plugin( ColormapRenderer( - amplitudeComponent, + amplitude_component, self._transformation, - self._colorAxis, - acyclicColormap, + self._color_axis, + acyclic_colormap, ), - displayName='Amplitude', + display_name='Amplitude', ) - self._rendererChooser.registerPlugin( + self._renderer_chooser.register_plugin( ColormapRenderer( - phaseComponent, + phase_component, self._transformation, - self._colorAxis, - cyclicColormap, + self._color_axis, + cyclic_colormap, ), - displayName='Phase', + display_name='Phase', ) - self._rendererChooser.registerPlugin( + self._renderer_chooser.register_plugin( ColormapRenderer( UnwrappedPhaseInRadiansArrayComponent(), self._transformation, - self._colorAxis, - acyclicColormap, + self._color_axis, + acyclic_colormap, ), - displayName='Phase (Unwrapped)', + display_name='Phase (Unwrapped)', ) - self._rendererChooser.registerPlugin( + self._renderer_chooser.register_plugin( CylindricalColorModelRenderer( - amplitudeComponent, - phaseComponent, + amplitude_component, + phase_component, self._transformation, - self._colorAxis, + self._color_axis, ), - displayName='Complex', + display_name='Complex', ) - self._rendererChooser.setCurrentPluginByName('Complex') + self._renderer_chooser.set_current_plugin('Complex') - self._rendererChooser.addObserver(self) - self._rendererPlugin = self._rendererChooser.currentPlugin - self._rendererPlugin.strategy.addObserver(self) + self._renderer_chooser.add_observer(self) + self._renderer_plugin = self._renderer_chooser.get_current_plugin() + self._renderer_plugin.strategy.add_observer(self) def renderers(self) -> Iterator[str]: - for plugin in self._rendererChooser: - yield plugin.displayName + for plugin in self._renderer_chooser: + yield plugin.display_name - def getRenderer(self) -> str: - return self._rendererPlugin.displayName + def get_renderer(self) -> str: + return self._renderer_plugin.display_name - def setRenderer(self, value: str) -> None: - self._rendererChooser.setCurrentPluginByName(value) + def set_renderer(self, value: str) -> None: + self._renderer_chooser.set_current_plugin(value) - def isRendererCyclic(self) -> bool: - return self._rendererPlugin.strategy.isCyclic() + def is_renderer_cyclic(self) -> bool: + return self._renderer_plugin.strategy.is_cyclic() def transformations(self) -> Iterator[str]: return self._transformation.choices() - def getTransformation(self) -> str: - return self._transformation.getValue() + def get_transformation(self) -> str: + return self._transformation.get_value() - def setTransformation(self, value: str) -> None: - self._transformation.setValue(value) + def set_transformation(self, value: str) -> None: + self._transformation.set_value(value) def variants(self) -> Iterator[str]: - return self._rendererPlugin.strategy.variants() + return self._renderer_plugin.strategy.variants() - def getVariant(self) -> str: - return self._rendererPlugin.strategy.getVariant() + def get_variant(self) -> str: + return self._renderer_plugin.strategy.get_variant() - def setVariant(self, value: str) -> None: - return self._rendererPlugin.strategy.setVariant(value) + def set_variant(self, value: str) -> None: + return self._renderer_plugin.strategy.set_variant(value) - def getMinDisplayValue(self) -> float: - return self._colorAxis.lower.getValue() + def get_min_display_value(self) -> float: + return self._color_axis.lower.get_value() - def setMinDisplayValue(self, value: float) -> None: - self._colorAxis.lower.setValue(value) + def set_min_display_value(self, value: float) -> None: + self._color_axis.lower.set_value(value) - def getMaxDisplayValue(self) -> float: - return self._colorAxis.upper.getValue() + def get_max_display_value(self) -> float: + return self._color_axis.upper.get_value() - def setMaxDisplayValue(self, value: float) -> None: - self._colorAxis.upper.setValue(value) + def set_max_display_value(self, value: float) -> None: + self._color_axis.upper.set_value(value) - def setDisplayValueRange(self, lower: float, upper: float) -> None: - self._colorAxis.setRange(lower, upper) + def set_display_value_range(self, lower: float, upper: float) -> None: + self._color_axis.set_range(lower, upper) def colorize(self, array: NumberArrayType) -> RealArrayType: - return self._rendererPlugin.strategy.colorize(array) + return self._renderer_plugin.strategy.colorize(array) def render( self, array: NumberArrayType, - pixelGeometry: PixelGeometry, + pixel_geometry: PixelGeometry, *, - autoscaleColorAxis: bool, + autoscale_color_axis: bool, ) -> VisualizationProduct: - return self._rendererPlugin.strategy.render( - array, pixelGeometry, autoscaleColorAxis=autoscaleColorAxis + return self._renderer_plugin.strategy.render( + array, pixel_geometry, autoscale_color_axis=autoscale_color_axis ) - def update(self, observable: Observable) -> None: - if observable is self._rendererChooser: - self._rendererPlugin.strategy.removeObserver(self) - self._rendererPlugin = self._rendererChooser.currentPlugin - self._rendererPlugin.strategy.addObserver(self) - self.notifyObservers() - elif observable is self._rendererPlugin.strategy: - self.notifyObservers() + def _update(self, observable: Observable) -> None: + if observable is self._renderer_chooser: + self._renderer_plugin.strategy.remove_observer(self) + self._renderer_plugin = self._renderer_chooser.get_current_plugin() + self._renderer_plugin.strategy.add_observer(self) + self.notify_observers() + elif observable is self._renderer_plugin.strategy: + self.notify_observers() diff --git a/src/ptychodus/model/visualization/renderer.py b/src/ptychodus/model/visualization/renderer.py index 59951009..9befe623 100644 --- a/src/ptychodus/model/visualization/renderer.py +++ b/src/ptychodus/model/visualization/renderer.py @@ -13,25 +13,25 @@ class Renderer(ParameterGroup): def __init__(self, name: str) -> None: super().__init__() - self._name = self.createStringParameter('name', name) + self._name = self.create_string_parameter('name', name) - def getName(self) -> str: - return self._name.getValue() + def get_name(self) -> str: + return self._name.get_value() @abstractmethod def variants(self) -> Iterator[str]: pass @abstractmethod - def getVariant(self) -> str: + def get_variant(self) -> str: pass @abstractmethod - def setVariant(self, variant: str) -> None: + def set_variant(self, variant: str) -> None: pass @abstractmethod - def isCyclic(self) -> bool: + def is_cyclic(self) -> bool: pass @abstractmethod @@ -40,6 +40,6 @@ def colorize(self, array: NumberArrayType) -> RealArrayType: @abstractmethod def render( - self, array: NumberArrayType, pixelGeometry: PixelGeometry, *, autoscaleColorAxis: bool + self, array: NumberArrayType, pixel_geometry: PixelGeometry, *, autoscale_color_axis: bool ) -> VisualizationProduct: pass diff --git a/src/ptychodus/model/visualization/transformation.py b/src/ptychodus/model/visualization/transformation.py index a14e6897..98702e43 100644 --- a/src/ptychodus/model/visualization/transformation.py +++ b/src/ptychodus/model/visualization/transformation.py @@ -6,7 +6,7 @@ from ptychodus.api.observer import Observable, Observer from ptychodus.api.parametric import Parameter from ptychodus.api.plugins import PluginChooser -from ptychodus.api.visualization import RealArrayType +from ptychodus.api.typing import RealArrayType __all__ = [ 'ScalarTransformation', @@ -16,7 +16,7 @@ class ScalarTransformation(ABC): @abstractmethod - def decorateText(self, text: str) -> str: + def decorate_text(self, text: str) -> str: pass @abstractmethod @@ -25,7 +25,7 @@ def __call__(self, array: RealArrayType) -> RealArrayType: class IdentityScalarTransformation(ScalarTransformation): - def decorateText(self, text: str) -> str: + def decorate_text(self, text: str) -> str: return text def __call__(self, array: RealArrayType) -> RealArrayType: @@ -33,7 +33,7 @@ def __call__(self, array: RealArrayType) -> RealArrayType: class SquareRootScalarTransformation(ScalarTransformation): - def decorateText(self, text: str) -> str: + def decorate_text(self, text: str) -> str: return f'$\\sqrt{{\\mathrm{{{text}}}}}$' def __call__(self, array: RealArrayType) -> RealArrayType: @@ -42,7 +42,7 @@ def __call__(self, array: RealArrayType) -> RealArrayType: class Log2ScalarTransformation(ScalarTransformation): - def decorateText(self, text: str) -> str: + def decorate_text(self, text: str) -> str: return f'$\\log_2{{\\left(\\mathrm{{{text}}}\\right)}}$' def __call__(self, array: RealArrayType) -> RealArrayType: @@ -51,7 +51,7 @@ def __call__(self, array: RealArrayType) -> RealArrayType: class LogScalarTransformation(ScalarTransformation): - def decorateText(self, text: str) -> str: + def decorate_text(self, text: str) -> str: return f'$\\ln{{\\left(\\mathrm{{{text}}}\\right)}}$' def __call__(self, array: RealArrayType) -> RealArrayType: @@ -60,7 +60,7 @@ def __call__(self, array: RealArrayType) -> RealArrayType: class Log10ScalarTransformation(ScalarTransformation): - def decorateText(self, text: str) -> str: + def decorate_text(self, text: str) -> str: return f'$\\log_{{10}}{{\\left(\\mathrm{{{text}}}\\right)}}$' def __call__(self, array: RealArrayType) -> RealArrayType: @@ -72,61 +72,61 @@ class ScalarTransformationParameter(Parameter[str], Observer): def __init__(self) -> None: super().__init__() self._chooser = PluginChooser[ScalarTransformation]() - self._chooser.registerPlugin( + self._chooser.register_plugin( IdentityScalarTransformation(), - displayName='Identity', + display_name='Identity', ) - self._chooser.registerPlugin( + self._chooser.register_plugin( SquareRootScalarTransformation(), - simpleName='sqrt', - displayName='Square Root', + simple_name='sqrt', + display_name='Square Root', ) - self._chooser.registerPlugin( + self._chooser.register_plugin( Log2ScalarTransformation(), - simpleName='log2', - displayName='Logarithm (Base 2)', + simple_name='log2', + display_name='Logarithm (Base 2)', ) - self._chooser.registerPlugin( + self._chooser.register_plugin( LogScalarTransformation(), - simpleName='ln', - displayName='Natural Logarithm', + simple_name='ln', + display_name='Natural Logarithm', ) - self._chooser.registerPlugin( + self._chooser.register_plugin( Log10ScalarTransformation(), - simpleName='log10', - displayName='Logarithm (Base 10)', + simple_name='log10', + display_name='Logarithm (Base 10)', ) - self.setValue('Identity') - self._chooser.addObserver(self) + self.set_value('Identity') + self._chooser.add_observer(self) def choices(self) -> Iterator[str]: - for name in self._chooser.getDisplayNameList(): - yield name + for plugin in self._chooser: + yield plugin.display_name - def getValue(self) -> str: - return self._chooser.currentPlugin.displayName + def get_value(self) -> str: + return self._chooser.get_current_plugin().display_name - def setValue(self, value: str, *, notify: bool = True) -> None: - self._chooser.setCurrentPluginByName(value) + def set_value(self, value: str, *, notify: bool = True) -> None: + self._chooser.set_current_plugin(value) - def getValueAsString(self) -> str: - return self.getValue() + def get_value_as_string(self) -> str: + return self.get_value() - def setValueFromString(self, value: str) -> None: - self.setValue(value) + def set_value_from_string(self, value: str) -> None: + self.set_value(value) def copy(self) -> Parameter[str]: parameter = ScalarTransformationParameter() - parameter.setValue(self.getValue()) + parameter.set_value(self.get_value()) return parameter - def decorateText(self, text: str) -> str: - return self._chooser.currentPlugin.strategy.decorateText(text) + def decorate_text(self, text: str) -> str: + return self._chooser.get_current_plugin().strategy.decorate_text(text) def transform(self, values: RealArrayType) -> RealArrayType: - return self._chooser.currentPlugin.strategy(values) + return self._chooser.get_current_plugin().strategy(values) - def update(self, observable: Observable) -> None: + def _update(self, observable: Observable) -> None: if observable is self._chooser: - self.notifyObservers() + self.notify_observers() diff --git a/src/ptychodus/model/workflow/api.py b/src/ptychodus/model/workflow/api.py index 49f90319..f3a42782 100644 --- a/src/ptychodus/model/workflow/api.py +++ b/src/ptychodus/model/workflow/api.py @@ -20,156 +20,159 @@ class ConcreteWorkflowProductAPI(WorkflowProductAPI): def __init__( self, - productAPI: ProductAPI, - scanAPI: ScanAPI, - probeAPI: ProbeAPI, - objectAPI: ObjectAPI, - reconstructorAPI: ReconstructorAPI, + product_api: ProductAPI, + scan_api: ScanAPI, + probe_api: ProbeAPI, + object_api: ObjectAPI, + reconstructor_api: ReconstructorAPI, executor: WorkflowExecutor, - productIndex: int, + product_index: int, ) -> None: - self._productAPI = productAPI - self._scanAPI = scanAPI - self._probeAPI = probeAPI - self._objectAPI = objectAPI - self._reconstructorAPI = reconstructorAPI + self._product_api = product_api + self._scan_api = scan_api + self._probe_api = probe_api + self._object_api = object_api + self._reconstructor_api = reconstructor_api self._executor = executor - self._productIndex = productIndex + self._product_index = product_index - def openScan(self, filePath: Path, *, fileType: str | None = None) -> None: - self._scanAPI.openScan(self._productIndex, filePath, fileType=fileType) + def open_scan(self, file_path: Path, *, file_type: str | None = None) -> None: + self._scan_api.open_scan(self._product_index, file_path, file_type=file_type) - def buildScan( - self, builderName: str | None = None, builderParameters: Mapping[str, Any] = {} + def build_scan( + self, builder_name: str | None = None, builder_parameters: Mapping[str, Any] = {} ) -> None: - if builderName is None: - self._scanAPI.buildScanFromSettings(self._productIndex) + if builder_name is None: + self._scan_api.build_scan_from_settings(self._product_index) else: - self._scanAPI.buildScan(self._productIndex, builderName, builderParameters) + self._scan_api.build_scan(self._product_index, builder_name, builder_parameters) - def openProbe(self, filePath: Path, *, fileType: str | None = None) -> None: - self._probeAPI.openProbe(self._productIndex, filePath, fileType=fileType) + def open_probe(self, file_path: Path, *, file_type: str | None = None) -> None: + self._probe_api.open_probe(self._product_index, file_path, file_type=file_type) - def buildProbe( - self, builderName: str | None = None, builderParameters: Mapping[str, Any] = {} + def build_probe( + self, builder_name: str | None = None, builder_parameters: Mapping[str, Any] = {} ) -> None: - if builderName is None: - self._probeAPI.buildProbeFromSettings(self._productIndex) + if builder_name is None: + self._probe_api.build_probe_from_settings(self._product_index) else: - self._probeAPI.buildProbe(self._productIndex, builderName, builderParameters) + self._probe_api.build_probe(self._product_index, builder_name, builder_parameters) - def openObject(self, filePath: Path, *, fileType: str | None = None) -> None: - self._objectAPI.openObject(self._productIndex, filePath, fileType=fileType) + def open_object(self, file_path: Path, *, file_type: str | None = None) -> None: + self._object_api.open_object(self._product_index, file_path, file_type=file_type) - def buildObject( - self, builderName: str | None = None, builderParameters: Mapping[str, Any] = {} + def build_object( + self, builder_name: str | None = None, builder_parameters: Mapping[str, Any] = {} ) -> None: - if builderName is None: - self._objectAPI.buildObjectFromSettings(self._productIndex) + if builder_name is None: + self._object_api.build_object_from_settings(self._product_index) else: - self._objectAPI.buildObject(self._productIndex, builderName, builderParameters) + self._object_api.build_object(self._product_index, builder_name, builder_parameters) + + def reconstruct_local(self) -> WorkflowProductAPI: + logger.debug(f'Reconstruct: index={self._product_index}') + output_product_index = self._reconstructor_api.reconstruct(self._product_index) - def reconstructLocal(self, outputProductName: str) -> WorkflowProductAPI: - logger.debug(f'Reconstruct: index={self._productIndex}') - outputProductIndex = self._reconstructorAPI.reconstruct( - self._productIndex, outputProductName - ) return ConcreteWorkflowProductAPI( - self._productAPI, - self._scanAPI, - self._probeAPI, - self._objectAPI, - self._reconstructorAPI, + self._product_api, + self._scan_api, + self._probe_api, + self._object_api, + self._reconstructor_api, self._executor, - outputProductIndex, + output_product_index, ) - def reconstructRemote(self) -> None: - logger.debug(f'Execute Workflow: index={self._productIndex}') - self._executor.runFlow(self._productIndex) + def reconstruct_remote(self) -> None: + logger.debug(f'Execute Workflow: index={self._product_index}') + self._executor.run_flow(self._product_index) - def saveProduct(self, filePath: Path, *, fileType: str | None = None) -> None: - self._productAPI.saveProduct(self._productIndex, filePath, fileType=fileType) + def save_product(self, file_path: Path, *, file_type: str | None = None) -> None: + self._product_api.save_product(self._product_index, file_path, file_type=file_type) class ConcreteWorkflowAPI(WorkflowAPI): def __init__( self, - settingsRegistry: SettingsRegistry, - patternsAPI: PatternsAPI, - productAPI: ProductAPI, - scanAPI: ScanAPI, - probeAPI: ProbeAPI, - objectAPI: ObjectAPI, - reconstructorAPI: ReconstructorAPI, + settings_registry: SettingsRegistry, + patterns_api: PatternsAPI, + product_api: ProductAPI, + scan_api: ScanAPI, + probe_api: ProbeAPI, + object_api: ObjectAPI, + reconstructor_api: ReconstructorAPI, executor: WorkflowExecutor, ) -> None: - self._settingsRegistry = settingsRegistry - self._patternsAPI = patternsAPI - self._productAPI = productAPI - self._scanAPI = scanAPI - self._probeAPI = probeAPI - self._objectAPI = objectAPI - self._reconstructorAPI = reconstructorAPI + self._settings_registry = settings_registry + self._patterns_api = patterns_api + self._product_api = product_api + self._scan_api = scan_api + self._probe_api = probe_api + self._object_api = object_api + self._reconstructor_api = reconstructor_api self._executor = executor - def _createProductAPI(self, productIndex: int) -> WorkflowProductAPI: - if productIndex < 0: - raise ValueError(f'Bad product index ({productIndex=})!') + def _create_product_api(self, product_index: int) -> WorkflowProductAPI: + if product_index < 0: + raise ValueError(f'Bad product index ({product_index=})!') return ConcreteWorkflowProductAPI( - self._productAPI, - self._scanAPI, - self._probeAPI, - self._objectAPI, - self._reconstructorAPI, + self._product_api, + self._scan_api, + self._probe_api, + self._object_api, + self._reconstructor_api, self._executor, - productIndex, + product_index, ) - def openPatterns( + def open_patterns( self, - filePath: Path, + file_path: Path, *, - fileType: str | None = None, - cropCenter: CropCenter | None = None, - cropExtent: ImageExtent | None = None, + file_type: str | None = None, + crop_center: CropCenter | None = None, + crop_extent: ImageExtent | None = None, ) -> None: - self._patternsAPI.openPatterns( - filePath, fileType=fileType, cropCenter=cropCenter, cropExtent=cropExtent + self._patterns_api.open_patterns( + file_path, file_type=file_type, crop_center=crop_center, crop_extent=crop_extent ) - def importProcessedPatterns(self, filePath: Path) -> None: - self._patternsAPI.importProcessedPatterns(filePath) + def import_assembled_patterns(self, file_path: Path) -> None: + self._patterns_api.import_assembled_patterns(file_path) - def exportProcessedPatterns(self, filePath: Path) -> None: - self._patternsAPI.exportProcessedPatterns(filePath) + def export_assembled_patterns(self, file_path: Path) -> None: + self._patterns_api.export_assembled_patterns(file_path) - def openProduct(self, filePath: Path, *, fileType: str | None = None) -> WorkflowProductAPI: - productIndex = self._productAPI.openProduct(filePath, fileType=fileType) - return self._createProductAPI(productIndex) + def open_product(self, file_path: Path, *, file_type: str | None = None) -> WorkflowProductAPI: + product_index = self._product_api.open_product(file_path, file_type=file_type) + return self._create_product_api(product_index) - def createProduct( + def create_product( self, name: str, *, comments: str = '', - detectorDistanceInMeters: float | None = None, - probeEnergyInElectronVolts: float | None = None, - probePhotonsPerSecond: float | None = None, - exposureTimeInSeconds: float | None = None, + detector_distance_m: float | None = None, + probe_energy_eV: float | None = None, # noqa: N803 + probe_photon_count: float | None = None, + exposure_time_s: float | None = None, + mass_attenuation_m2_kg: float | None = None, + tomography_angle_deg: float | None = None, ) -> WorkflowProductAPI: - productIndex = self._productAPI.insertNewProduct( + product_index = self._product_api.insert_new_product( name, comments=comments, - detectorDistanceInMeters=detectorDistanceInMeters, - probeEnergyInElectronVolts=probeEnergyInElectronVolts, - probePhotonsPerSecond=probePhotonsPerSecond, - exposureTimeInSeconds=exposureTimeInSeconds, + detector_distance_m=detector_distance_m, + probe_energy_eV=probe_energy_eV, + probe_photon_count=probe_photon_count, + exposure_time_s=exposure_time_s, + mass_attenuation_m2_kg=mass_attenuation_m2_kg, + tomography_angle_deg=tomography_angle_deg, ) - return self._createProductAPI(productIndex) + return self._create_product_api(product_index) - def saveSettings( - self, filePath: Path, changePathPrefix: PathPrefixChange | None = None + def save_settings( + self, file_path: Path, change_path_prefix: PathPrefixChange | None = None ) -> None: - self._settingsRegistry.saveSettings(filePath, changePathPrefix) + self._settings_registry.save_settings(file_path, change_path_prefix) diff --git a/src/ptychodus/model/workflow/authorizer.py b/src/ptychodus/model/workflow/authorizer.py index 624e8ca7..604335fb 100644 --- a/src/ptychodus/model/workflow/authorizer.py +++ b/src/ptychodus/model/workflow/authorizer.py @@ -7,37 +7,37 @@ class WorkflowAuthorizer: def __init__(self) -> None: super().__init__() - self._authorizeLock = threading.Lock() - self._authorizeCode = str() - self._authorizeURL = 'https://aps.anl.gov' - self.isAuthorizedEvent = threading.Event() - self.isAuthorizedEvent.set() - self.shutdownEvent = threading.Event() + self._authorize_lock = threading.Lock() + self._authorize_code = str() + self._authorize_url = 'https://aps.anl.gov' + self.is_authorized_event = threading.Event() + self.is_authorized_event.set() + self.shutdown_event = threading.Event() @property - def isAuthorized(self) -> bool: - return self.isAuthorizedEvent.is_set() + def is_authorized(self) -> bool: + return self.is_authorized_event.is_set() - def getAuthorizeURL(self) -> str: - with self._authorizeLock: - return self._authorizeURL + def get_authorize_url(self) -> str: + with self._authorize_lock: + return self._authorize_url - def setCodeFromAuthorizeURL(self, code: str) -> None: - with self._authorizeLock: - self._authorizeCode = code - self.isAuthorizedEvent.set() + def set_code_from_authorize_url(self, code: str) -> None: + with self._authorize_lock: + self._authorize_code = code + self.is_authorized_event.set() - def getCodeFromAuthorizeURL(self) -> str: - with self._authorizeLock: - return self._authorizeCode + def get_code_from_authorize_url(self) -> str: + with self._authorize_lock: + return self._authorize_code - def authenticate(self, authorizeURL: str) -> None: - logger.info(f'Authenticate at {authorizeURL}') + def authenticate(self, authorize_url: str) -> None: + logger.info(f'Authenticate at {authorize_url}') - with self._authorizeLock: - self._authorizeURL = authorizeURL - self.isAuthorizedEvent.clear() + with self._authorize_lock: + self._authorize_url = authorize_url + self.is_authorized_event.clear() - while not self.shutdownEvent.is_set(): - if self.isAuthorizedEvent.wait(timeout=1.0): + while not self.shutdown_event.is_set(): + if self.is_authorized_event.wait(timeout=1.0): break diff --git a/src/ptychodus/model/workflow/core.py b/src/ptychodus/model/workflow/core.py index e08eaa82..dd9c5180 100644 --- a/src/ptychodus/model/workflow/core.py +++ b/src/ptychodus/model/workflow/core.py @@ -1,4 +1,3 @@ -from __future__ import annotations from collections.abc import Sequence from datetime import datetime from pathlib import Path @@ -28,106 +27,96 @@ class WorkflowParametersPresenter(Observable, Observer): def __init__( self, settings: WorkflowSettings, - inputDataLocator: DataLocator, - computeDataLocator: DataLocator, - outputDataLocator: OutputDataLocator, + input_data_locator: DataLocator, + compute_data_locator: DataLocator, + output_data_locator: OutputDataLocator, ) -> None: super().__init__() self._settings = settings - self._inputDataLocator = inputDataLocator - self._computeDataLocator = computeDataLocator - self._outputDataLocator = outputDataLocator + self._input_data_locator = input_data_locator + self._compute_data_locator = compute_data_locator + self._output_data_locator = output_data_locator - @classmethod - def createInstance( - cls, - settings: WorkflowSettings, - inputDataLocator: DataLocator, - computeDataLocator: DataLocator, - outputDataLocator: OutputDataLocator, - ) -> WorkflowParametersPresenter: - presenter = cls(settings, inputDataLocator, computeDataLocator, outputDataLocator) - settings.addObserver(presenter) - inputDataLocator.addObserver(presenter) - computeDataLocator.addObserver(presenter) - outputDataLocator.addObserver(presenter) - return presenter + settings.add_observer(self) + input_data_locator.add_observer(self) + compute_data_locator.add_observer(self) + output_data_locator.add_observer(self) - def setInputDataEndpointID(self, endpointID: UUID) -> None: - self._inputDataLocator.setEndpointID(endpointID) + def set_input_data_endpoint_id(self, endpoint_id: UUID) -> None: + self._input_data_locator.set_endpoint_id(endpoint_id) - def getInputDataEndpointID(self) -> UUID: - return self._inputDataLocator.getEndpointID() + def get_input_data_endpoint_id(self) -> UUID: + return self._input_data_locator.get_endpoint_id() - def setInputDataGlobusPath(self, globusPath: str) -> None: - self._inputDataLocator.setGlobusPath(globusPath) + def set_input_data_globus_path(self, globus_path: str) -> None: + self._input_data_locator.set_globus_path(globus_path) - def getInputDataGlobusPath(self) -> str: - return self._inputDataLocator.getGlobusPath() + def get_input_data_globus_path(self) -> str: + return self._input_data_locator.get_globus_path() - def setInputDataPosixPath(self, posixPath: Path) -> None: - self._inputDataLocator.setPosixPath(posixPath) + def set_input_data_posix_path(self, posix_path: Path) -> None: + self._input_data_locator.set_posix_path(posix_path) - def getInputDataPosixPath(self) -> Path: - return self._inputDataLocator.getPosixPath() + def get_input_data_posix_path(self) -> Path: + return self._input_data_locator.get_posix_path() - def setComputeEndpointID(self, endpointID: UUID) -> None: - self._settings.computeEndpointID.setValue(endpointID) + def set_compute_endpoint_id(self, endpoint_id: UUID) -> None: + self._settings.compute_endpoint_id.set_value(endpoint_id) - def getComputeEndpointID(self) -> UUID: - return self._settings.computeEndpointID.getValue() + def get_compute_endpoint_id(self) -> UUID: + return self._settings.compute_endpoint_id.get_value() - def setComputeDataEndpointID(self, endpointID: UUID) -> None: - self._computeDataLocator.setEndpointID(endpointID) + def set_compute_data_endpoint_id(self, endpoint_id: UUID) -> None: + self._compute_data_locator.set_endpoint_id(endpoint_id) - def getComputeDataEndpointID(self) -> UUID: - return self._computeDataLocator.getEndpointID() + def get_compute_data_endpoint_id(self) -> UUID: + return self._compute_data_locator.get_endpoint_id() - def setComputeDataGlobusPath(self, globusPath: str) -> None: - self._computeDataLocator.setGlobusPath(globusPath) + def set_compute_data_globus_path(self, globus_path: str) -> None: + self._compute_data_locator.set_globus_path(globus_path) - def getComputeDataGlobusPath(self) -> str: - return self._computeDataLocator.getGlobusPath() + def get_compute_data_globus_path(self) -> str: + return self._compute_data_locator.get_globus_path() - def setComputeDataPosixPath(self, posixPath: Path) -> None: - self._computeDataLocator.setPosixPath(posixPath) + def set_compute_data_posix_path(self, posix_path: Path) -> None: + self._compute_data_locator.set_posix_path(posix_path) - def getComputeDataPosixPath(self) -> Path: - return self._computeDataLocator.getPosixPath() + def get_compute_data_posix_path(self) -> Path: + return self._compute_data_locator.get_posix_path() - def setRoundTripEnabled(self, enable: bool) -> None: - self._outputDataLocator.setRoundTripEnabled(enable) + def set_round_trip_enabled(self, enable: bool) -> None: + self._output_data_locator.set_round_trip_enabled(enable) - def isRoundTripEnabled(self) -> bool: - return self._outputDataLocator.isRoundTripEnabled() + def is_round_trip_enabled(self) -> bool: + return self._output_data_locator.is_round_trip_enabled() - def setOutputDataEndpointID(self, endpointID: UUID) -> None: - self._outputDataLocator.setEndpointID(endpointID) + def set_output_data_endpoint_id(self, endpoint_id: UUID) -> None: + self._output_data_locator.set_endpoint_id(endpoint_id) - def getOutputDataEndpointID(self) -> UUID: - return self._outputDataLocator.getEndpointID() + def get_output_data_endpoint_id(self) -> UUID: + return self._output_data_locator.get_endpoint_id() - def setOutputDataGlobusPath(self, globusPath: str) -> None: - self._outputDataLocator.setGlobusPath(globusPath) + def set_output_data_globus_path(self, globus_path: str) -> None: + self._output_data_locator.set_globus_path(globus_path) - def getOutputDataGlobusPath(self) -> str: - return self._outputDataLocator.getGlobusPath() + def get_output_data_globus_path(self) -> str: + return self._output_data_locator.get_globus_path() - def setOutputDataPosixPath(self, posixPath: Path) -> None: - self._outputDataLocator.setPosixPath(posixPath) + def set_output_data_posix_path(self, posix_path: Path) -> None: + self._output_data_locator.set_posix_path(posix_path) - def getOutputDataPosixPath(self) -> Path: - return self._outputDataLocator.getPosixPath() + def get_output_data_posix_path(self) -> Path: + return self._output_data_locator.get_posix_path() - def update(self, observable: Observable) -> None: + def _update(self, observable: Observable) -> None: if observable is self._settings: - self.notifyObservers() - elif observable is self._inputDataLocator: - self.notifyObservers() - elif observable is self._computeDataLocator: - self.notifyObservers() - elif observable is self._outputDataLocator: - self.notifyObservers() + self.notify_observers() + elif observable in ( + self._input_data_locator, + self._compute_data_locator, + self._output_data_locator, + ): + self.notify_observers() class WorkflowAuthorizationPresenter: @@ -135,32 +124,35 @@ def __init__(self, authorizer: WorkflowAuthorizer) -> None: self._authorizer = authorizer @property - def isAuthorized(self) -> bool: - return self._authorizer.isAuthorized + def is_authorized(self) -> bool: + return self._authorizer.is_authorized - def getAuthorizeURL(self) -> str: - return self._authorizer.getAuthorizeURL() + def get_authorize_url(self) -> str: + return self._authorizer.get_authorize_url() - def setCodeFromAuthorizeURL(self, code: str) -> None: - self._authorizer.setCodeFromAuthorizeURL(code) + def set_code_from_authorize_url(self, code: str) -> None: + self._authorizer.set_code_from_authorize_url(code) -class WorkflowStatusPresenter: +class WorkflowStatusPresenter(Observable, Observer): def __init__( - self, settings: WorkflowSettings, statusRepository: WorkflowStatusRepository + self, settings: WorkflowSettings, status_repository: WorkflowStatusRepository ) -> None: + super().__init__() self._settings = settings - self._statusRepository = statusRepository + self._status_repository = status_repository + + settings.add_observer(self) - def getRefreshIntervalLimitsInSeconds(self) -> Interval[int]: + def get_refresh_interval_limits_s(self) -> Interval[int]: return Interval[int](10, 86400) - def getRefreshIntervalInSeconds(self) -> int: - limits = self.getRefreshIntervalLimitsInSeconds() - return limits.clamp(self._settings.statusRefreshIntervalInSeconds.getValue()) + def get_refresh_interval_s(self) -> int: + limits = self.get_refresh_interval_limits_s() + return limits.clamp(self._settings.status_refresh_interval_s.get_value()) - def setRefreshIntervalInSeconds(self, seconds: int) -> None: - self._settings.statusRefreshIntervalInSeconds.setValue(seconds) + def set_refresh_interval_s(self, seconds: int) -> None: + self._settings.status_refresh_interval_s.set_value(seconds) @overload def __getitem__(self, index: int) -> WorkflowStatus: ... @@ -169,62 +161,66 @@ def __getitem__(self, index: int) -> WorkflowStatus: ... def __getitem__(self, index: slice) -> Sequence[WorkflowStatus]: ... def __getitem__(self, index: int | slice) -> WorkflowStatus | Sequence[WorkflowStatus]: - return self._statusRepository[index] + return self._status_repository[index] def __len__(self) -> int: - return len(self._statusRepository) + return len(self._status_repository) - def getStatusDateTime(self) -> datetime: - return self._statusRepository.getStatusDateTime() + def get_status_date_time(self) -> datetime: + return self._status_repository.get_status_date_time() - def refreshStatus(self) -> None: - self._statusRepository.refreshStatus() + def refresh_status(self) -> None: + self._status_repository.refresh_status() + + def _update(self, observable: Observable) -> None: + if observable is self._settings: + self.notify_observers() class WorkflowExecutionPresenter: def __init__(self, executor: WorkflowExecutor) -> None: self._executor = executor - def runFlow(self, inputProductIndex: int) -> None: - self._executor.runFlow(inputProductIndex) + def run_flow(self, input_product_index: int) -> None: + self._executor.run_flow(input_product_index) class WorkflowCore: def __init__( self, - settingsRegistry: SettingsRegistry, - patternsAPI: PatternsAPI, - productAPI: ProductAPI, - scanAPI: ScanAPI, - probeAPI: ProbeAPI, - objectAPI: ObjectAPI, - reconstructorAPI: ReconstructorAPI, + settings_registry: SettingsRegistry, + patterns_api: PatternsAPI, + product_api: ProductAPI, + scan_api: ScanAPI, + probe_api: ProbeAPI, + object_api: ObjectAPI, + reconstructor_api: ReconstructorAPI, ) -> None: - self._settings = WorkflowSettings(settingsRegistry) - self._inputDataLocator = SimpleDataLocator(self._settings._settingsGroup, 'Input') - self._computeDataLocator = SimpleDataLocator(self._settings._settingsGroup, 'Compute') - self._outputDataLocator = OutputDataLocator( - self._settings._settingsGroup, 'Output', self._inputDataLocator + self._settings = WorkflowSettings(settings_registry) + self._input_data_locator = SimpleDataLocator(self._settings._group, 'Input') + self._compute_data_locator = SimpleDataLocator(self._settings._group, 'Compute') + self._output_data_locator = OutputDataLocator( + self._settings._group, 'Output', self._input_data_locator ) self._authorizer = WorkflowAuthorizer() - self._statusRepository = WorkflowStatusRepository() + self._status_repository = WorkflowStatusRepository() self._executor = WorkflowExecutor( self._settings, - self._inputDataLocator, - self._computeDataLocator, - self._outputDataLocator, - settingsRegistry, - patternsAPI, - productAPI, + self._input_data_locator, + self._compute_data_locator, + self._output_data_locator, + settings_registry, + patterns_api, + product_api, ) - self.workflowAPI = ConcreteWorkflowAPI( - settingsRegistry, - patternsAPI, - productAPI, - scanAPI, - probeAPI, - objectAPI, - reconstructorAPI, + self.workflow_api = ConcreteWorkflowAPI( + settings_registry, + patterns_api, + product_api, + scan_api, + probe_api, + object_api, + reconstructor_api, self._executor, ) self._thread: threading.Thread | None = None @@ -234,22 +230,22 @@ def __init__( except ModuleNotFoundError: logger.info('Globus not found.') else: - self._thread = GlobusWorkflowThread.createInstance( - self._authorizer, self._statusRepository, self._executor + self._thread = GlobusWorkflowThread.create_instance( + self._authorizer, self._status_repository, self._executor ) - self.parametersPresenter = WorkflowParametersPresenter.createInstance( + self.parameters_presenter = WorkflowParametersPresenter( self._settings, - self._inputDataLocator, - self._computeDataLocator, - self._outputDataLocator, + self._input_data_locator, + self._compute_data_locator, + self._output_data_locator, ) - self.authorizationPresenter = WorkflowAuthorizationPresenter(self._authorizer) - self.statusPresenter = WorkflowStatusPresenter(self._settings, self._statusRepository) - self.executionPresenter = WorkflowExecutionPresenter(self._executor) + self.authorization_presenter = WorkflowAuthorizationPresenter(self._authorizer) + self.status_presenter = WorkflowStatusPresenter(self._settings, self._status_repository) + self.execution_presenter = WorkflowExecutionPresenter(self._executor) @property - def areWorkflowsSupported(self) -> bool: + def is_supported(self) -> bool: return self._thread is not None def start(self) -> None: @@ -262,8 +258,8 @@ def start(self) -> None: def stop(self) -> None: logger.info('Stopping workflow thread...') - self._executor.jobQueue.join() - self._authorizer.shutdownEvent.set() + self._executor.job_queue.join() + self._authorizer.shutdown_event.set() if self._thread: self._thread.join() diff --git a/src/ptychodus/model/workflow/executor.py b/src/ptychodus/model/workflow/executor.py index 594a59fb..dcbfd398 100644 --- a/src/ptychodus/model/workflow/executor.py +++ b/src/ptychodus/model/workflow/executor.py @@ -16,91 +16,91 @@ @dataclass(frozen=True) class WorkflowJob: - flowLabel: str - flowInput: Mapping[str, Any] + flow_label: str + flow_input: Mapping[str, Any] class WorkflowExecutor: def __init__( self, settings: WorkflowSettings, - inputDataLocator: DataLocator, - computeDataLocator: DataLocator, - outputDataLocator: DataLocator, - settingsRegistry: SettingsRegistry, - patternsAPI: PatternsAPI, - productAPI: ProductAPI, + input_data_locator: DataLocator, + compute_data_locator: DataLocator, + output_data_locator: DataLocator, + settings_registry: SettingsRegistry, + patterns_api: PatternsAPI, + product_api: ProductAPI, ) -> None: super().__init__() self._settings = settings - self._inputDataLocator = inputDataLocator - self._computeDataLocator = computeDataLocator - self._outputDataLocator = outputDataLocator - self._productAPI = productAPI - self._settingsRegistry = settingsRegistry - self._patternsAPI = patternsAPI - self.jobQueue: queue.Queue[WorkflowJob] = queue.Queue() - - def runFlow(self, inputProductIndex: int) -> None: - transferSyncLevel = 3 # Copy files if checksums of the source and destination mismatch - ptychodusAction = 'reconstruct' # TODO or 'train' + self._input_data_locator = input_data_locator + self._compute_data_locator = compute_data_locator + self._output_data_locator = output_data_locator + self._product_api = product_api + self._settings_registry = settings_registry + self._patterns_api = patterns_api + self.job_queue: queue.Queue[WorkflowJob] = queue.Queue() + + def run_flow(self, input_product_index: int) -> None: + transfer_sync_level = 3 # Copy files if checksums of the source and destination mismatch + ptychodus_action = 'reconstruct' # TODO or 'train' try: - flowLabel = self._productAPI.getItemName(inputProductIndex) + flow_label = self._product_api.get_item(input_product_index).get_name() except IndexError: - logger.warning(f'Failed access product for flow ({inputProductIndex=})!') + logger.warning(f'Failed access product for flow ({input_product_index=})!') return - inputDataPosixPath = self._inputDataLocator.getPosixPath() / flowLabel - computeDataPosixPath = self._computeDataLocator.getPosixPath() / flowLabel + input_data_posix_path = self._input_data_locator.get_posix_path() / flow_label + compute_data_posix_path = self._compute_data_locator.get_posix_path() / flow_label - inputDataGlobusPath = f'{self._inputDataLocator.getGlobusPath()}/{flowLabel}' - computeDataGlobusPath = f'{self._computeDataLocator.getGlobusPath()}/{flowLabel}' - outputDataGlobusPath = f'{self._outputDataLocator.getGlobusPath()}/{flowLabel}' + input_data_globus_path = f'{self._input_data_locator.get_globus_path()}/{flow_label}' + compute_data_globus_path = f'{self._compute_data_locator.get_globus_path()}/{flow_label}' + output_data_globus_path = f'{self._output_data_locator.get_globus_path()}/{flow_label}' - settingsFile = 'settings.ini' - patternsFile = 'patterns.npz' - inputFile = 'product-in.npz' - outputFile = 'product-out.npz' + settings_file = 'settings.ini' + patterns_file = 'patterns.npz' + input_file = 'product-in.npz' + output_file = 'product-out.npz' try: - inputDataPosixPath.mkdir(mode=0o755, parents=True, exist_ok=True) + input_data_posix_path.mkdir(mode=0o755, parents=True, exist_ok=True) except FileExistsError: logger.warning('Input data POSIX path must be a directory!') return # TODO use workflow API - self._settingsRegistry.saveSettings(inputDataPosixPath / settingsFile) - self._patternsAPI.exportProcessedPatterns(inputDataPosixPath / patternsFile) - self._productAPI.saveProduct( - inputProductIndex, inputDataPosixPath / inputFile, fileType='NPZ' + self._settings_registry.save_settings(input_data_posix_path / settings_file) + self._patterns_api.export_assembled_patterns(input_data_posix_path / patterns_file) + self._product_api.save_product( + input_product_index, input_data_posix_path / input_file, file_type='NPZ' ) - flowInput = { - 'input_data_transfer_source_endpoint_id': str(self._inputDataLocator.getEndpointID()), - 'input_data_transfer_source_path': inputDataGlobusPath, - 'input_data_transfer_destination_endpoint_id': str( - self._computeDataLocator.getEndpointID() + flow_input = { + 'input_data_transfer_source_endpoint': str(self._input_data_locator.get_endpoint_id()), + 'input_data_transfer_source_path': input_data_globus_path, + 'input_data_transfer_destination_endpoint': str( + self._compute_data_locator.get_endpoint_id() ), - 'input_data_transfer_destination_path': computeDataGlobusPath, + 'input_data_transfer_destination_path': compute_data_globus_path, 'input_data_transfer_recursive': True, - 'input_data_transfer_sync_level': transferSyncLevel, - 'compute_endpoint': str(self._settings.computeEndpointID.getValue()), - 'ptychodus_action': ptychodusAction, - 'ptychodus_settings_file': str(computeDataPosixPath / settingsFile), - 'ptychodus_patterns_file': str(computeDataPosixPath / patternsFile), - 'ptychodus_input_file': str(computeDataPosixPath / inputFile), - 'ptychodus_output_file': str(computeDataPosixPath / outputFile), - 'output_data_transfer_source_endpoint_id': str( - self._computeDataLocator.getEndpointID() + 'input_data_transfer_sync_level': transfer_sync_level, + 'compute_endpoint': str(self._settings.compute_endpoint_id.get_value()), + 'ptychodus_action': ptychodus_action, + 'ptychodus_settings_file': str(compute_data_posix_path / settings_file), + 'ptychodus_patterns_file': str(compute_data_posix_path / patterns_file), + 'ptychodus_input_file': str(compute_data_posix_path / input_file), + 'ptychodus_output_file': str(compute_data_posix_path / output_file), + 'output_data_transfer_source_endpoint': str( + self._compute_data_locator.get_endpoint_id() ), - 'output_data_transfer_source_path': f'{computeDataGlobusPath}/{outputFile}', - 'output_data_transfer_destination_endpoint_id': str( - self._outputDataLocator.getEndpointID() + 'output_data_transfer_source_path': f'{compute_data_globus_path}/{output_file}', + 'output_data_transfer_destination_endpoint': str( + self._output_data_locator.get_endpoint_id() ), - 'output_data_transfer_destination_path': f'{outputDataGlobusPath}/{outputFile}', + 'output_data_transfer_destination_path': f'{output_data_globus_path}/{output_file}', 'output_data_transfer_recursive': False, } - input_ = WorkflowJob(flowLabel, flowInput) - self.jobQueue.put(input_) + input_ = WorkflowJob(flow_label, flow_input) + self.job_queue.put(input_) diff --git a/src/ptychodus/model/workflow/globus.py b/src/ptychodus/model/workflow/globus.py index 7df6975e..865f0c19 100644 --- a/src/ptychodus/model/workflow/globus.py +++ b/src/ptychodus/model/workflow/globus.py @@ -35,14 +35,14 @@ def ptychodus_reconstruct(**data: str) -> None: from ptychodus.model import ModelCore action = data['ptychodus_action'] - inputFile = Path(data['ptychodus_input_file']) - outputFile = Path(data['ptychodus_output_file']) - settingsFile = Path(data['ptychodus_settings_file']) - patternsFile = Path(data['ptychodus_patterns_file']) + input_file = Path(data['ptychodus_input_file']) + output_file = Path(data['ptychodus_output_file']) + settings_file = Path(data['ptychodus_settings_file']) + patterns_file = Path(data['ptychodus_patterns_file']) - with ModelCore(settingsFile) as model: - model.workflowAPI.importProcessedPatterns(patternsFile) - model.batchModeExecute(action, inputFile, outputFile) + with ModelCore(settings_file) as model: + model.workflow_api.import_assembled_patterns(patterns_file) + model.batch_mode_execute(action, input_file, output_file) @gladier.generate_flow_definition @@ -60,8 +60,6 @@ class PtychodusReconstruct(gladier.GladierBaseTool): @gladier.generate_flow_definition class PtychodusClient(gladier.GladierBaseClient): client_id = PTYCHODUS_CLIENT_ID - globus_group = '13e5512f-e761-11ec-8a9e-ff9dc0f99d56' - gladier_tools = [ 'gladier_tools.globus.transfer.Transfer:InputData', PtychodusReconstruct, @@ -81,7 +79,7 @@ def authenticate(self, url: str) -> str: return self.get_code() def get_code(self) -> str: - return self._authorizer.getCodeFromAuthorizeURL() + return self._authorizer.get_code_from_authorize_url() class PtychodusClientBuilder(ABC): @@ -93,125 +91,125 @@ def build(self) -> gladier.GladierBaseClient: class NativePtychodusClientBuilder(PtychodusClientBuilder): def __init__(self, authorizer: WorkflowAuthorizer) -> None: super().__init__() - self._authClient = fair_research_login.NativeClient( + self._auth_client = fair_research_login.NativeClient( client_id=PTYCHODUS_CLIENT_ID, app_name='Ptychodus', code_handlers=[CustomCodeHandler(authorizer)], ) - def _requestAuthorization(self, scopes: list[str]) -> ScopeAuthorizerMapping: + def _request_authorization(self, scopes: list[str]) -> ScopeAuthorizerMapping: logger.debug(f'Requested authorization scopes: {pformat(scopes)}') # 'force' is used for any underlying scope changes. For example, if a flow adds transfer # functionality since it was last run, running it again would require a re-login. - self._authClient.login(requested_scopes=scopes, force=True, refresh_tokens=True) - return self._authClient.get_authorizers_by_scope() + self._auth_client.login(requested_scopes=scopes, force=True, refresh_tokens=True) + return self._auth_client.get_authorizers_by_scope() def build(self) -> gladier.GladierBaseClient: - initialAuthorizers: dict[str, AuthorizerTypes] = dict() + initial_authorizers: dict[str, AuthorizerTypes] = dict() try: # Try to use a previous login to avoid a new login flow - initialAuthorizers = self._authClient.get_authorizers_by_scope() + initial_authorizers = self._auth_client.get_authorizers_by_scope() except fair_research_login.LoadError: pass - loginManager = gladier.managers.CallbackLoginManager( - authorizers=initialAuthorizers, - callback=self._requestAuthorization, + login_manager = gladier.managers.CallbackLoginManager( + authorizers=initial_authorizers, + callback=self._request_authorization, ) - return PtychodusClient(login_manager=loginManager) + return PtychodusClient(login_manager=login_manager) class ConfidentialPtychodusClientBuilder(PtychodusClientBuilder): - def __init__(self, clientID: str, clientSecret: str, flowID: str | None) -> None: + def __init__(self, client_id: str, client_secret: str, flow_id: str | None) -> None: super().__init__() - self._authClient = globus_sdk.ConfidentialAppAuthClient( - client_id=clientID, - client_secret=clientSecret, + self._auth_client = globus_sdk.ConfidentialAppAuthClient( + client_id=client_id, + client_secret=client_secret, app_name='Ptychodus', ) - self._flowID = flowID + self._flow_id = flow_id - def _requestAuthorization(self, scopes: list[str]) -> ScopeAuthorizerMapping: + def _request_authorization(self, scopes: list[str]) -> ScopeAuthorizerMapping: logger.debug(f'Requested authorization scopes: {pformat(scopes)}') - response = self._authClient.oauth2_client_credentials_tokens(requested_scopes=scopes) + response = self._auth_client.oauth2_client_credentials_tokens(requested_scopes=scopes) return { scope: globus_sdk.AccessTokenAuthorizer(access_token=tokens['access_token']) for scope, tokens in response.by_scopes.scope_map.items() } def build(self) -> gladier.GladierBaseClient: - initialAuthorizers: dict[str, AuthorizerTypes] = dict() - loginManager = gladier.managers.CallbackLoginManager( - authorizers=initialAuthorizers, - callback=self._requestAuthorization, + initial_authorizers: dict[str, AuthorizerTypes] = dict() + login_manager = gladier.managers.CallbackLoginManager( + authorizers=initial_authorizers, + callback=self._request_authorization, ) - flowsManager = gladier.managers.FlowsManager(flow_id=self._flowID) - return PtychodusClient(login_manager=loginManager, flows_manager=flowsManager) + flows_manager = gladier.managers.FlowsManager(flow_id=self._flow_id) + return PtychodusClient(login_manager=login_manager, flows_manager=flows_manager) class GlobusWorkflowThread(threading.Thread): def __init__( self, authorizer: WorkflowAuthorizer, - statusRepository: WorkflowStatusRepository, + status_repository: WorkflowStatusRepository, executor: WorkflowExecutor, - clientBuilder: PtychodusClientBuilder, + client_builder: PtychodusClientBuilder, ) -> None: super().__init__() self._authorizer = authorizer - self._statusRepository = statusRepository + self._status_repository = status_repository self._executor = executor - self._clientBuilder = clientBuilder + self._client_builder = client_builder logger.info('\tGlobus SDK ' + version('globus-sdk')) logger.info('\tFair Research Login ' + version('fair-research-login')) logger.info('\tGladier ' + version('gladier')) - self.__gladierClient: gladier.GladierBaseClient | None = None + self.__gladier_client: gladier.GladierBaseClient | None = None @classmethod - def createInstance( + def create_instance( cls, authorizer: WorkflowAuthorizer, - statusRepository: WorkflowStatusRepository, + status_repository: WorkflowStatusRepository, executor: WorkflowExecutor, ) -> GlobusWorkflowThread: try: - clientID = os.environ['CLIENT_ID'] + client_id = os.environ['CLIENT_ID'] except KeyError: - clientBuilder: PtychodusClientBuilder = NativePtychodusClientBuilder(authorizer) - return cls(authorizer, statusRepository, executor, clientBuilder) + client_builder: PtychodusClientBuilder = NativePtychodusClientBuilder(authorizer) + return cls(authorizer, status_repository, executor, client_builder) try: - clientSecret = os.environ['CLIENT_SECRET'] + client_secret = os.environ['CLIENT_SECRET'] except KeyError as ex: raise KeyError('CLIENT_ID requires a CLIENT_SECRET environment variable.') from ex try: - flowID = os.environ['FLOW_ID'] + flow_id = os.environ['FLOW_ID'] except KeyError: # This isn't necessarily bad, but CCs like regular users only get one flow # to play with. They probably don't need more than one, but this will ensure # there aren't errors due to tracking mismatch in the Glaider config - flowID = '' + flow_id = '' logger.warning('No flow ID enforced. Recommend setting FLOW_ID environment variable.') - clientBuilder = ConfidentialPtychodusClientBuilder(clientID, clientSecret, flowID) - return cls(authorizer, statusRepository, executor, clientBuilder) + client_builder = ConfidentialPtychodusClientBuilder(client_id, client_secret, flow_id) + return cls(authorizer, status_repository, executor, client_builder) @property - def _gladierClient(self) -> gladier.GladierBaseClient: - if self.__gladierClient is None: - self.__gladierClient = self._clientBuilder.build() + def _gladier_client(self) -> gladier.GladierBaseClient: + if self.__gladier_client is None: + self.__gladier_client = self._client_builder.build() - return self.__gladierClient + return self.__gladier_client - def _getCurrentAction(self, runID: str) -> str: - status = self._gladierClient.get_status(runID) + def _get_current_action(self, run_id: str) -> str: + status = self._gladier_client.get_status(run_id) action = status.get('state_name') if not action: @@ -232,64 +230,64 @@ def _getCurrentAction(self, runID: str) -> str: return action - def _refreshStatus(self) -> None: - statusList: list[WorkflowStatus] = list() - flowsManager = self._gladierClient.flows_manager - flowID = flowsManager.get_flow_id() - flowsClient = flowsManager.flows_client - response = flowsClient.list_runs(filter_flow_id=flowID) - runDictList = response['runs'] + def _refresh_status(self) -> None: + status_list: list[WorkflowStatus] = list() + flows_manager = self._gladier_client.flows_manager + flow_id = flows_manager.get_flow_id() + flows_client = flows_manager.flows_client + response = flows_client.list_runs(filter_flow_id=flow_id) + run_dict_list = response['runs'] while response['has_next_page']: - response = flowsClient.list_runs(filter_flow_id=flowID, marker=response['marker']) - runDictList.extend(response['runs']) + response = flows_client.list_runs(filter_flow_id=flow_id, marker=response['marker']) + run_dict_list.extend(response['runs']) - for runDict in runDictList: - runID = runDict.get('run_id', '') - action = self._getCurrentAction(runID) - startTimeStr = runDict.get('start_time', '') - completionTimeStr = runDict.get('completion_time', '') + for run_dict in run_dict_list: + run_id = run_dict.get('run_id', '') + action = self._get_current_action(run_id) + start_time_str = run_dict.get('start_time', '') + completion_time_str = run_dict.get('completion_time', '') try: - startTime = datetime.fromisoformat(startTimeStr) + start_time = datetime.fromisoformat(start_time_str) except ValueError: - logger.warning(f'Failed to parse startTime "{startTimeStr}"!') - startTime = datetime(1, 1, 1) + logger.warning(f'Failed to parse startTime "{start_time_str}"!') + start_time = datetime(1, 1, 1) try: - completionTime = datetime.fromisoformat(completionTimeStr) + completion_time = datetime.fromisoformat(completion_time_str) except ValueError: - completionTime = None + completion_time = None run = WorkflowStatus( - label=runDict.get('label', ''), - startTime=startTime, - completionTime=completionTime, - status=runDict.get('status', ''), + label=run_dict.get('label', ''), + start_time=start_time, + completion_time=completion_time, + status=run_dict.get('status', ''), action=action, - runID=runID, - runURL=f'https://app.globus.org/runs/{runID}/logs', + run_id=run_id, + run_url=f'https://app.globus.org/runs/{run_id}/logs', ) - statusList.append(run) + status_list.append(run) - self._statusRepository.update(statusList) + self._status_repository.update(status_list) def run(self) -> None: - while not self._authorizer.shutdownEvent.is_set(): - if self._statusRepository.refreshStatusEvent.is_set(): - self._refreshStatus() - self._statusRepository.refreshStatusEvent.clear() + while not self._authorizer.shutdown_event.is_set(): + if self._status_repository.refresh_status_event.is_set(): + self._refresh_status() + self._status_repository.refresh_status_event.clear() try: - input_ = self._executor.jobQueue.get(block=True, timeout=1) + input_ = self._executor.job_queue.get(block=True, timeout=1) except queue.Empty: continue try: - response = self._gladierClient.run_flow( - flow_input={'input': input_.flowInput}, - label=input_.flowLabel, + response = self._gladier_client.run_flow( + flow_input={'input': input_.flow_input}, + label=input_.flow_label, tags=['aps', 'ptychography'], ) except Exception: @@ -297,4 +295,4 @@ def run(self) -> None: else: logger.info(f'Run Flow Response: {json.dumps(response, indent=4)}') finally: - self._executor.jobQueue.task_done() + self._executor.job_queue.task_done() diff --git a/src/ptychodus/model/workflow/locator.py b/src/ptychodus/model/workflow/locator.py index 82ff810a..6f810666 100644 --- a/src/ptychodus/model/workflow/locator.py +++ b/src/ptychodus/model/workflow/locator.py @@ -4,134 +4,134 @@ from uuid import UUID from ptychodus.api.observer import Observable, Observer -from ptychodus.api.parametric import ( - ParameterGroup, -) +from ptychodus.api.parametric import ParameterGroup class DataLocator(ABC, Observable): @abstractmethod - def setEndpointID(self, endpointID: UUID) -> None: + def set_endpoint_id(self, endpoint_id: UUID) -> None: pass @abstractmethod - def getEndpointID(self) -> UUID: + def get_endpoint_id(self) -> UUID: pass @abstractmethod - def setGlobusPath(self, globusPath: str) -> None: + def set_globus_path(self, globus_path: str) -> None: pass @abstractmethod - def getGlobusPath(self) -> str: + def get_globus_path(self) -> str: pass @abstractmethod - def setPosixPath(self, posixPath: Path) -> None: + def set_posix_path(self, posix_path: Path) -> None: pass @abstractmethod - def getPosixPath(self) -> Path: + def get_posix_path(self) -> Path: pass class SimpleDataLocator(DataLocator, Observer): - def __init__(self, group: ParameterGroup, entryPrefix: str) -> None: + def __init__(self, group: ParameterGroup, entry_prefix: str) -> None: super().__init__() - self._endpointID = group.createUUIDParameter(f'{entryPrefix}DataEndpointID', UUID(int=0)) - self._globusPath = group.createStringParameter( - f'{entryPrefix}DataGlobusPath', - f'/~/path/to/{entryPrefix.lower()}/data', + self._endpoint_id = group.create_uuid_parameter( + f'{entry_prefix}DataEndpointID', UUID(int=0) ) - self._posixPath = group.createPathParameter( - f'{entryPrefix}DataPosixPath', - Path(f'/path/to/{entryPrefix.lower()}/data'), + self._globus_path = group.create_string_parameter( + f'{entry_prefix}DataGlobusPath', + f'/~/path/to/{entry_prefix.lower()}/data', + ) + self._posix_path = group.create_path_parameter( + f'{entry_prefix}DataPosixPath', + Path(f'/path/to/{entry_prefix.lower()}/data'), ) - self._endpointID.addObserver(self) - self._globusPath.addObserver(self) - self._posixPath.addObserver(self) + self._endpoint_id.add_observer(self) + self._globus_path.add_observer(self) + self._posix_path.add_observer(self) - def setEndpointID(self, endpointID: UUID) -> None: - self._endpointID.setValue(endpointID) + def set_endpoint_id(self, endpoint_id: UUID) -> None: + self._endpoint_id.set_value(endpoint_id) - def getEndpointID(self) -> UUID: - return self._endpointID.getValue() + def get_endpoint_id(self) -> UUID: + return self._endpoint_id.get_value() - def setGlobusPath(self, globusPath: str) -> None: - self._globusPath.setValue(globusPath) + def set_globus_path(self, globus_path: str) -> None: + self._globus_path.set_value(globus_path) - def getGlobusPath(self) -> str: - return self._globusPath.getValue() + def get_globus_path(self) -> str: + return self._globus_path.get_value() - def setPosixPath(self, posixPath: Path) -> None: - self._posixPath.setValue(posixPath) + def set_posix_path(self, posix_path: Path) -> None: + self._posix_path.set_value(posix_path) - def getPosixPath(self) -> Path: - return self._posixPath.getValue() + def get_posix_path(self) -> Path: + return self._posix_path.get_value() - def update(self, observable: Observable) -> None: - if observable is self._endpointID: - self.notifyObservers() - elif observable is self._globusPath: - self.notifyObservers() - elif observable is self._posixPath: - self.notifyObservers() + def _update(self, observable: Observable) -> None: + if observable is self._endpoint_id: + self.notify_observers() + elif observable is self._globus_path: + self.notify_observers() + elif observable is self._posix_path: + self.notify_observers() class OutputDataLocator(DataLocator, Observer): def __init__( - self, group: ParameterGroup, entryPrefix: str, inputDataLocator: DataLocator + self, group: ParameterGroup, entry_prefix: str, input_data_locator: DataLocator ) -> None: super().__init__() - self._useRoundTrip = group.createBooleanParameter('UseRoundTrip', True) - self._outputDataLocator = SimpleDataLocator(group, entryPrefix) - self._inputDataLocator = inputDataLocator + self._use_round_trip = group.create_boolean_parameter('UseRoundTrip', True) + self._output_data_locator = SimpleDataLocator(group, entry_prefix) + self._input_data_locator = input_data_locator - self._useRoundTrip.addObserver(self) - self._inputDataLocator.addObserver(self) - self._outputDataLocator.addObserver(self) + self._use_round_trip.add_observer(self) + self._input_data_locator.add_observer(self) + self._output_data_locator.add_observer(self) - def setRoundTripEnabled(self, enable: bool) -> None: - self._useRoundTrip.setValue(enable) + def set_round_trip_enabled(self, enable: bool) -> None: + self._use_round_trip.set_value(enable) - def isRoundTripEnabled(self) -> bool: - return self._useRoundTrip.getValue() + def is_round_trip_enabled(self) -> bool: + return self._use_round_trip.get_value() - def setEndpointID(self, endpointID: UUID) -> None: - self._outputDataLocator.setEndpointID(endpointID) + def set_endpoint_id(self, endpoint_id: UUID) -> None: + self._output_data_locator.set_endpoint_id(endpoint_id) - def getEndpointID(self) -> UUID: + def get_endpoint_id(self) -> UUID: return ( - self._inputDataLocator.getEndpointID() - if self._useRoundTrip.getValue() - else self._outputDataLocator.getEndpointID() + self._input_data_locator.get_endpoint_id() + if self._use_round_trip.get_value() + else self._output_data_locator.get_endpoint_id() ) - def setGlobusPath(self, globusPath: str) -> None: - self._outputDataLocator.setGlobusPath(globusPath) + def set_globus_path(self, globus_path: str) -> None: + self._output_data_locator.set_globus_path(globus_path) - def getGlobusPath(self) -> str: + def get_globus_path(self) -> str: return ( - self._inputDataLocator.getGlobusPath() - if self._useRoundTrip.getValue() - else self._outputDataLocator.getGlobusPath() + self._input_data_locator.get_globus_path() + if self._use_round_trip.get_value() + else self._output_data_locator.get_globus_path() ) - def setPosixPath(self, posixPath: Path) -> None: - self._outputDataLocator.setPosixPath(posixPath) + def set_posix_path(self, posix_path: Path) -> None: + self._output_data_locator.set_posix_path(posix_path) - def getPosixPath(self) -> Path: + def get_posix_path(self) -> Path: return ( - self._inputDataLocator.getPosixPath() - if self._useRoundTrip.getValue() - else self._outputDataLocator.getPosixPath() + self._input_data_locator.get_posix_path() + if self._use_round_trip.get_value() + else self._output_data_locator.get_posix_path() ) - def update(self, observable: Observable) -> None: - if observable is self._useRoundTrip: - self.notifyObservers() - elif observable is self._inputDataLocator: - self.notifyObservers() - elif observable is self._outputDataLocator: - self.notifyObservers() + def _update(self, observable: Observable) -> None: + if observable is self._use_round_trip: + self.notify_observers() + elif observable is self._input_data_locator: + self.notify_observers() + elif observable is self._output_data_locator: + self.notify_observers() diff --git a/src/ptychodus/model/workflow/settings.py b/src/ptychodus/model/workflow/settings.py index 7bae1e27..aeda1268 100644 --- a/src/ptychodus/model/workflow/settings.py +++ b/src/ptychodus/model/workflow/settings.py @@ -8,16 +8,16 @@ class WorkflowSettings(Observable, Observer): def __init__(self, registry: SettingsRegistry) -> None: super().__init__() - self._settingsGroup = registry.createGroup('Workflow') - self._settingsGroup.addObserver(self) + self._group = registry.create_group('Workflow') + self._group.add_observer(self) - self.computeEndpointID = self._settingsGroup.createUUIDParameter( + self.compute_endpoint_id = self._group.create_uuid_parameter( 'ComputeEndpointID', UUID(int=0) ) - self.statusRefreshIntervalInSeconds = self._settingsGroup.createIntegerParameter( + self.status_refresh_interval_s = self._group.create_integer_parameter( 'StatusRefreshIntervalInSeconds', 10 ) - def update(self, observable: Observable) -> None: - if observable is self._settingsGroup: - self.notifyObservers() + def _update(self, observable: Observable) -> None: + if observable is self._group: + self.notify_observers() diff --git a/src/ptychodus/model/workflow/status.py b/src/ptychodus/model/workflow/status.py index dce27227..edc4bd50 100644 --- a/src/ptychodus/model/workflow/status.py +++ b/src/ptychodus/model/workflow/status.py @@ -1,6 +1,6 @@ from collections.abc import Sequence from dataclasses import dataclass -from datetime import datetime +from datetime import datetime, timezone from typing import overload import threading @@ -8,21 +8,21 @@ @dataclass(frozen=True) class WorkflowStatus: label: str - startTime: datetime - completionTime: datetime | None + start_time: datetime + completion_time: datetime | None status: str action: str - runID: str - runURL: str + run_id: str + run_url: str class WorkflowStatusRepository(Sequence[WorkflowStatus]): def __init__(self) -> None: super().__init__() - self._statusLock = threading.Lock() - self._statusList: list[WorkflowStatus] = list() - self._statusDateTime = datetime.min - self.refreshStatusEvent = threading.Event() + self._status_lock = threading.Lock() + self._status_list: list[WorkflowStatus] = list() + self._status_date_time = datetime.min + self.refresh_status_event = threading.Event() @overload def __getitem__(self, index: int) -> WorkflowStatus: ... @@ -31,22 +31,22 @@ def __getitem__(self, index: int) -> WorkflowStatus: ... def __getitem__(self, index: slice) -> Sequence[WorkflowStatus]: ... def __getitem__(self, index: int | slice) -> WorkflowStatus | Sequence[WorkflowStatus]: - with self._statusLock: - return self._statusList[index] + with self._status_lock: + return self._status_list[index] def __len__(self) -> int: - with self._statusLock: - return len(self._statusList) + with self._status_lock: + return len(self._status_list) - def getStatusDateTime(self) -> datetime: - with self._statusLock: - return self._statusDateTime + def get_status_date_time(self) -> datetime: + with self._status_lock: + return self._status_date_time - def refreshStatus(self) -> None: - self.refreshStatusEvent.set() + def refresh_status(self) -> None: + self.refresh_status_event.set() - def update(self, statusSequence: Sequence[WorkflowStatus]) -> None: - with self._statusLock: - self._statusDateTime = datetime.utcnow() - self._statusList = list(statusSequence) - self._statusList.sort(key=lambda x: x.startTime) + def update(self, status_sequence: Sequence[WorkflowStatus]) -> None: + with self._status_lock: + self._status_date_time = datetime.now(timezone.utc) + self._status_list = list(status_sequence) + self._status_list.sort(key=lambda x: x.start_time) diff --git a/src/ptychodus/plugins/aps2idDiffractionFile.py b/src/ptychodus/plugins/aps2idDiffractionFile.py deleted file mode 100644 index dbb651c4..00000000 --- a/src/ptychodus/plugins/aps2idDiffractionFile.py +++ /dev/null @@ -1,82 +0,0 @@ -from collections.abc import Mapping -from pathlib import Path -import logging -import re - -import h5py - -from ptychodus.api.geometry import ImageExtent -from ptychodus.api.patterns import ( - DiffractionDataset, - DiffractionFileReader, - DiffractionMetadata, - DiffractionPatternArray, - SimpleDiffractionDataset, -) -from ptychodus.api.plugins import PluginRegistry -from ptychodus.api.tree import SimpleTreeNode - -from .h5DiffractionFile import H5DiffractionPatternArray - -logger = logging.getLogger(__name__) - - -class APS2IDDiffractionFileReader(DiffractionFileReader): - def _getFileSeries(self, filePath: Path) -> tuple[Mapping[int, Path], str]: - filePathDict: dict[int, Path] = dict() - - digits = re.findall(r'\d+', filePath.stem) - longest_digits = max(digits, key=len) - filePattern = filePath.name.replace(longest_digits, f'(\\d{{{len(longest_digits)}}})') - - for fp in filePath.parent.iterdir(): - z = re.match(filePattern, fp.name) - - if z: - index = int(z.group(1)) - filePathDict[index] = fp - - return filePathDict, filePattern - - def read(self, filePath: Path) -> DiffractionDataset: - dataset = SimpleDiffractionDataset.createNullInstance(filePath) - dataPath = '/entry/data/data' - - filePathMapping, filePattern = self._getFileSeries(filePath) - contentsTree = SimpleTreeNode.createRoot(['Name', 'Type', 'Details']) - arrayList: list[DiffractionPatternArray] = list() - - for idx, fp in sorted(filePathMapping.items()): - array = H5DiffractionPatternArray(fp.stem, idx, fp, dataPath) - contentsTree.createChild([array.getLabel(), 'HDF5', str(idx)]) - arrayList.append(array) - - try: - with h5py.File(filePath, 'r') as h5File: - try: - h5data = h5File[dataPath] - except KeyError: - logger.warning(f'File {filePath} is not an APS 2-ID data file.') - else: - numberOfPatternsPerArray, detectorHeight, detectorWidth = h5data.shape - metadata = DiffractionMetadata( - numberOfPatternsPerArray=numberOfPatternsPerArray, - numberOfPatternsTotal=numberOfPatternsPerArray * len(arrayList), - patternDataType=h5data.dtype, - detectorExtent=ImageExtent(detectorWidth, detectorHeight), - filePath=filePath.parent / filePattern, - ) - - dataset = SimpleDiffractionDataset(metadata, contentsTree, arrayList) - except OSError: - logger.warning(f'Unable to read file "{filePath}".') - - return dataset - - -def registerPlugins(registry: PluginRegistry) -> None: - registry.diffractionFileReaders.registerPlugin( - APS2IDDiffractionFileReader(), - simpleName='APS_2ID', - displayName='APS 2-ID Diffraction Files (*.h5 *.hdf5)', - ) diff --git a/src/ptychodus/plugins/aps2id_diffraction_file.py b/src/ptychodus/plugins/aps2id_diffraction_file.py new file mode 100644 index 00000000..79ca354f --- /dev/null +++ b/src/ptychodus/plugins/aps2id_diffraction_file.py @@ -0,0 +1,89 @@ +from collections.abc import Mapping +from pathlib import Path +import logging +import re + +import h5py +import numpy + +from ptychodus.api.geometry import ImageExtent +from ptychodus.api.patterns import ( + DiffractionDataset, + DiffractionFileReader, + DiffractionMetadata, + DiffractionPatternArray, + SimpleDiffractionDataset, +) +from ptychodus.api.plugins import PluginRegistry +from ptychodus.api.tree import SimpleTreeNode + +from .h5_diffraction_file import H5DiffractionPatternArray + +logger = logging.getLogger(__name__) + + +class APS2IDDiffractionFileReader(DiffractionFileReader): + def _get_file_series(self, file_path: Path) -> tuple[Mapping[int, Path], str]: + file_path_dict: dict[int, Path] = dict() + + digits = re.findall(r'\d+', file_path.stem) + longest_digits = max(digits, key=len) + file_pattern = file_path.name.replace(longest_digits, f'(\\d{{{len(longest_digits)}}})') + + for fp in file_path.parent.iterdir(): + z = re.match(file_pattern, fp.name) + + if z: + index = int(z.group(1)) + file_path_dict[index] = fp + + return file_path_dict, file_pattern + + def read(self, file_path: Path) -> DiffractionDataset: + file_path_mapping, file_pattern = self._get_file_series(file_path) + data_path = '/entry/data/data' + + with h5py.File(file_path, 'r') as h5_file: + try: + h5data = h5_file[data_path] + except KeyError: + logger.warning(f'File {file_path} is not an APS 2-ID data file.') + return SimpleDiffractionDataset.create_null(file_path) + else: + num_patterns_per_array, detector_height, detector_width = h5data.shape + metadata = DiffractionMetadata( + num_patterns_per_array=num_patterns_per_array, + num_patterns_total=num_patterns_per_array * len(file_path_mapping), + pattern_dtype=h5data.dtype, + detector_extent=ImageExtent(detector_width, detector_height), + file_path=file_path.parent / file_pattern, + ) + + contents_tree = SimpleTreeNode.create_root(['Name', 'Type', 'Details']) + array_list: list[DiffractionPatternArray] = list() + + for idx, fp in sorted(file_path_mapping.items()): + indexes = numpy.arange(num_patterns_per_array) + idx * num_patterns_per_array + array = H5DiffractionPatternArray(fp.stem, indexes, fp, data_path) + contents_tree.create_child([array.get_label(), 'HDF5', str(idx)]) + array_list.append(array) + + return SimpleDiffractionDataset(metadata, contents_tree, array_list) + + +def register_plugins(registry: PluginRegistry) -> None: + registry.diffraction_file_readers.register_plugin( + APS2IDDiffractionFileReader(), + simple_name='APS_2IDD', + display_name='APS 2-ID-D Files (*.h5 *.hdf5)', + ) + registry.diffraction_file_readers.register_plugin( + APS2IDDiffractionFileReader(), + simple_name='APS_2IDE', + display_name='APS 2-ID-E Microprobe Files (*.h5 *.hdf5)', + ) + registry.diffraction_file_readers.register_plugin( + APS2IDDiffractionFileReader(), + simple_name='APS_BNP', + display_name='APS Bionanoprobe Files (*.h5 *.hdf5)', + ) diff --git a/src/ptychodus/plugins/csaxs_diffraction_file.py b/src/ptychodus/plugins/csaxs_diffraction_file.py new file mode 100644 index 00000000..9826f856 --- /dev/null +++ b/src/ptychodus/plugins/csaxs_diffraction_file.py @@ -0,0 +1,84 @@ +from pathlib import Path +from typing import Final +import logging + +import h5py +import numpy + +from ptychodus.api.geometry import ImageExtent, PixelGeometry +from ptychodus.api.patterns import ( + DiffractionDataset, + DiffractionFileReader, + DiffractionMetadata, + SimpleDiffractionDataset, +) +from ptychodus.api.plugins import PluginRegistry + +from .h5_diffraction_file import H5DiffractionPatternArray, H5DiffractionFileTreeBuilder + +logger = logging.getLogger(__name__) + + +class CSAXSDiffractionFileReader(DiffractionFileReader): + ONE_MICRON_M: Final[float] = 1e-6 + ONE_MILLIMETER_M: Final[float] = 1e-3 + + def __init__(self) -> None: + self._data_path = '/entry/data/data' + self._tree_builder = H5DiffractionFileTreeBuilder() + + def read(self, file_path: Path) -> DiffractionDataset: + dataset = SimpleDiffractionDataset.create_null(file_path) + + try: + with h5py.File(file_path, 'r') as h5_file: + contents_tree = self._tree_builder.build(h5_file) + + try: + data = h5_file[self._data_path] + x_pixel_size_um = h5_file['/entry/instrument/eiger_4/x_pixel_size'] + y_pixel_size_um = h5_file['/entry/instrument/eiger_4/y_pixel_size'] + distance_mm = h5_file['/entry/instrument/monochromator/distance'] + energy_keV = h5_file['/entry/instrument/monochromator/energy'] # noqa: N806 + except KeyError: + logger.warning('Unable to load data.') + else: + num_patterns, detector_height, detector_width = data.shape + detector_distance_m = float(distance_mm[()]) * self.ONE_MILLIMETER_M + detector_pixel_geometry = PixelGeometry( + width_m=float(x_pixel_size_um[()]) * self.ONE_MICRON_M, + height_m=float(y_pixel_size_um[()]) * self.ONE_MICRON_M, + ) + probe_energy_eV = 1000 * float(energy_keV[()]) # noqa: N806 + + metadata = DiffractionMetadata( + num_patterns_per_array=num_patterns, + num_patterns_total=num_patterns, + pattern_dtype=data.dtype, + detector_distance_m=abs(detector_distance_m), + detector_extent=ImageExtent(detector_width, detector_height), + detector_pixel_geometry=detector_pixel_geometry, + probe_energy_eV=probe_energy_eV, + file_path=file_path, + ) + + array = H5DiffractionPatternArray( + label=file_path.stem, + indexes=numpy.arange(num_patterns), + file_path=file_path, + data_path=self._data_path, + ) + + dataset = SimpleDiffractionDataset(metadata, contents_tree, [array]) + except OSError: + logger.warning(f'Unable to read file "{file_path}".') + + return dataset + + +def register_plugins(registry: PluginRegistry) -> None: + registry.diffraction_file_readers.register_plugin( + CSAXSDiffractionFileReader(), + simple_name='SLS_cSAXS', + display_name='SLS cSAXS Files (*.h5 *.hdf5)', + ) diff --git a/src/ptychodus/plugins/cssi_position_file.py b/src/ptychodus/plugins/cssi_position_file.py new file mode 100644 index 00000000..cf9e0c9c --- /dev/null +++ b/src/ptychodus/plugins/cssi_position_file.py @@ -0,0 +1,41 @@ +from pathlib import Path +from typing import Final +import logging + +import h5py + +from ptychodus.api.plugins import PluginRegistry +from ptychodus.api.scan import PositionSequence, PositionFileReader, ScanPoint, ScanPointParseError + +logger = logging.getLogger(__name__) + + +class CSSIPositionFileReader(PositionFileReader): + ONE_MILLIMETER_M: Final[float] = 1e-3 + + def read(self, file_path: Path) -> PositionSequence: + point_list: list[ScanPoint] = list() + + with h5py.File(file_path, 'r') as h5_file: + try: + h5_positions = h5_file['/exchange/motor_pos'] + except KeyError: + logger.exception('Unable to load scan.') + else: + for idx, row in enumerate(h5_positions): + point = ScanPoint( + idx, + row[0] * self.ONE_MILLIMETER_M, + row[1] * self.ONE_MILLIMETER_M, + ) + point_list.append(point) + + return PositionSequence(point_list) + + +def register_plugins(registry: PluginRegistry) -> None: + registry.position_file_readers.register_plugin( + CSSIPositionFileReader(), + simple_name='APS_CSSI', + display_name='APS CSSI Files (*.h5 *.hdf5)', + ) diff --git a/src/ptychodus/plugins/csvObjectFile.py b/src/ptychodus/plugins/csvObjectFile.py deleted file mode 100644 index 69064259..00000000 --- a/src/ptychodus/plugins/csvObjectFile.py +++ /dev/null @@ -1,31 +0,0 @@ -from pathlib import Path - -import numpy - -from ptychodus.api.object import Object, ObjectFileReader, ObjectFileWriter -from ptychodus.api.plugins import PluginRegistry - - -class CSVObjectFileReader(ObjectFileReader): - def read(self, filePath: Path) -> Object: - array = numpy.genfromtxt(filePath, delimiter=',', dtype='complex') - return Object(array) - - -class CSVObjectFileWriter(ObjectFileWriter): - def write(self, filePath: Path, object_: Object) -> None: - array = object_.array - numpy.savetxt(filePath, array, delimiter=',') - - -def registerPlugins(registry: PluginRegistry) -> None: - registry.objectFileReaders.registerPlugin( - CSVObjectFileReader(), - simpleName='CSV', - displayName='Comma-Separated Values Files (*.csv)', - ) - registry.objectFileWriters.registerPlugin( - CSVObjectFileWriter(), - simpleName='CSV', - displayName='Comma-Separated Values Files (*.csv)', - ) diff --git a/src/ptychodus/plugins/csvProbeFile.py b/src/ptychodus/plugins/csvProbeFile.py deleted file mode 100644 index 81120e6e..00000000 --- a/src/ptychodus/plugins/csvProbeFile.py +++ /dev/null @@ -1,40 +0,0 @@ -from pathlib import Path - -import numpy - -from ptychodus.api.plugins import PluginRegistry -from ptychodus.api.probe import Probe, ProbeFileReader, ProbeFileWriter - - -class CSVProbeFileReader(ProbeFileReader): - def read(self, filePath: Path) -> Probe: - arrayFlat = numpy.genfromtxt(filePath, delimiter=',', dtype='complex') - numberOfModes, remainder = divmod(arrayFlat.shape[0], arrayFlat.shape[1]) - - if remainder != 0: - raise ValueError('Failed to determine probe modes!') - - if numberOfModes > 1: - array = arrayFlat.reshape(numberOfModes, arrayFlat.shape[1], arrayFlat.shape[1]) - - return Probe(array) - - -class CSVProbeFileWriter(ProbeFileWriter): - def write(self, filePath: Path, probe: Probe) -> None: - array = probe.array - arrayFlat = array.reshape(-1, array.shape[-1]) - numpy.savetxt(filePath, arrayFlat, delimiter=',') - - -def registerPlugins(registry: PluginRegistry) -> None: - registry.probeFileReaders.registerPlugin( - CSVProbeFileReader(), - simpleName='CSV', - displayName='Comma-Separated Values Files (*.csv)', - ) - registry.probeFileWriters.registerPlugin( - CSVProbeFileWriter(), - simpleName='CSV', - displayName='Comma-Separated Values Files (*.csv)', - ) diff --git a/src/ptychodus/plugins/csv_object_file.py b/src/ptychodus/plugins/csv_object_file.py new file mode 100644 index 00000000..23559c43 --- /dev/null +++ b/src/ptychodus/plugins/csv_object_file.py @@ -0,0 +1,31 @@ +from pathlib import Path + +import numpy + +from ptychodus.api.object import Object, ObjectFileReader, ObjectFileWriter +from ptychodus.api.plugins import PluginRegistry + + +class CSVObjectFileReader(ObjectFileReader): + def read(self, file_path: Path) -> Object: + array = numpy.genfromtxt(file_path, delimiter=',', dtype=complex) + return Object(array=array, pixel_geometry=None, center=None) + + +class CSVObjectFileWriter(ObjectFileWriter): + def write(self, file_path: Path, object_: Object) -> None: + array = object_.get_array() + numpy.savetxt(file_path, array, delimiter=',') + + +def register_plugins(registry: PluginRegistry) -> None: + registry.object_file_readers.register_plugin( + CSVObjectFileReader(), + simple_name='CSV', + display_name='Comma-Separated Values Files (*.csv)', + ) + registry.object_file_writers.register_plugin( + CSVObjectFileWriter(), + simple_name='CSV', + display_name='Comma-Separated Values Files (*.csv)', + ) diff --git a/src/ptychodus/plugins/csv_probe_file.py b/src/ptychodus/plugins/csv_probe_file.py new file mode 100644 index 00000000..36eb2658 --- /dev/null +++ b/src/ptychodus/plugins/csv_probe_file.py @@ -0,0 +1,38 @@ +from pathlib import Path + +import numpy + +from ptychodus.api.plugins import PluginRegistry +from ptychodus.api.probe import ProbeSequence, ProbeFileReader, ProbeFileWriter + + +class CSVProbeFileReader(ProbeFileReader): + def read(self, file_path: Path) -> ProbeSequence: + array_flat = numpy.genfromtxt(file_path, delimiter=',', dtype=complex) + num_modes, remainder = divmod(array_flat.shape[0], array_flat.shape[1]) + + if remainder != 0: + raise ValueError('Failed to determine probe modes!') + + array = array_flat.reshape(num_modes, array_flat.shape[1], array_flat.shape[1]) + return ProbeSequence(array=array, opr_weights=None, pixel_geometry=None) + + +class CSVProbeFileWriter(ProbeFileWriter): + def write(self, file_path: Path, probes: ProbeSequence) -> None: + array = probes.get_probe_no_opr().get_array() + array_flat = array.reshape(-1, array.shape[-1]) + numpy.savetxt(file_path, array_flat, delimiter=',') + + +def register_plugins(registry: PluginRegistry) -> None: + registry.probe_file_readers.register_plugin( + CSVProbeFileReader(), + simple_name='CSV', + display_name='Comma-Separated Values Files (*.csv)', + ) + registry.probe_file_writers.register_plugin( + CSVProbeFileWriter(), + simple_name='CSV', + display_name='Comma-Separated Values Files (*.csv)', + ) diff --git a/src/ptychodus/plugins/cxiDiffractionFile.py b/src/ptychodus/plugins/cxiDiffractionFile.py deleted file mode 100644 index ab629705..00000000 --- a/src/ptychodus/plugins/cxiDiffractionFile.py +++ /dev/null @@ -1,81 +0,0 @@ -from pathlib import Path -import logging - -import h5py - -from ptychodus.api.constants import ELECTRON_VOLT_J -from ptychodus.api.geometry import ImageExtent, PixelGeometry -from ptychodus.api.patterns import ( - DiffractionDataset, - DiffractionFileReader, - DiffractionMetadata, - SimpleDiffractionDataset, -) -from ptychodus.api.plugins import PluginRegistry - -from .h5DiffractionFile import H5DiffractionPatternArray, H5DiffractionFileTreeBuilder - -logger = logging.getLogger(__name__) - - -class CXIDiffractionFileReader(DiffractionFileReader): - def __init__(self) -> None: - self._dataPath = '/entry_1/data_1/data' - self._treeBuilder = H5DiffractionFileTreeBuilder() - - def read(self, filePath: Path) -> DiffractionDataset: - dataset = SimpleDiffractionDataset.createNullInstance(filePath) - - try: - with h5py.File(filePath, 'r') as h5File: - contentsTree = self._treeBuilder.build(h5File) - - try: - data = h5File[self._dataPath] - except KeyError: - logger.warning('Unable to load data.') - else: - numberOfPatterns, detectorHeight, detectorWidth = data.shape - - detectorExtent = ImageExtent(detectorWidth, detectorHeight) - detectorDistanceInMeters = float( - h5File['/entry_1/instrument_1/detector_1/distance'][()] - ) - detectorPixelGeometry = PixelGeometry( - float(h5File['/entry_1/instrument_1/detector_1/x_pixel_size'][()]), - float(h5File['/entry_1/instrument_1/detector_1/y_pixel_size'][()]), - ) - probeEnergyInJoules = float(h5File['/entry_1/instrument_1/source_1/energy'][()]) - probeEnergyInElectronVolts = probeEnergyInJoules / ELECTRON_VOLT_J - - metadata = DiffractionMetadata( - numberOfPatternsPerArray=numberOfPatterns, - numberOfPatternsTotal=numberOfPatterns, - patternDataType=data.dtype, - detectorDistanceInMeters=detectorDistanceInMeters, - detectorExtent=detectorExtent, - detectorPixelGeometry=detectorPixelGeometry, - probeEnergyInElectronVolts=probeEnergyInElectronVolts, - filePath=filePath, - ) - - array = H5DiffractionPatternArray( - label=filePath.stem, - index=0, - filePath=filePath, - dataPath=self._dataPath, - ) - - dataset = SimpleDiffractionDataset(metadata, contentsTree, [array]) - except OSError: - logger.warning(f'Unable to read file "{filePath}".') - - return dataset - - -def registerPlugins(registry: PluginRegistry) -> None: - registry.diffractionFileReaders.registerPlugin( - CXIDiffractionFileReader(), - simpleName='CXI', - displayName='Coherent X-ray Imaging Files (*.cxi)', - ) diff --git a/src/ptychodus/plugins/cxiProbeFile.py b/src/ptychodus/plugins/cxiProbeFile.py deleted file mode 100644 index b3d4fd9a..00000000 --- a/src/ptychodus/plugins/cxiProbeFile.py +++ /dev/null @@ -1,31 +0,0 @@ -from pathlib import Path -import logging - -import h5py -import numpy - -from ptychodus.api.plugins import PluginRegistry -from ptychodus.api.probe import Probe, ProbeFileReader - -logger = logging.getLogger(__name__) - - -class CXIProbeFileReader(ProbeFileReader): - def read(self, filePath: Path) -> Probe: - array = numpy.zeros((0, 0, 0), dtype=complex) - - with h5py.File(filePath, 'r') as h5File: - try: - array = h5File['/entry_1/instrument_1/source_1/illumination'][()] - except KeyError: - logger.warning('Unable to load probe.') - - return Probe(array) - - -def registerPlugins(registry: PluginRegistry) -> None: - registry.probeFileReaders.registerPlugin( - CXIProbeFileReader(), - simpleName='CXI', - displayName='Coherent X-ray Imaging Files (*.cxi)', - ) diff --git a/src/ptychodus/plugins/cxiScanFile.py b/src/ptychodus/plugins/cxiScanFile.py deleted file mode 100644 index 88485c7b..00000000 --- a/src/ptychodus/plugins/cxiScanFile.py +++ /dev/null @@ -1,39 +0,0 @@ -from pathlib import Path -import logging - -import h5py - -from ptychodus.api.plugins import PluginRegistry -from ptychodus.api.scan import Scan, ScanFileReader, ScanPoint - -logger = logging.getLogger(__name__) - - -class CXIScanFileReader(ScanFileReader): - def read(self, filePath: Path) -> Scan: - pointList: list[ScanPoint] = list() - - with h5py.File(filePath, 'r') as h5File: - try: - xyzArray = h5File['/entry_1/data_1/translation'][()] - except KeyError: - logger.exception('Unable to load scan.') - else: - for idx, xyz in enumerate(xyzArray): - try: - x, y, z = xyz - except ValueError: - logger.exception(f'Unable to load scan point {xyz=}.') - else: - point = ScanPoint(idx, x, y) - pointList.append(point) - - return Scan(pointList) - - -def registerPlugins(registry: PluginRegistry) -> None: - registry.scanFileReaders.registerPlugin( - CXIScanFileReader(), - simpleName='CXI', - displayName='Coherent X-ray Imaging Files (*.cxi)', - ) diff --git a/src/ptychodus/plugins/cxi_file.py b/src/ptychodus/plugins/cxi_file.py new file mode 100644 index 00000000..d7bf7b86 --- /dev/null +++ b/src/ptychodus/plugins/cxi_file.py @@ -0,0 +1,125 @@ +from pathlib import Path +from typing import Final +import logging + +import h5py +import numpy + +from .h5_diffraction_file import H5DiffractionPatternArray, H5DiffractionFileTreeBuilder +from ptychodus.api.geometry import ImageExtent, PixelGeometry +from ptychodus.api.patterns import ( + DiffractionDataset, + DiffractionFileReader, + DiffractionMetadata, + SimpleDiffractionDataset, +) +from ptychodus.api.plugins import PluginRegistry +from ptychodus.api.probe import ProbeSequence, ProbeFileReader +from ptychodus.api.product import ELECTRON_VOLT_J +from ptychodus.api.scan import PositionSequence, PositionFileReader, ScanPoint +from ptychodus.api.typing import ComplexArrayType + +logger = logging.getLogger(__name__) + + +class CXIDiffractionFileReader(DiffractionFileReader): + def __init__(self) -> None: + self._data_path = '/entry_1/data_1/data' + self._tree_builder = H5DiffractionFileTreeBuilder() + + def read(self, file_path: Path) -> DiffractionDataset: + dataset = SimpleDiffractionDataset.create_null(file_path) + + try: + with h5py.File(file_path, 'r') as h5_file: + contents_tree = self._tree_builder.build(h5_file) + + try: + data = h5_file[self._data_path] + except KeyError: + logger.warning('Unable to load data.') + else: + num_patterns, detector_height, detector_width = data.shape + + detector_extent = ImageExtent(detector_width, detector_height) + detector_distance_m = float( + h5_file['/entry_1/instrument_1/detector_1/distance'][()] + ) + detector_pixel_geometry = PixelGeometry( + float(h5_file['/entry_1/instrument_1/detector_1/x_pixel_size'][()]), + float(h5_file['/entry_1/instrument_1/detector_1/y_pixel_size'][()]), + ) + probe_energy_J = float(h5_file['/entry_1/instrument_1/source_1/energy'][()]) # noqa: N806 + probe_energy_eV = probe_energy_J / ELECTRON_VOLT_J # noqa: N806 + + # TODO load detector mask; zeros are good pixels + # /entry_1/instrument_1/detector_1/mask Dataset {512, 512} + + metadata = DiffractionMetadata( + num_patterns_per_array=num_patterns, + num_patterns_total=num_patterns, + pattern_dtype=data.dtype, + detector_distance_m=detector_distance_m, + detector_extent=detector_extent, + detector_pixel_geometry=detector_pixel_geometry, + probe_energy_eV=probe_energy_eV, + file_path=file_path, + ) + + array = H5DiffractionPatternArray( + label=file_path.stem, + indexes=numpy.arange(num_patterns), + file_path=file_path, + data_path=self._data_path, + ) + + dataset = SimpleDiffractionDataset(metadata, contents_tree, [array]) + except OSError: + logger.warning(f'Unable to read file "{file_path}".') + + return dataset + + +class CXIPositionFileReader(PositionFileReader): + def read(self, file_path: Path) -> PositionSequence: + point_list: list[ScanPoint] = list() + + with h5py.File(file_path, 'r') as h5_file: + xyz_m = h5_file['/entry_1/data_1/translation'][()] + + for idx, (x, y, z) in enumerate(xyz_m): + point = ScanPoint(idx, x, y) + point_list.append(point) + + return PositionSequence(point_list) + + +class CXIProbeFileReader(ProbeFileReader): + def read(self, file_path: Path) -> ProbeSequence: + array: ComplexArrayType | None = None + + with h5py.File(file_path, 'r') as h5_file: + array = h5_file['/entry_1/instrument_1/source_1/illumination'][()] + + return ProbeSequence(array=array, opr_weights=None, pixel_geometry=None) + + +def register_plugins(registry: PluginRegistry) -> None: + SIMPLE_NAME: Final[str] = 'CXI' # noqa: N806 + DISPLAY_NAME: Final[str] = 'Coherent X-ray Imaging Files (*.cxi)' # noqa: N806 + + registry.diffraction_file_readers.register_plugin( + CXIDiffractionFileReader(), + simple_name=SIMPLE_NAME, + display_name=DISPLAY_NAME, + ) + registry.position_file_readers.register_plugin( + CXIPositionFileReader(), + simple_name=SIMPLE_NAME, + display_name=DISPLAY_NAME, + ) + registry.probe_file_readers.register_plugin( + CXIProbeFileReader(), + simple_name=SIMPLE_NAME, + display_name=DISPLAY_NAME, + ) diff --git a/src/ptychodus/plugins/deconvolution.py b/src/ptychodus/plugins/deconvolution.py index 5b3ce3a8..6367ba70 100644 --- a/src/ptychodus/plugins/deconvolution.py +++ b/src/ptychodus/plugins/deconvolution.py @@ -12,45 +12,42 @@ def __call__(self, emap: ElementMap, product: Product) -> ElementMap: class RichardsonLucyDeconvolution(DeconvolutionStrategy): def __call__(self, emap: ElementMap, product: Product) -> ElementMap: - cps = skimage.restoration.richardson_lucy( - emap.counts_per_second, product.probe.getIntensity() - ) + probe_intensity = product.probes.get_probe_no_opr().get_intensity() + cps = skimage.restoration.richardson_lucy(emap.counts_per_second, probe_intensity) return ElementMap(emap.name, cps) class WienerDeconvolution(DeconvolutionStrategy): def __call__(self, emap: ElementMap, product: Product) -> ElementMap: + probe_intensity = product.probes.get_probe_no_opr().get_intensity() balance = 0.05 # TODO - cps = skimage.restoration.wiener( - emap.counts_per_second, product.probe.getIntensity(), balance - ) + cps = skimage.restoration.wiener(emap.counts_per_second, probe_intensity, balance) return ElementMap(emap.name, cps) class UnsupervisedWienerDeconvolution(DeconvolutionStrategy): def __call__(self, emap: ElementMap, product: Product) -> ElementMap: - cps, _ = skimage.restoration.unsupervised_wiener( - emap.counts_per_second, product.probe.getIntensity() - ) + probe_intensity = product.probes.get_probe_no_opr().get_intensity() + cps, _ = skimage.restoration.unsupervised_wiener(emap.counts_per_second, probe_intensity) return ElementMap(emap.name, cps) -def registerPlugins(registry: PluginRegistry) -> None: +def register_plugins(registry: PluginRegistry) -> None: # NOTE See https://scikit-image.org/docs/stable/api/skimage.restoration.html # TODO Implement method from https://doi.org/10.1364/OE.20.018287 - registry.deconvolutionStrategies.registerPlugin( + registry.deconvolution_strategies.register_plugin( IdentityDeconvolution(), - displayName='Identity', + display_name='Identity', ) - registry.deconvolutionStrategies.registerPlugin( + registry.deconvolution_strategies.register_plugin( RichardsonLucyDeconvolution(), - displayName='Richardson-Lucy', + display_name='Richardson-Lucy', ) - registry.deconvolutionStrategies.registerPlugin( + registry.deconvolution_strategies.register_plugin( WienerDeconvolution(), - displayName='Wiener', + display_name='Wiener', ) - registry.deconvolutionStrategies.registerPlugin( + registry.deconvolution_strategies.register_plugin( UnsupervisedWienerDeconvolution(), - displayName='Unsupervised Wiener', + display_name='Unsupervised Wiener', ) diff --git a/src/ptychodus/plugins/delimitedScanFile.py b/src/ptychodus/plugins/delimitedScanFile.py deleted file mode 100644 index 4df148c6..00000000 --- a/src/ptychodus/plugins/delimitedScanFile.py +++ /dev/null @@ -1,81 +0,0 @@ -from pathlib import Path -import csv - -from ptychodus.api.plugins import PluginRegistry -from ptychodus.api.scan import ( - Scan, - ScanFileReader, - ScanFileWriter, - ScanPoint, - ScanPointParseError, -) - - -class DelimitedScanFileReader(ScanFileReader): - def __init__(self, delimiter: str, swapXY: bool) -> None: - self._delimiter = delimiter - self._swapXY = swapXY - - def read(self, filePath: Path) -> Scan: - pointList: list[ScanPoint] = list() - - if self._swapXY: - xcol = 1 - ycol = 0 - else: - xcol = 0 - ycol = 1 - - with filePath.open(newline='') as csvFile: - csvReader = csv.reader(csvFile, delimiter=self._delimiter) - - for idx, row in enumerate(csvReader): - if row[0].startswith('#'): - continue - - if len(row) < 2: - raise ScanPointParseError('Bad number of columns!') - - point = ScanPoint(idx, float(row[xcol]), float(row[ycol])) - pointList.append(point) - - return Scan(pointList) - - -class DelimitedScanFileWriter(ScanFileWriter): - def __init__(self, delimiter: str, swapXY: bool) -> None: - self._delimiter = delimiter - self._swapXY = swapXY - - def write(self, filePath: Path, scan: Scan) -> None: - with filePath.open(mode='wt') as csvFile: - for point in scan: - x = point.positionXInMeters - y = point.positionYInMeters - line = ( - f'{y}{self._delimiter}{x}\n' if self._swapXY else f'{x}{self._delimiter}{y}\n' - ) - csvFile.write(line) - - -def registerPlugins(registry: PluginRegistry) -> None: - registry.scanFileReaders.registerPlugin( - DelimitedScanFileReader(' ', swapXY=False), - simpleName='TXT', - displayName='Space-Separated Values Files (*.txt)', - ) - registry.scanFileWriters.registerPlugin( - DelimitedScanFileWriter(' ', swapXY=False), - simpleName='TXT', - displayName='Space-Separated Values Files (*.txt)', - ) - registry.scanFileReaders.registerPlugin( - DelimitedScanFileReader(',', swapXY=True), - simpleName='CSV', - displayName='Comma-Separated Values Files (*.csv)', - ) - registry.scanFileWriters.registerPlugin( - DelimitedScanFileWriter(',', swapXY=True), - simpleName='CSV', - displayName='Comma-Separated Values Files (*.csv)', - ) diff --git a/src/ptychodus/plugins/delimited_position_file.py b/src/ptychodus/plugins/delimited_position_file.py new file mode 100644 index 00000000..028201d4 --- /dev/null +++ b/src/ptychodus/plugins/delimited_position_file.py @@ -0,0 +1,81 @@ +from pathlib import Path +import csv + +from ptychodus.api.plugins import PluginRegistry +from ptychodus.api.scan import ( + PositionSequence, + PositionFileReader, + PositionFileWriter, + ScanPoint, + ScanPointParseError, +) + + +class DelimitedPositionFileReader(PositionFileReader): + def __init__(self, delimiter: str, swap_xy: bool) -> None: + self._delimiter = delimiter + self._swap_xy = swap_xy + + def read(self, file_path: Path) -> PositionSequence: + point_list: list[ScanPoint] = list() + + if self._swap_xy: + xcol = 1 + ycol = 0 + else: + xcol = 0 + ycol = 1 + + with file_path.open(newline='') as csv_file: + csv_reader = csv.reader(csv_file, delimiter=self._delimiter) + + for idx, row in enumerate(csv_reader): + if row[0].startswith('#'): + continue + + if len(row) < 2: + raise ScanPointParseError('Bad number of columns!') + + point = ScanPoint(idx, float(row[xcol]), float(row[ycol])) + point_list.append(point) + + return PositionSequence(point_list) + + +class DelimitedPositionFileWriter(PositionFileWriter): + def __init__(self, delimiter: str, swap_xy: bool) -> None: + self._delimiter = delimiter + self._swap_xy = swap_xy + + def write(self, file_path: Path, positions: PositionSequence) -> None: + with file_path.open(mode='wt') as csv_file: + for point in positions: + x = point.position_x_m + y = point.position_y_m + line = ( + f'{y}{self._delimiter}{x}\n' if self._swap_xy else f'{x}{self._delimiter}{y}\n' + ) + csv_file.write(line) + + +def register_plugins(registry: PluginRegistry) -> None: + registry.position_file_readers.register_plugin( + DelimitedPositionFileReader(' ', swap_xy=False), + simple_name='TXT', + display_name='Space-Separated Values Files (*.txt)', + ) + registry.position_file_writers.register_plugin( + DelimitedPositionFileWriter(' ', swap_xy=False), + simple_name='TXT', + display_name='Space-Separated Values Files (*.txt)', + ) + registry.position_file_readers.register_plugin( + DelimitedPositionFileReader(',', swap_xy=True), + simple_name='CSV', + display_name='Comma-Separated Values Files (*.csv)', + ) + registry.position_file_writers.register_plugin( + DelimitedPositionFileWriter(',', swap_xy=True), + simple_name='CSV', + display_name='Comma-Separated Values Files (*.csv)', + ) diff --git a/src/ptychodus/plugins/fresnelZonePlate.py b/src/ptychodus/plugins/fresnelZonePlate.py deleted file mode 100644 index 3afe0743..00000000 --- a/src/ptychodus/plugins/fresnelZonePlate.py +++ /dev/null @@ -1,21 +0,0 @@ -from ptychodus.api.probe import FresnelZonePlate -from ptychodus.api.plugins import PluginRegistry - - -def registerPlugins(registry: PluginRegistry) -> None: - registry.fresnelZonePlates.registerPlugin( - FresnelZonePlate(160e-6, 70e-9, 60e-6), - displayName='2-ID-D', - ) - registry.fresnelZonePlates.registerPlugin( - FresnelZonePlate(160e-6, 30e-9, 80e-6), - displayName='HXN', - ) - registry.fresnelZonePlates.registerPlugin( - FresnelZonePlate(114.8e-6, 60e-9, 40e-6), - displayName='LYNX', - ) - registry.fresnelZonePlates.registerPlugin( - FresnelZonePlate(180e-6, 50e-9, 60e-6), - displayName='Velociprobe', - ) diff --git a/src/ptychodus/plugins/fresnel_zone_plate.py b/src/ptychodus/plugins/fresnel_zone_plate.py new file mode 100644 index 00000000..0808ac72 --- /dev/null +++ b/src/ptychodus/plugins/fresnel_zone_plate.py @@ -0,0 +1,25 @@ +from ptychodus.api.probe import FresnelZonePlate +from ptychodus.api.plugins import PluginRegistry + + +def register_plugins(registry: PluginRegistry) -> None: + registry.fresnel_zone_plates.register_plugin( + FresnelZonePlate(160e-6, 70e-9, 60e-6), + display_name='2-ID-D', + ) + registry.fresnel_zone_plates.register_plugin( + FresnelZonePlate(160e-6, 30e-9, 80e-6), + display_name='HXN', + ) + registry.fresnel_zone_plates.register_plugin( + FresnelZonePlate(114.8e-6, 60e-9, 40e-6), + display_name='LYNX', + ) + registry.fresnel_zone_plates.register_plugin( + FresnelZonePlate(180e-6, 15e-9, 15e-6), + display_name='PtychoProbe', + ) + registry.fresnel_zone_plates.register_plugin( + FresnelZonePlate(180e-6, 50e-9, 60e-6), + display_name='Velociprobe', + ) diff --git a/src/ptychodus/plugins/h5DiffractionFile.py b/src/ptychodus/plugins/h5DiffractionFile.py deleted file mode 100644 index 8c9a978b..00000000 --- a/src/ptychodus/plugins/h5DiffractionFile.py +++ /dev/null @@ -1,211 +0,0 @@ -from pathlib import Path -import logging - -import h5py -import numpy - -from ptychodus.api.geometry import ImageExtent -from ptychodus.api.patterns import ( - DiffractionPatternArrayType, - DiffractionDataset, - DiffractionFileReader, - DiffractionFileWriter, - DiffractionMetadata, - DiffractionPatternArray, - DiffractionPatternState, - SimpleDiffractionDataset, -) -from ptychodus.api.plugins import PluginRegistry -from ptychodus.api.tree import SimpleTreeNode - -logger = logging.getLogger(__name__) - - -class H5DiffractionPatternArray(DiffractionPatternArray): - def __init__(self, label: str, index: int, filePath: Path, dataPath: str) -> None: - super().__init__() - self._label = label - self._index = index - self._state = DiffractionPatternState.UNKNOWN - self._filePath = filePath - self._dataPath = dataPath - - def getLabel(self) -> str: - return self._label - - def getIndex(self) -> int: - return self._index - - def getState(self) -> DiffractionPatternState: - return self._state - - def getData(self) -> DiffractionPatternArrayType: - self._state = DiffractionPatternState.MISSING - - with h5py.File(self._filePath, 'r') as h5File: - try: - item = h5File[self._dataPath] - except KeyError: - raise ValueError(f'Symlink {self._filePath}:{self._dataPath} is broken!') - else: - if isinstance(item, h5py.Dataset): - self._state = DiffractionPatternState.FOUND - else: - raise ValueError(f'Symlink {self._filePath}:{self._dataPath} is not a dataset!') - - data = item[()] - - return data - - -class H5DiffractionFileTreeBuilder: - def _addAttributes( - self, treeNode: SimpleTreeNode, attributeManager: h5py.AttributeManager - ) -> None: - for name, value in attributeManager.items(): - if isinstance(value, str): - itemDetails = f'STRING = "{value}"' - elif isinstance(value, h5py.Empty): - logger.debug(f'Skipping empty attribute {name}.') - else: - stringInfo = h5py.check_string_dtype(value.dtype) - itemDetails = ( - f'STRING = "{value.decode(stringInfo.encoding)}"' - if stringInfo - else f'SCALAR {value.dtype} = {value}' - ) - - treeNode.createChild([str(name), 'Attribute', itemDetails]) - - def createRootNode(self) -> SimpleTreeNode: - return SimpleTreeNode.createRoot(['Name', 'Type', 'Details']) - - def build(self, h5File: h5py.File) -> SimpleTreeNode: - rootNode = self.createRootNode() - unvisited = [(rootNode, h5File)] - - while unvisited: - parentItem, h5Group = unvisited.pop() - - for itemName in h5Group: - itemType = 'Unknown' - itemDetails = '' - h5Item = h5Group.get(itemName, getlink=True) - - treeNode = parentItem.createChild(list()) - - if isinstance(h5Item, h5py.HardLink): - itemType = 'Hard Link' - h5Item = h5Group.get(itemName, getlink=False) - - if isinstance(h5Item, h5py.Group): - itemType = 'Group' - self._addAttributes(treeNode, h5Item.attrs) - unvisited.append((treeNode, h5Item)) - elif isinstance(h5Item, h5py.Dataset): - itemType = 'Dataset' - self._addAttributes(treeNode, h5Item.attrs) - spaceId = h5Item.id.get_space() - - if spaceId.get_simple_extent_type() == h5py.h5s.SCALAR: - value = h5Item[()] - - if isinstance(value, bytes): - itemDetails = value.decode() - else: - stringInfo = h5py.check_string_dtype(value.dtype) - itemDetails = ( - f'STRING = "{value.decode(stringInfo.encoding)}"' - if stringInfo - else f'SCALAR {value.dtype} = {value}' - ) - else: - itemDetails = f'{h5Item.shape} {h5Item.dtype}' - elif isinstance(h5Item, h5py.SoftLink): - itemType = 'Soft Link' - itemDetails = f'{h5Item.path}' - elif isinstance(h5Item, h5py.ExternalLink): - itemType = 'External Link' - itemDetails = f'{h5Item.filename}/{h5Item.path}' - else: - logger.debug(f'Unknown item "{itemName}"') - - treeNode.itemData = [itemName, itemType, itemDetails] - - return rootNode - - -class H5DiffractionFileReader(DiffractionFileReader): - def __init__(self, dataPath: str) -> None: - self._dataPath = dataPath - self._treeBuilder = H5DiffractionFileTreeBuilder() - - def read(self, filePath: Path) -> DiffractionDataset: - dataset = SimpleDiffractionDataset.createNullInstance(filePath) - - try: - with h5py.File(filePath, 'r') as h5File: - metadata = DiffractionMetadata.createNullInstance(filePath) - contentsTree = self._treeBuilder.build(h5File) - - try: - data = h5File[self._dataPath] - except KeyError: - logger.warning('Unable to find data.') - else: - numberOfPatterns, detectorHeight, detectorWidth = data.shape - - metadata = DiffractionMetadata( - numberOfPatternsPerArray=numberOfPatterns, - numberOfPatternsTotal=numberOfPatterns, - patternDataType=data.dtype, - detectorExtent=ImageExtent(detectorWidth, detectorHeight), - filePath=filePath, - ) - - array = H5DiffractionPatternArray( - label=filePath.stem, - index=0, - filePath=filePath, - dataPath=self._dataPath, - ) - - dataset = SimpleDiffractionDataset(metadata, contentsTree, [array]) - except OSError: - logger.warning(f'Unable to read file "{filePath}".') - - return dataset - - -class H5DiffractionFileWriter(DiffractionFileWriter): - def __init__(self, dataPath: str) -> None: - self._dataPath = dataPath - - def write(self, filePath: Path, dataset: DiffractionDataset) -> None: - data = numpy.concatenate([array.getData() for array in dataset]) - - with h5py.File(filePath, 'w') as h5File: - h5File.create_dataset(self._dataPath, data=data, compression='gzip') - - -def registerPlugins(registry: PluginRegistry) -> None: - registry.diffractionFileReaders.registerPlugin( - H5DiffractionFileReader(dataPath='/entry/data/data'), - simpleName='APS_HXN', - displayName='CNM/APS HXN Diffraction Files (*.h5 *.hdf5)', - ) - registry.diffractionFileReaders.registerPlugin( - H5DiffractionFileReader(dataPath='/entry/measurement/Eiger/data'), - simpleName='NanoMax', - displayName='NanoMax Diffraction Files (*.h5 *.hdf5)', - ) - registry.diffractionFileReaders.registerPlugin( - H5DiffractionFileReader(dataPath='/dp'), - simpleName='PtychoShelves', - displayName='PtychoShelves Diffraction Files (*.h5 *.hdf5)', - ) - registry.diffractionFileWriters.registerPlugin( - H5DiffractionFileWriter(dataPath='/dp'), - simpleName='PtychoShelves', - displayName='PtychoShelves Diffraction Files (*.h5 *.hdf5)', - ) diff --git a/src/ptychodus/plugins/h5ProductFile.py b/src/ptychodus/plugins/h5ProductFile.py deleted file mode 100644 index 4337465b..00000000 --- a/src/ptychodus/plugins/h5ProductFile.py +++ /dev/null @@ -1,151 +0,0 @@ -from pathlib import Path -from typing import Final -import logging - -import h5py - -from ptychodus.api.object import Object -from ptychodus.api.plugins import PluginRegistry -from ptychodus.api.probe import Probe -from ptychodus.api.product import ( - Product, - ProductFileReader, - ProductFileWriter, - ProductMetadata, -) -from ptychodus.api.scan import Scan, ScanPoint - -logger = logging.getLogger(__name__) - - -class H5ProductFileIO(ProductFileReader, ProductFileWriter): - SIMPLE_NAME: Final[str] = 'HDF5' - DISPLAY_NAME: Final[str] = 'Ptychodus Product Files (*.h5 *.hdf5)' - - NAME: Final[str] = 'name' - COMMENTS: Final[str] = 'comments' - DETECTOR_OBJECT_DISTANCE: Final[str] = 'detector_object_distance_m' - PROBE_ENERGY: Final[str] = 'probe_energy_eV' - PROBE_PHOTON_FLUX: Final[str] = 'probe_photons_per_s' - EXPOSURE_TIME: Final[str] = 'exposure_time_s' - - PROBE_ARRAY: Final[str] = 'probe' - PROBE_PIXEL_HEIGHT: Final[str] = 'pixel_height_m' - PROBE_PIXEL_WIDTH: Final[str] = 'pixel_width_m' - PROBE_POSITION_INDEXES: Final[str] = 'probe_position_indexes' - PROBE_POSITION_X: Final[str] = 'probe_position_x_m' - PROBE_POSITION_Y: Final[str] = 'probe_position_y_m' - - OBJECT_ARRAY: Final[str] = 'object' - OBJECT_CENTER_X: Final[str] = 'center_x_m' - OBJECT_CENTER_Y: Final[str] = 'center_y_m' - OBJECT_LAYER_DISTANCE: Final[str] = 'object_layer_distance_m' - OBJECT_PIXEL_HEIGHT: Final[str] = 'pixel_height_m' - OBJECT_PIXEL_WIDTH: Final[str] = 'pixel_width_m' - - COSTS_ARRAY: Final[str] = 'costs' - - def read(self, filePath: Path) -> Product: - scanPointList: list[ScanPoint] = list() - - with h5py.File(filePath, 'r') as h5File: - metadata = ProductMetadata( - name=str(h5File.attrs[self.NAME]), - comments=str(h5File.attrs[self.COMMENTS]), - detectorDistanceInMeters=float(h5File.attrs[self.DETECTOR_OBJECT_DISTANCE]), - probeEnergyInElectronVolts=float(h5File.attrs[self.PROBE_ENERGY]), - probePhotonsPerSecond=float(h5File.attrs[self.PROBE_PHOTON_FLUX]), - exposureTimeInSeconds=float(h5File.attrs[self.EXPOSURE_TIME]), - ) - - h5ScanIndexes = h5File[self.PROBE_POSITION_INDEXES] - h5ScanX = h5File[self.PROBE_POSITION_X] - h5ScanY = h5File[self.PROBE_POSITION_Y] - - for idx, x_m, y_m in zip(h5ScanIndexes[()], h5ScanX[()], h5ScanY[()]): - point = ScanPoint(idx, x_m, y_m) - scanPointList.append(point) - - h5Probe = h5File[self.PROBE_ARRAY] - probe = Probe( - array=h5Probe[()], - pixelWidthInMeters=float(h5Probe.attrs[self.PROBE_PIXEL_WIDTH]), - pixelHeightInMeters=float(h5Probe.attrs[self.PROBE_PIXEL_HEIGHT]), - ) - - h5Object = h5File[self.OBJECT_ARRAY] - h5ObjectLayerDistance = h5File[self.OBJECT_LAYER_DISTANCE] - object_ = Object( - array=h5Object[()], - layerDistanceInMeters=h5ObjectLayerDistance[()], - pixelWidthInMeters=float(h5Object.attrs[self.OBJECT_PIXEL_WIDTH]), - pixelHeightInMeters=float(h5Object.attrs[self.OBJECT_PIXEL_HEIGHT]), - centerXInMeters=float(h5Object.attrs[self.OBJECT_CENTER_X]), - centerYInMeters=float(h5Object.attrs[self.OBJECT_CENTER_Y]), - ) - - h5Costs = h5File[self.COSTS_ARRAY] - costs = h5Costs[()] - - return Product( - metadata=metadata, - scan=Scan(scanPointList), - probe=probe, - object_=object_, - costs=costs, - ) - - def write(self, filePath: Path, product: Product) -> None: - scanIndexes: list[int] = list() - scanXInMeters: list[float] = list() - scanYInMeters: list[float] = list() - - for point in product.scan: - scanIndexes.append(point.index) - scanXInMeters.append(point.positionXInMeters) - scanYInMeters.append(point.positionYInMeters) - - with h5py.File(filePath, 'w') as h5File: - metadata = product.metadata - h5File.attrs[self.NAME] = metadata.name - h5File.attrs[self.COMMENTS] = metadata.comments - h5File.attrs[self.DETECTOR_OBJECT_DISTANCE] = metadata.detectorDistanceInMeters - h5File.attrs[self.PROBE_ENERGY] = metadata.probeEnergyInElectronVolts - h5File.attrs[self.PROBE_PHOTON_FLUX] = metadata.probePhotonsPerSecond - h5File.attrs[self.EXPOSURE_TIME] = metadata.exposureTimeInSeconds - - h5File.create_dataset(self.PROBE_POSITION_INDEXES, data=scanIndexes) - h5File.create_dataset(self.PROBE_POSITION_X, data=scanXInMeters) - h5File.create_dataset(self.PROBE_POSITION_Y, data=scanYInMeters) - - probe = product.probe - probeGeometry = probe.getGeometry() - h5Probe = h5File.create_dataset(self.PROBE_ARRAY, data=probe.array) - h5Probe.attrs[self.PROBE_PIXEL_WIDTH] = probeGeometry.pixelWidthInMeters - h5Probe.attrs[self.PROBE_PIXEL_HEIGHT] = probeGeometry.pixelHeightInMeters - - object_ = product.object_ - objectGeometry = object_.getGeometry() - h5Object = h5File.create_dataset(self.OBJECT_ARRAY, data=object_.array) - h5Object.attrs[self.OBJECT_CENTER_X] = objectGeometry.centerXInMeters - h5Object.attrs[self.OBJECT_CENTER_Y] = objectGeometry.centerYInMeters - h5Object.attrs[self.OBJECT_PIXEL_WIDTH] = objectGeometry.pixelWidthInMeters - h5Object.attrs[self.OBJECT_PIXEL_HEIGHT] = objectGeometry.pixelHeightInMeters - h5File.create_dataset(self.OBJECT_LAYER_DISTANCE, data=object_.layerDistanceInMeters) - - h5File.create_dataset(self.COSTS_ARRAY, data=product.costs) - - -def registerPlugins(registry: PluginRegistry) -> None: - h5ProductFileIO = H5ProductFileIO() - - registry.productFileReaders.registerPlugin( - h5ProductFileIO, - simpleName=H5ProductFileIO.SIMPLE_NAME, - displayName=H5ProductFileIO.DISPLAY_NAME, - ) - registry.productFileWriters.registerPlugin( - h5ProductFileIO, - simpleName=H5ProductFileIO.SIMPLE_NAME, - displayName=H5ProductFileIO.DISPLAY_NAME, - ) diff --git a/src/ptychodus/plugins/h5_diffraction_file.py b/src/ptychodus/plugins/h5_diffraction_file.py new file mode 100644 index 00000000..f49af9fd --- /dev/null +++ b/src/ptychodus/plugins/h5_diffraction_file.py @@ -0,0 +1,222 @@ +from pathlib import Path +import logging + +import h5py +import numpy + +from ptychodus.api.geometry import ImageExtent +from ptychodus.api.patterns import ( + DiffractionDataset, + DiffractionFileReader, + DiffractionFileWriter, + DiffractionMetadata, + DiffractionPatternArray, + PatternDataType, + PatternIndexesType, + SimpleDiffractionDataset, +) +from ptychodus.api.plugins import PluginRegistry +from ptychodus.api.tree import SimpleTreeNode + +logger = logging.getLogger(__name__) + + +class H5DiffractionPatternArray(DiffractionPatternArray): + def __init__( + self, label: str, indexes: PatternIndexesType, file_path: Path, data_path: str + ) -> None: + super().__init__() + self._label = label + self._indexes = indexes + self._file_path = file_path + self._data_path = data_path + + def get_label(self) -> str: + return self._label + + def get_indexes(self) -> PatternIndexesType: + return self._indexes + + def get_data(self) -> PatternDataType: + with h5py.File(self._file_path, 'r') as h5_file: + try: + item = h5_file[self._data_path] + except KeyError: + raise ValueError(f'Symlink {self._file_path}:{self._data_path} is broken!') + else: + if not isinstance(item, h5py.Dataset): + raise ValueError( + f'Symlink {self._file_path}:{self._data_path} is not a dataset!' + ) + + data = item[()] + + return data + + +class H5DiffractionFileTreeBuilder: + def _add_attributes( + self, tree_node: SimpleTreeNode, attribute_manager: h5py.AttributeManager + ) -> None: + for name, value in attribute_manager.items(): + if isinstance(value, str): + item_details = f'STRING = "{value}"' + elif isinstance(value, h5py.Empty): + logger.debug(f'Skipping empty attribute {name}.') + elif isinstance(value, numpy.ndarray): + item_details = f'ARRAY = {value}' + else: + string_info = h5py.check_string_dtype(value.dtype) + + if string_info: + item_details = f'STRING = "{value.decode(string_info.encoding)}"' + else: + item_details = f'SCALAR {value.dtype} = {value}' + + tree_node.create_child([str(name), 'Attribute', item_details]) + + def create_root_node(self) -> SimpleTreeNode: + return SimpleTreeNode.create_root(['Name', 'Type', 'Details']) + + def build(self, h5_file: h5py.File) -> SimpleTreeNode: + root_node = self.create_root_node() + unvisited = [(root_node, h5_file)] + + while unvisited: + parent_item, h5_group = unvisited.pop() + + for item_name in h5_group: + item_type = 'Unknown' + item_details = '' + h5_item = h5_group.get(item_name, getlink=True) + + tree_node = parent_item.create_child(list()) + + if isinstance(h5_item, h5py.HardLink): + item_type = 'Hard Link' + h5_item = h5_group.get(item_name, getlink=False) + + if isinstance(h5_item, h5py.Group): + item_type = 'Group' + self._add_attributes(tree_node, h5_item.attrs) + unvisited.append((tree_node, h5_item)) + elif isinstance(h5_item, h5py.Dataset): + item_type = 'Dataset' + self._add_attributes(tree_node, h5_item.attrs) + space_id = h5_item.id.get_space() + + if space_id.get_simple_extent_type() == h5py.h5s.SCALAR: + value = h5_item[()] + + if isinstance(value, bytes): + item_details = value.decode() + elif isinstance(value, numpy.ndarray): + item_details = f'STRING = {h5_item.asstr()}' + else: + string_info = h5py.check_string_dtype(value.dtype) + + if string_info: + item_details = ( + f'STRING = "{value.decode(string_info.encoding)}"' + ) + else: + item_details = f'SCALAR {value.dtype} = {value}' + elif h5_item.size == 1: + value = h5_item[()] + item_details = f'DATASET {value.dtype} = {value}' + else: + item_details = f'{h5_item.shape} {h5_item.dtype}' + elif isinstance(h5_item, h5py.SoftLink): + item_type = 'Soft Link' + item_details = f'{h5_item.path}' + elif isinstance(h5_item, h5py.ExternalLink): + item_type = 'External Link' + item_details = f'{h5_item.filename}/{h5_item.path}' + else: + logger.debug(f'Unknown item "{item_name}"') + + tree_node.item_data = [item_name, item_type, item_details] + + return root_node + + +class H5DiffractionFileReader(DiffractionFileReader): + def __init__(self, data_path: str) -> None: + self._data_path = data_path + self._tree_builder = H5DiffractionFileTreeBuilder() + + def read(self, file_path: Path) -> DiffractionDataset: + dataset = SimpleDiffractionDataset.create_null(file_path) + + try: + with h5py.File(file_path, 'r') as h5_file: + metadata = DiffractionMetadata.create_null(file_path) + contents_tree = self._tree_builder.build(h5_file) + + try: + data = h5_file[self._data_path] + except KeyError: + logger.warning('Unable to find data.') + return dataset + + num_patterns, detector_height, detector_width = data.shape + + metadata = DiffractionMetadata( + num_patterns_per_array=num_patterns, + num_patterns_total=num_patterns, + pattern_dtype=data.dtype, + detector_extent=ImageExtent(detector_width, detector_height), + file_path=file_path, + ) + + array = H5DiffractionPatternArray( + label=file_path.stem, + indexes=numpy.arange(num_patterns), + file_path=file_path, + data_path=self._data_path, + ) + + dataset = SimpleDiffractionDataset(metadata, contents_tree, [array]) + except OSError: + logger.warning(f'Unable to read file "{file_path}".') + + return dataset + + +class H5DiffractionFileWriter(DiffractionFileWriter): + def __init__(self, data_path: str) -> None: + self._data_path = data_path + + def write(self, file_path: Path, dataset: DiffractionDataset) -> None: + data = numpy.concatenate([array.get_data() for array in dataset]) + + with h5py.File(file_path, 'w') as h5_file: + h5_file.create_dataset(self._data_path, data=data, compression='gzip') + + +def register_plugins(registry: PluginRegistry) -> None: + registry.diffraction_file_readers.register_plugin( + H5DiffractionFileReader(data_path='/exchange/data'), + simple_name='APS_CSSI', + display_name='APS CSSI Files (*.h5 *.hdf5)', + ) + registry.diffraction_file_readers.register_plugin( + H5DiffractionFileReader(data_path='/entry/data/data'), + simple_name='APS_HXN', + display_name='CNM/APS HXN Files (*.h5 *.hdf5)', + ) + registry.diffraction_file_readers.register_plugin( + H5DiffractionFileReader(data_path='/entry/measurement/Eiger/data'), + simple_name='MAX_IV_NanoMax', + display_name='MAX IV NanoMax Files (*.h5 *.hdf5)', + ) + registry.diffraction_file_readers.register_plugin( + H5DiffractionFileReader(data_path='/dp'), + simple_name='PtychoShelves', + display_name='PtychoShelves Files (*.h5 *.hdf5)', + ) + registry.diffraction_file_writers.register_plugin( + H5DiffractionFileWriter(data_path='/dp'), + simple_name='PtychoShelves', + display_name='PtychoShelves Files (*.h5 *.hdf5)', + ) diff --git a/src/ptychodus/plugins/h5_product_file.py b/src/ptychodus/plugins/h5_product_file.py new file mode 100644 index 00000000..83135a13 --- /dev/null +++ b/src/ptychodus/plugins/h5_product_file.py @@ -0,0 +1,204 @@ +from pathlib import Path +from typing import Final +import logging + +import h5py + +from ptychodus.api.geometry import PixelGeometry +from ptychodus.api.object import Object, ObjectCenter +from ptychodus.api.plugins import PluginRegistry +from ptychodus.api.probe import ProbeSequence +from ptychodus.api.product import ( + Product, + ProductFileReader, + ProductFileWriter, + ProductMetadata, +) +from ptychodus.api.scan import PositionSequence, ScanPoint + +logger = logging.getLogger(__name__) + + +class H5ProductFileIO(ProductFileReader, ProductFileWriter): + SIMPLE_NAME: Final[str] = 'HDF5' + DISPLAY_NAME: Final[str] = 'Ptychodus Product Files (*.h5 *.hdf5)' + + NAME: Final[str] = 'name' + COMMENTS: Final[str] = 'comments' + DETECTOR_OBJECT_DISTANCE: Final[str] = 'detector_object_distance_m' + PROBE_ENERGY: Final[str] = 'probe_energy_eV' + PROBE_PHOTON_COUNT: Final[str] = 'probe_photon_count' + EXPOSURE_TIME: Final[str] = 'exposure_time_s' + MASS_ATTENUATION: Final[str] = 'mass_attenuation_m2_kg' + TOMOGRAPHY_ANGLE: Final[str] = 'tomography_angle_deg' + + PROBE_ARRAY: Final[str] = 'probe' + OPR_WEIGHTS: Final[str] = 'opr_weights' + PROBE_PIXEL_HEIGHT: Final[str] = 'pixel_height_m' + PROBE_PIXEL_WIDTH: Final[str] = 'pixel_width_m' + PROBE_POSITION_INDEXES: Final[str] = 'probe_position_indexes' + PROBE_POSITION_X: Final[str] = 'probe_position_x_m' + PROBE_POSITION_Y: Final[str] = 'probe_position_y_m' + + OBJECT_ARRAY: Final[str] = 'object' + OBJECT_CENTER_X: Final[str] = 'center_x_m' + OBJECT_CENTER_Y: Final[str] = 'center_y_m' + OBJECT_LAYER_SPACING: Final[str] = 'object_layer_spacing_m' + OBJECT_PIXEL_HEIGHT: Final[str] = 'pixel_height_m' + OBJECT_PIXEL_WIDTH: Final[str] = 'pixel_width_m' + + COSTS_ARRAY: Final[str] = 'costs' + + def read(self, file_path: Path) -> Product: + point_list: list[ScanPoint] = list() + + with h5py.File(file_path, 'r') as h5_file: + probe_photon_count = 0.0 + + try: + probe_photon_count = float(h5_file.attrs[self.PROBE_PHOTON_COUNT]) + except KeyError: + logger.debug('Probe photon count not found.') + + mass_attenuation_m2_kg = 0.0 + + try: + mass_attenuation_m2_kg = float(h5_file.attrs[self.MASS_ATTENUATION]) + except KeyError: + logger.debug('Mass attenuation not found.') + + tomography_angle_deg = 0.0 + + try: + tomography_angle_deg = float(h5_file.attrs[self.TOMOGRAPHY_ANGLE]) + except KeyError: + logger.debug('Tomography angle not found.') + + metadata = ProductMetadata( + name=str(h5_file.attrs[self.NAME]), + comments=str(h5_file.attrs[self.COMMENTS]), + detector_distance_m=float(h5_file.attrs[self.DETECTOR_OBJECT_DISTANCE]), + probe_energy_eV=float(h5_file.attrs[self.PROBE_ENERGY]), + probe_photon_count=probe_photon_count, + exposure_time_s=float(h5_file.attrs[self.EXPOSURE_TIME]), + mass_attenuation_m2_kg=mass_attenuation_m2_kg, + tomography_angle_deg=tomography_angle_deg, + ) + + h5_scan_indexes = h5_file[self.PROBE_POSITION_INDEXES] + h5_scan_x = h5_file[self.PROBE_POSITION_X] + h5_scan_y = h5_file[self.PROBE_POSITION_Y] + + for idx, x_m, y_m in zip(h5_scan_indexes[()], h5_scan_x[()], h5_scan_y[()]): + point = ScanPoint(idx, x_m, y_m) + point_list.append(point) + + h5_probe = h5_file[self.PROBE_ARRAY] + probe_pixel_geometry = PixelGeometry( + width_m=float(h5_probe.attrs[self.PROBE_PIXEL_WIDTH]), + height_m=float(h5_probe.attrs[self.PROBE_PIXEL_HEIGHT]), + ) + + try: + opr_weights = h5_probe.attrs[self.OPR_WEIGHTS] + except KeyError: + logger.debug('OPR weights not found.') + opr_weights = None + + probe = ProbeSequence( + array=h5_probe[()], + opr_weights=opr_weights, + pixel_geometry=probe_pixel_geometry, + ) + + h5_object = h5_file[self.OBJECT_ARRAY] + object_pixel_geometry = PixelGeometry( + width_m=float(h5_object.attrs[self.OBJECT_PIXEL_WIDTH]), + height_m=float(h5_object.attrs[self.OBJECT_PIXEL_HEIGHT]), + ) + object_center = ObjectCenter( + position_x_m=float(h5_object.attrs[self.OBJECT_CENTER_X]), + position_y_m=float(h5_object.attrs[self.OBJECT_CENTER_Y]), + ) + h5_object_layer_spacing = h5_file[self.OBJECT_LAYER_SPACING] + object_ = Object( + array=h5_object[()], + pixel_geometry=object_pixel_geometry, + center=object_center, + layer_spacing_m=h5_object_layer_spacing[()], + ) + + h5_costs = h5_file[self.COSTS_ARRAY] + costs = h5_costs[()] + + return Product( + metadata=metadata, + positions=PositionSequence(point_list), + probes=probe, + object_=object_, + costs=costs, + ) + + def write(self, file_path: Path, product: Product) -> None: + scan_indexes: list[int] = list() + scan_x_m: list[float] = list() + scan_y_m: list[float] = list() + + for point in product.positions: + scan_indexes.append(point.index) + scan_x_m.append(point.position_x_m) + scan_y_m.append(point.position_y_m) + + with h5py.File(file_path, 'w') as h5_file: + metadata = product.metadata + h5_file.attrs[self.NAME] = metadata.name + h5_file.attrs[self.COMMENTS] = metadata.comments + h5_file.attrs[self.DETECTOR_OBJECT_DISTANCE] = metadata.detector_distance_m + h5_file.attrs[self.PROBE_ENERGY] = metadata.probe_energy_eV + h5_file.attrs[self.PROBE_PHOTON_COUNT] = metadata.probe_photon_count + h5_file.attrs[self.EXPOSURE_TIME] = metadata.exposure_time_s + h5_file.attrs[self.MASS_ATTENUATION] = metadata.mass_attenuation_m2_kg + + h5_file.create_dataset(self.PROBE_POSITION_INDEXES, data=scan_indexes) + h5_file.create_dataset(self.PROBE_POSITION_X, data=scan_x_m) + h5_file.create_dataset(self.PROBE_POSITION_Y, data=scan_y_m) + + probe = product.probes + h5_probe = h5_file.create_dataset(self.PROBE_ARRAY, data=probe.get_array()) + + try: + opr_weights = probe.get_opr_weights() + except ValueError: + pass + else: + h5_file.create_dataset(self.OPR_WEIGHTS, data=opr_weights) + + probe_pixel_geometry = probe.get_pixel_geometry() + h5_probe.attrs[self.PROBE_PIXEL_WIDTH] = probe_pixel_geometry.width_m + h5_probe.attrs[self.PROBE_PIXEL_HEIGHT] = probe_pixel_geometry.height_m + + object_ = product.object_ + object_geometry = object_.get_geometry() + h5_object = h5_file.create_dataset(self.OBJECT_ARRAY, data=object_.get_array()) + h5_object.attrs[self.OBJECT_CENTER_X] = object_geometry.center_x_m + h5_object.attrs[self.OBJECT_CENTER_Y] = object_geometry.center_y_m + h5_object.attrs[self.OBJECT_PIXEL_WIDTH] = object_geometry.pixel_width_m + h5_object.attrs[self.OBJECT_PIXEL_HEIGHT] = object_geometry.pixel_height_m + h5_file.create_dataset(self.OBJECT_LAYER_SPACING, data=object_.layer_spacing_m) + + h5_file.create_dataset(self.COSTS_ARRAY, data=product.costs) + + +def register_plugins(registry: PluginRegistry) -> None: + h5_product_file_io = H5ProductFileIO() + + registry.register_product_file_reader_with_adapters( + h5_product_file_io, + simple_name=H5ProductFileIO.SIMPLE_NAME, + display_name=H5ProductFileIO.DISPLAY_NAME, + ) + registry.product_file_writers.register_plugin( + h5_product_file_io, + simple_name=H5ProductFileIO.SIMPLE_NAME, + display_name=H5ProductFileIO.DISPLAY_NAME, + ) diff --git a/src/ptychodus/plugins/isn_diffraction_file.py b/src/ptychodus/plugins/isn_diffraction_file.py new file mode 100644 index 00000000..2094aeaf --- /dev/null +++ b/src/ptychodus/plugins/isn_diffraction_file.py @@ -0,0 +1,100 @@ +from pathlib import Path +from typing import Final +import logging + +import h5py +import numpy + +from ptychodus.api.geometry import ImageExtent, PixelGeometry +from ptychodus.api.patterns import ( + DiffractionDataset, + DiffractionFileReader, + DiffractionMetadata, + DiffractionPatternArray, + PatternDataType, + PatternIndexesType, + SimpleDiffractionDataset, +) +from ptychodus.api.plugins import PluginRegistry + +from .h5_diffraction_file import H5DiffractionPatternArray, H5DiffractionFileTreeBuilder + +logger = logging.getLogger(__name__) + + +class ISNDiffractionFileReader(DiffractionFileReader): + ONE_MILLIMETER_M: Final[float] = 1.0e-3 + + def __init__(self) -> None: + self._tree_builder = H5DiffractionFileTreeBuilder() + + def read(self, file_path: Path) -> DiffractionDataset: + dataset = SimpleDiffractionDataset.create_null(file_path) + + try: + with h5py.File(file_path, 'r') as h5_file: + metadata = DiffractionMetadata.create_null(file_path) + contents_tree = self._tree_builder.build(h5_file) + array_list: list[DiffractionPatternArray] = [] + dataset = SimpleDiffractionDataset(metadata, contents_tree, array_list) + + try: + configs = h5_file['configs'] + num_patterns_per_array = int(configs['num_images'][()]) + num_patterns_total = 50 * num_patterns_per_array # TODO generalize + detector_distance_mm = float(configs['det_dist_mm']) + detector_width = int(configs['det_size_x'][()]) + detector_height = int(configs['det_size_y'][()]) + pixel_width_m = float(configs['pix_size_x'][()]) + pixel_height_m = float(configs['pix_size_y'][()]) + probe_energy_eV = float(configs['photon_energy_eV'][()]) # noqa: N806 + except KeyError: + logger.warning('Unable to find metadata.') + return dataset + + metadata = DiffractionMetadata( + num_patterns_per_array=num_patterns_per_array, + num_patterns_total=num_patterns_total, + pattern_dtype=numpy.dtype('u4'), + detector_distance_m=self.ONE_MILLIMETER_M * detector_distance_mm, + detector_extent=ImageExtent(detector_width, detector_height), + detector_pixel_geometry=PixelGeometry(pixel_width_m, pixel_height_m), + probe_energy_eV=probe_energy_eV, + file_path=file_path, + ) + + try: + ptycho = h5_file['PTYCHO'] + except KeyError: + logger.warning('Unable to find data.') + return dataset + + for name, h5_item in sorted(ptycho.items()): + h5_item = ptycho.get(name, getlink=True) + + if isinstance(h5_item, h5py.ExternalLink): + offset = len(array_list) * metadata.num_patterns_per_array + data_path = '/entry/data/data' # TODO str(h5_item.path) + array = H5DiffractionPatternArray( + label=name, + indexes=numpy.arange(num_patterns_per_array) + offset, + file_path=file_path.parent / h5_item.filename, + data_path=data_path, + ) + array_list.append(array) + else: + logger.debug(f'Skipping "{name}": not an external link.') + + dataset = SimpleDiffractionDataset(metadata, contents_tree, array_list) + except OSError: + logger.warning(f'Unable to read file "{file_path}".') + + return dataset + + +def register_plugins(registry: PluginRegistry) -> None: + registry.diffraction_file_readers.register_plugin( + ISNDiffractionFileReader(), + simple_name='APS_ISN', + display_name='APS ISN Files (*.h5 *.hdf5)', + ) diff --git a/src/ptychodus/plugins/lclsFileReaders.py b/src/ptychodus/plugins/lclsFileReaders.py deleted file mode 100644 index 57fcbbd1..00000000 --- a/src/ptychodus/plugins/lclsFileReaders.py +++ /dev/null @@ -1,173 +0,0 @@ -from pathlib import Path -from typing import Final -import logging - -import h5py -import numpy -import tables - -from ptychodus.api.geometry import ImageExtent -from ptychodus.api.patterns import ( - DiffractionPatternArrayType, - DiffractionDataset, - DiffractionFileReader, - DiffractionMetadata, - DiffractionPatternArray, - DiffractionPatternState, - SimpleDiffractionDataset, -) -from ptychodus.api.plugins import PluginRegistry -from ptychodus.api.scan import Scan, ScanFileReader, ScanPoint - -from .h5DiffractionFile import H5DiffractionFileTreeBuilder - -logger = logging.getLogger(__name__) - - -class PyTablesDiffractionPatternArray(DiffractionPatternArray): - def __init__(self, label: str, index: int, filePath: Path, dataPath: str) -> None: - super().__init__() - self._label = label - self._index = index - self._state = DiffractionPatternState.UNKNOWN - self._filePath = filePath - self._dataPath = dataPath - - def getLabel(self) -> str: - return self._label - - def getIndex(self) -> int: - return self._index - - def getState(self) -> DiffractionPatternState: - return self._state - - def getData(self) -> DiffractionPatternArrayType: - self._state = DiffractionPatternState.MISSING - - with tables.open_file(self._filePath, mode='r') as h5file: - try: - item = h5file.get_node(self._dataPath) - except tables.NoSuchNodeError: - raise ValueError(f'Symlink {self._filePath}:{self._dataPath} is broken!') - else: - if isinstance(item, tables.EArray): - self._state = DiffractionPatternState.FOUND - else: - raise ValueError( - f'Symlink {self._filePath}:{self._dataPath} is not a tables File!' - ) - - data = item[:] - - return data - - -class LCLSDiffractionFileReader(DiffractionFileReader): - def __init__(self) -> None: - self._dataPath = '/jungfrau1M/image_img' - self._treeBuilder = H5DiffractionFileTreeBuilder() - - def read(self, filePath: Path) -> DiffractionDataset: - dataset = SimpleDiffractionDataset.createNullInstance(filePath) - metadata = DiffractionMetadata.createNullInstance(filePath) - - try: - with tables.open_file(filePath, mode='r') as h5File: - try: - data = h5File.get_node(self._dataPath) - except tables.NoSuchNodeError: - logger.debug('Unable to find data.') - else: - data_shape = h5File.root.jungfrau1M.image_img.shape - numberOfPatterns, detectorHeight, detectorWidth = data_shape - - array = PyTablesDiffractionPatternArray( - label=filePath.stem, - index=0, - filePath=filePath, - dataPath=self._dataPath, - ) - metadata = DiffractionMetadata( - numberOfPatternsPerArray=numberOfPatterns, - numberOfPatternsTotal=numberOfPatterns, - patternDataType=data.dtype, - detectorExtent=ImageExtent(detectorWidth, detectorHeight), - filePath=filePath, - ) - - with h5py.File(filePath, 'r') as h5File: - contentsTree = self._treeBuilder.build(h5File) - - dataset = SimpleDiffractionDataset(metadata, contentsTree, [array]) - except OSError: - logger.debug(f'Unable to read file "{filePath}".') - - return dataset - - -class LCLSScanFileReader(ScanFileReader): - MICRONS_TO_METERS: Final[float] = 1e-6 - - def __init__( - self, - tomographyAngleInDegrees: float, - ipm2LowThreshold: float, - ipm2HighThreshold: float, - ) -> None: - self._tomographyAngleInDegrees = tomographyAngleInDegrees - self._ipm2LowThreshold = ipm2LowThreshold - self._ipm2HighThreshold = ipm2HighThreshold - - def read(self, filePath: Path) -> Scan: - scanPointList: list[ScanPoint] = list() - - with tables.open_file(filePath, mode='r') as h5file: - try: - # piezo stage positions are in microns - pi_x = h5file.get_node('/lmc/ch03')[:] - pi_y = h5file.get_node('/lmc/ch04')[:] - pi_z = h5file.get_node('/lmc/ch05')[:] - - # ipm2 is used for filtering and normalizing the data - ipm2 = h5file.get_node('/ipm2/sum')[:] - except tables.NoSuchNodeError: - logger.exception('Unable to load scan.') - else: - # vertical coordinate is always pi_z - ycoords = -pi_z * self.MICRONS_TO_METERS - - # horizontal coordinate may be a combination of pi_x and pi_y - tomographyAngleInRadians = numpy.deg2rad(self._tomographyAngleInDegrees) - cosAngle = numpy.cos(tomographyAngleInRadians) - sinAngle = numpy.sin(tomographyAngleInRadians) - xcoords = (cosAngle * pi_x + sinAngle * pi_y) * self.MICRONS_TO_METERS - - for index, (ipm, x, y) in enumerate(zip(ipm2, xcoords, ycoords)): - if self._ipm2LowThreshold <= ipm and ipm < self._ipm2HighThreshold: - if numpy.isfinite(x) and numpy.isfinite(y): - point = ScanPoint(index, x, y) - scanPointList.append(point) - else: - logger.debug(f'Filtered scan point {index=} {ipm=}.') - - return Scan(scanPointList) - - -def registerPlugins(registry: PluginRegistry) -> None: - SIMPLE_NAME: Final[str] = 'LCLS_XPP' - - registry.diffractionFileReaders.registerPlugin( - LCLSDiffractionFileReader(), - simpleName=SIMPLE_NAME, - displayName='LCLS XPP Diffraction Files (*.h5 *.hdf5)', - ) - registry.scanFileReaders.registerPlugin( - LCLSScanFileReader( - tomographyAngleInDegrees=180.0, - ipm2LowThreshold=2500.0, - ipm2HighThreshold=6000.0, - ), - simpleName=SIMPLE_NAME, - displayName='LCLS XPP Scan Files (*.h5 *.hdf5)', - ) diff --git a/src/ptychodus/plugins/lcls_file_readers.py b/src/ptychodus/plugins/lcls_file_readers.py new file mode 100644 index 00000000..272e6b18 --- /dev/null +++ b/src/ptychodus/plugins/lcls_file_readers.py @@ -0,0 +1,167 @@ +from pathlib import Path +from typing import Final +import logging + +import h5py +import numpy +import tables + +from ptychodus.api.geometry import ImageExtent +from ptychodus.api.patterns import ( + DiffractionDataset, + DiffractionFileReader, + DiffractionMetadata, + DiffractionPatternArray, + PatternDataType, + PatternIndexesType, + SimpleDiffractionDataset, +) +from ptychodus.api.plugins import PluginRegistry +from ptychodus.api.scan import PositionSequence, PositionFileReader, ScanPoint + +from .h5_diffraction_file import H5DiffractionFileTreeBuilder + +logger = logging.getLogger(__name__) + + +class PyTablesDiffractionPatternArray(DiffractionPatternArray): + def __init__(self, label: str, num_patterns: int, file_path: Path, data_path: str) -> None: + super().__init__() + self._label = label + self._indexes = numpy.arange(num_patterns) + self._file_path = file_path + self._data_path = data_path + + def get_label(self) -> str: + return self._label + + def get_indexes(self) -> PatternIndexesType: + return self._indexes + + def get_data(self) -> PatternDataType: + with tables.open_file(self._file_path, mode='r') as h5_file: + try: + item = h5_file.get_node(self._data_path) + except tables.NoSuchNodeError: + raise ValueError(f'Symlink {self._file_path}:{self._data_path} is broken!') + else: + if not isinstance(item, tables.EArray): + raise ValueError( + f'Symlink {self._file_path}:{self._data_path} is not a tables File!' + ) + + data = item[:] + + return data + + +class LCLSDiffractionFileReader(DiffractionFileReader): + def __init__(self) -> None: + self._data_path = '/jungfrau1M/image_img' + self._tree_builder = H5DiffractionFileTreeBuilder() + + def read(self, file_path: Path) -> DiffractionDataset: + dataset = SimpleDiffractionDataset.create_null(file_path) + metadata = DiffractionMetadata.create_null(file_path) + + try: + with tables.open_file(file_path, mode='r') as h5_file: + try: + data = h5_file.get_node(self._data_path) + except tables.NoSuchNodeError: + logger.debug('Unable to find data.') + return dataset + + data_shape = h5_file.root.jungfrau1M.image_img.shape + num_patterns, detector_height, detector_width = data_shape + + array = PyTablesDiffractionPatternArray( + label=file_path.stem, + num_patterns=num_patterns, + file_path=file_path, + data_path=self._data_path, + ) + metadata = DiffractionMetadata( + num_patterns_per_array=num_patterns, + num_patterns_total=num_patterns, + pattern_dtype=data.dtype, + detector_extent=ImageExtent(detector_width, detector_height), + file_path=file_path, + ) + + with h5py.File(file_path, 'r') as h5_file: + contents_tree = self._tree_builder.build(h5_file) + + dataset = SimpleDiffractionDataset(metadata, contents_tree, [array]) + except OSError: + logger.debug(f'Unable to read file "{file_path}".') + + return dataset + + +class LCLSPositionFileReader(PositionFileReader): + MICRONS_TO_METERS: Final[float] = 1e-6 + + def __init__( + self, + tomography_angle_deg: float, + ipm2_low_threshold: float, + ipm2_high_threshold: float, + ) -> None: + self._tomography_angle_deg = tomography_angle_deg + self._ipm2_low_threshold = ipm2_low_threshold + self._ipm2_high_threshold = ipm2_high_threshold + + def read(self, file_path: Path) -> PositionSequence: + point_list: list[ScanPoint] = list() + + with tables.open_file(file_path, mode='r') as h5_file: + try: + # piezo stage positions are in microns + pi_x = h5_file.get_node('/lmc/ch03')[:] + pi_y = h5_file.get_node('/lmc/ch04')[:] + pi_z = h5_file.get_node('/lmc/ch05')[:] + + # ipm2 is used for filtering and normalizing the data + ipm2 = h5_file.get_node('/ipm2/sum')[:] + except tables.NoSuchNodeError: + logger.exception('Unable to load scan.') + else: + # vertical coordinate is always pi_z + ycoords = -pi_z * self.MICRONS_TO_METERS + + # horizontal coordinate may be a combination of pi_x and pi_y + tomography_angle_rad = numpy.deg2rad(self._tomography_angle_deg) + cos_angle = numpy.cos(tomography_angle_rad) + sin_angle = numpy.sin(tomography_angle_rad) + xcoords = (cos_angle * pi_x + sin_angle * pi_y) * self.MICRONS_TO_METERS + + for index, (ipm, x, y) in enumerate(zip(ipm2, xcoords, ycoords)): + if self._ipm2_low_threshold <= ipm and ipm < self._ipm2_high_threshold: + if numpy.isfinite(x) and numpy.isfinite(y): + point = ScanPoint(index, x, y) + point_list.append(point) + else: + logger.debug(f'Filtered scan point {index=} {ipm=}.') + + return PositionSequence(point_list) + + +def register_plugins(registry: PluginRegistry) -> None: + SIMPLE_NAME: Final[str] = 'LCLS_XPP' # noqa: N806 + DISPLAY_NAME: Final[str] = 'LCLS XPP Files (*.h5 *.hdf5)' # noqa: N806 + + registry.diffraction_file_readers.register_plugin( + LCLSDiffractionFileReader(), + simple_name=SIMPLE_NAME, + display_name=DISPLAY_NAME, + ) + registry.position_file_readers.register_plugin( + LCLSPositionFileReader( + tomography_angle_deg=180.0, + ipm2_low_threshold=2500.0, + ipm2_high_threshold=6000.0, + ), + simple_name=SIMPLE_NAME, + display_name=DISPLAY_NAME, + ) diff --git a/src/ptychodus/plugins/lynxDiffractionFile.py b/src/ptychodus/plugins/lynxDiffractionFile.py deleted file mode 100644 index 477dc6d6..00000000 --- a/src/ptychodus/plugins/lynxDiffractionFile.py +++ /dev/null @@ -1,68 +0,0 @@ -from pathlib import Path -import logging - -import h5py - -from ptychodus.api.geometry import ImageExtent, PixelGeometry -from ptychodus.api.patterns import ( - DiffractionDataset, - DiffractionFileReader, - DiffractionMetadata, - SimpleDiffractionDataset, -) -from ptychodus.api.plugins import PluginRegistry - -from .h5DiffractionFile import H5DiffractionPatternArray, H5DiffractionFileTreeBuilder - -logger = logging.getLogger(__name__) - - -class LYNXDiffractionFileReader(DiffractionFileReader): - def __init__(self) -> None: - self._dataPath = '/entry/data/eiger_4' - self._treeBuilder = H5DiffractionFileTreeBuilder() - - def read(self, filePath: Path) -> DiffractionDataset: - dataset = SimpleDiffractionDataset.createNullInstance(filePath) - - try: - with h5py.File(filePath, 'r') as h5File: - contentsTree = self._treeBuilder.build(h5File) - - try: - data = h5File[self._dataPath] - pixelSize = float(data.attrs['Pixel_size'].item()) - except KeyError: - logger.warning('Unable to load data.') - else: - numberOfPatterns, detectorHeight, detectorWidth = data.shape - - metadata = DiffractionMetadata( - numberOfPatternsPerArray=numberOfPatterns, - numberOfPatternsTotal=numberOfPatterns, - patternDataType=data.dtype, - detectorExtent=ImageExtent(detectorWidth, detectorHeight), - detectorPixelGeometry=PixelGeometry(pixelSize, pixelSize), - filePath=filePath, - ) - - array = H5DiffractionPatternArray( - label=filePath.stem, - index=0, - filePath=filePath, - dataPath=self._dataPath, - ) - - dataset = SimpleDiffractionDataset(metadata, contentsTree, [array]) - except OSError: - logger.warning(f'Unable to read file "{filePath}".') - - return dataset - - -def registerPlugins(registry: PluginRegistry) -> None: - registry.diffractionFileReaders.registerPlugin( - LYNXDiffractionFileReader(), - simpleName='LYNX', - displayName='LYNX Diffraction Files (*.h5 *.hdf5)', - ) diff --git a/src/ptychodus/plugins/lynxSoftGlueZynqScanFile.py b/src/ptychodus/plugins/lynxSoftGlueZynqScanFile.py deleted file mode 100644 index e17a3721..00000000 --- a/src/ptychodus/plugins/lynxSoftGlueZynqScanFile.py +++ /dev/null @@ -1,84 +0,0 @@ -from pathlib import Path -from typing import Final -import csv -import logging - -from ptychodus.api.plugins import PluginRegistry -from ptychodus.api.scan import Scan, ScanFileReader, ScanPoint, ScanPointParseError - -logger = logging.getLogger(__name__) - - -class LYNXSoftGlueZynqScanFileReader(ScanFileReader): - SIMPLE_NAME: Final[str] = 'LYNXSoftGlueZynq' - MICRONS_TO_METERS: Final[float] = 1.0e-6 - - EXPECTED_HEADER_RAW: Final[list[str]] = [ - 'DataPoint', - 'x_st_fzp', - 'y_st_fzp', - 'ckUser_Clk_Count', - 'Detector_Count', - ] - - EXPECTED_HEADER_PROCESSED: Final[list[str]] = [ - 'Detector_Count', - 'Average_x_st_fzp', - 'Stdev_x_st_fzp', - 'Average_y_st_fzp', - 'Stdev_y_st_fzp', - ] - - def read(self, filePath: Path) -> Scan: - pointList: list[ScanPoint] = list() - scanName = self.SIMPLE_NAME - - with filePath.open(newline='') as csvFile: - csvReader = csv.reader(csvFile, delimiter=' ') - csvIterator = iter(csvReader) - - titleRow = next(csvIterator) - - try: - scanName = ' '.join(titleRow).split(',', maxsplit=1)[0] - except IndexError: - raise ScanPointParseError('Bad scan name!') - - columnHeaderRow = next(csvIterator) - - if columnHeaderRow == LYNXSoftGlueZynqScanFileReader.EXPECTED_HEADER_RAW: - logger.debug(f'Reading raw scan positions for "{scanName}"...') - X = 1 - Y = 2 - DETECTOR_COUNT = 4 - elif columnHeaderRow == LYNXSoftGlueZynqScanFileReader.EXPECTED_HEADER_PROCESSED: - logger.debug(f'Reading processed scan positions for "{scanName}"...') - DETECTOR_COUNT = 0 - X = 1 - Y = 3 - else: - raise ScanPointParseError('Bad header!') - - for row in csvIterator: - if row[0].startswith('#'): - continue - - if len(row) != len(columnHeaderRow): - raise ScanPointParseError('Bad number of columns!') - - point = ScanPoint( - int(row[DETECTOR_COUNT]), - -float(row[X]) * LYNXSoftGlueZynqScanFileReader.MICRONS_TO_METERS, - -float(row[Y]) * LYNXSoftGlueZynqScanFileReader.MICRONS_TO_METERS, - ) - pointList.append(point) - - return Scan(pointList) - - -def registerPlugins(registry: PluginRegistry) -> None: - registry.scanFileReaders.registerPlugin( - LYNXSoftGlueZynqScanFileReader(), - simpleName=LYNXSoftGlueZynqScanFileReader.SIMPLE_NAME, - displayName='LYNX SoftGlueZynq Scan Files (*.dat)', - ) diff --git a/src/ptychodus/plugins/lynx_diffraction_file.py b/src/ptychodus/plugins/lynx_diffraction_file.py new file mode 100644 index 00000000..2ade02f9 --- /dev/null +++ b/src/ptychodus/plugins/lynx_diffraction_file.py @@ -0,0 +1,69 @@ +from pathlib import Path +import logging + +import h5py +import numpy + +from ptychodus.api.geometry import ImageExtent, PixelGeometry +from ptychodus.api.patterns import ( + DiffractionDataset, + DiffractionFileReader, + DiffractionMetadata, + SimpleDiffractionDataset, +) +from ptychodus.api.plugins import PluginRegistry + +from .h5_diffraction_file import H5DiffractionPatternArray, H5DiffractionFileTreeBuilder + +logger = logging.getLogger(__name__) + + +class LYNXDiffractionFileReader(DiffractionFileReader): + def __init__(self) -> None: + self._data_path = '/entry/data/eiger_4' + self._tree_builder = H5DiffractionFileTreeBuilder() + + def read(self, file_path: Path) -> DiffractionDataset: + dataset = SimpleDiffractionDataset.create_null(file_path) + + try: + with h5py.File(file_path, 'r') as h5_file: + contents_tree = self._tree_builder.build(h5_file) + + try: + data = h5_file[self._data_path] + pixel_size = float(data.attrs['Pixel_size'].item()) + except KeyError: + logger.warning('Unable to load data.') + else: + num_patterns, detector_height, detector_width = data.shape + + metadata = DiffractionMetadata( + num_patterns_per_array=num_patterns, + num_patterns_total=num_patterns, + pattern_dtype=data.dtype, + detector_extent=ImageExtent(detector_width, detector_height), + detector_pixel_geometry=PixelGeometry(pixel_size, pixel_size), + file_path=file_path, + ) + + array = H5DiffractionPatternArray( + label=file_path.stem, + indexes=numpy.arange(num_patterns), + file_path=file_path, + data_path=self._data_path, + ) + + dataset = SimpleDiffractionDataset(metadata, contents_tree, [array]) + except OSError: + logger.warning(f'Unable to read file "{file_path}".') + + return dataset + + +def register_plugins(registry: PluginRegistry) -> None: + registry.diffraction_file_readers.register_plugin( + LYNXDiffractionFileReader(), + simple_name='APS_LYNX', + display_name='APS LYNX Files (*.h5 *.hdf5)', + ) diff --git a/src/ptychodus/plugins/lynxOrchestraScanFile.py b/src/ptychodus/plugins/lynx_orchestra_position_file.py similarity index 50% rename from src/ptychodus/plugins/lynxOrchestraScanFile.py rename to src/ptychodus/plugins/lynx_orchestra_position_file.py index d3507421..c0c191b3 100644 --- a/src/ptychodus/plugins/lynxOrchestraScanFile.py +++ b/src/ptychodus/plugins/lynx_orchestra_position_file.py @@ -4,13 +4,14 @@ import logging from ptychodus.api.plugins import PluginRegistry -from ptychodus.api.scan import Scan, ScanFileReader, ScanPoint, ScanPointParseError +from ptychodus.api.scan import PositionSequence, PositionFileReader, ScanPoint, ScanPointParseError logger = logging.getLogger(__name__) -class LYNXOrchestraScanFileReader(ScanFileReader): - SIMPLE_NAME: Final[str] = 'LYNXOrchestra' +class LYNXOrchestraPositionFileReader(PositionFileReader): + SIMPLE_NAME: Final[str] = 'APS_LYNX_Orchestra' + DISPLAY_NAME: Final[str] = 'APS LYNX Orchestra Files (*.dat)' MICRONS_TO_METERS: Final[float] = 1.0e-6 DATA_POINT_COLUMN: Final[int] = 0 X_COLUMN: Final[int] = 3 @@ -37,37 +38,37 @@ class LYNXOrchestraScanFileReader(ScanFileReader): 'Stdev_cap5', ] - def read(self, filePath: Path) -> Scan: - pointList: list[ScanPoint] = list() - scanName = self.SIMPLE_NAME + def read(self, file_path: Path) -> PositionSequence: + point_list: list[ScanPoint] = list() + scan_name = self.SIMPLE_NAME - with filePath.open(newline='') as csvFile: - csvReader = csv.reader(csvFile, delimiter=' ', skipinitialspace=True) - csvIterator = iter(csvReader) + with file_path.open(newline='') as csv_file: + csv_reader = csv.reader(csv_file, delimiter=' ', skipinitialspace=True) + csv_iterator = iter(csv_reader) - titleRow = next(csvIterator) + title_row = next(csv_iterator) try: - scanName = ' '.join(titleRow).split(',', maxsplit=1)[0] + scan_name = ' '.join(title_row).split(',', maxsplit=1)[0] except IndexError: raise ScanPointParseError('Bad scan name!') - columnHeaderRow = next(csvIterator) + column_header_row = next(csv_iterator) - if columnHeaderRow == LYNXOrchestraScanFileReader.EXPECTED_HEADER: - logger.debug(f'Reading scan positions for "{scanName}"...') + if column_header_row == LYNXOrchestraPositionFileReader.EXPECTED_HEADER: + logger.debug(f'Reading scan positions for "{scan_name}"...') else: raise ScanPointParseError( 'Bad LYNX Orchestra header!\n' - f'Expected: {LYNXOrchestraScanFileReader.EXPECTED_HEADER}\n' - f'Found: {columnHeaderRow}\n' + f'Expected: {LYNXOrchestraPositionFileReader.EXPECTED_HEADER}\n' + f'Found: {column_header_row}\n' ) - for row in csvIterator: + for row in csv_iterator: if row[0].startswith('#'): continue - if len(row) != len(columnHeaderRow): + if len(row) != len(column_header_row): raise ScanPointParseError('Bad number of columns!') point = ScanPoint( @@ -75,14 +76,14 @@ def read(self, filePath: Path) -> Scan: -float(row[self.X_COLUMN]) * self.MICRONS_TO_METERS, -float(row[self.Y_COLUMN]) * self.MICRONS_TO_METERS, ) - pointList.append(point) + point_list.append(point) - return Scan(pointList) + return PositionSequence(point_list) -def registerPlugins(registry: PluginRegistry) -> None: - registry.scanFileReaders.registerPlugin( - LYNXOrchestraScanFileReader(), - simpleName=LYNXOrchestraScanFileReader.SIMPLE_NAME, - displayName='LYNX Orchestra Scan Files (*.dat)', +def register_plugins(registry: PluginRegistry) -> None: + registry.position_file_readers.register_plugin( + LYNXOrchestraPositionFileReader(), + simple_name=LYNXOrchestraPositionFileReader.SIMPLE_NAME, + display_name=LYNXOrchestraPositionFileReader.DISPLAY_NAME, ) diff --git a/src/ptychodus/plugins/lynx_sgz_position_file.py b/src/ptychodus/plugins/lynx_sgz_position_file.py new file mode 100644 index 00000000..694a21f6 --- /dev/null +++ b/src/ptychodus/plugins/lynx_sgz_position_file.py @@ -0,0 +1,85 @@ +from pathlib import Path +from typing import Final +import csv +import logging + +from ptychodus.api.plugins import PluginRegistry +from ptychodus.api.scan import PositionSequence, PositionFileReader, ScanPoint, ScanPointParseError + +logger = logging.getLogger(__name__) + + +class LYNXSoftGlueZynqPositionFileReader(PositionFileReader): + SIMPLE_NAME: Final[str] = 'APS_LYNX_SoftGlueZynq' + DISPLAY_NAME: Final[str] = 'APS LYNX SoftGlueZynq Files (*.dat)' + MICRONS_TO_METERS: Final[float] = 1.0e-6 + + EXPECTED_HEADER_RAW: Final[list[str]] = [ + 'DataPoint', + 'x_st_fzp', + 'y_st_fzp', + 'ckUser_Clk_Count', + 'Detector_Count', + ] + + EXPECTED_HEADER_PROCESSED: Final[list[str]] = [ + 'Detector_Count', + 'Average_x_st_fzp', + 'Stdev_x_st_fzp', + 'Average_y_st_fzp', + 'Stdev_y_st_fzp', + ] + + def read(self, file_path: Path) -> PositionSequence: + point_list: list[ScanPoint] = list() + scan_name = self.SIMPLE_NAME + + with file_path.open(newline='') as csv_file: + csv_reader = csv.reader(csv_file, delimiter=' ') + csv_iterator = iter(csv_reader) + + title_row = next(csv_iterator) + + try: + scan_name = ' '.join(title_row).split(',', maxsplit=1)[0] + except IndexError: + raise ScanPointParseError('Bad scan name!') + + column_header_row = next(csv_iterator) + + if column_header_row == LYNXSoftGlueZynqPositionFileReader.EXPECTED_HEADER_RAW: + logger.debug(f'Reading raw scan positions for "{scan_name}"...') + X = 1 # noqa: N806 + Y = 2 # noqa: N806 + DETECTOR_COUNT = 4 # noqa: N806 + elif column_header_row == LYNXSoftGlueZynqPositionFileReader.EXPECTED_HEADER_PROCESSED: + logger.debug(f'Reading processed scan positions for "{scan_name}"...') + DETECTOR_COUNT = 0 # noqa: N806 + X = 1 # noqa: N806 + Y = 3 # noqa: N806 + else: + raise ScanPointParseError('Bad header!') + + for row in csv_iterator: + if row[0].startswith('#'): + continue + + if len(row) != len(column_header_row): + raise ScanPointParseError('Bad number of columns!') + + point = ScanPoint( + int(row[DETECTOR_COUNT]), + -float(row[X]) * LYNXSoftGlueZynqPositionFileReader.MICRONS_TO_METERS, + -float(row[Y]) * LYNXSoftGlueZynqPositionFileReader.MICRONS_TO_METERS, + ) + point_list.append(point) + + return PositionSequence(point_list) + + +def register_plugins(registry: PluginRegistry) -> None: + registry.position_file_readers.register_plugin( + LYNXSoftGlueZynqPositionFileReader(), + simple_name=LYNXSoftGlueZynqPositionFileReader.SIMPLE_NAME, + display_name=LYNXSoftGlueZynqPositionFileReader.DISPLAY_NAME, + ) diff --git a/src/ptychodus/plugins/matObjectFile.py b/src/ptychodus/plugins/matObjectFile.py deleted file mode 100644 index aaaa9a92..00000000 --- a/src/ptychodus/plugins/matObjectFile.py +++ /dev/null @@ -1,49 +0,0 @@ -from pathlib import Path - -import numpy -import scipy.io - -from ptychodus.api.object import Object, ObjectFileReader, ObjectFileWriter -from ptychodus.api.plugins import PluginRegistry - - -class MATObjectFileReader(ObjectFileReader): - def read(self, filePath: Path) -> Object: - matDict = scipy.io.loadmat(filePath) - array = matDict['object'] - - if array.ndim == 3: - # array[width, height, num_layers] - array = array.transpose(2, 0, 1) - - try: - p = matDict['p'][0, 0] - multi_slice_param = p['multi_slice_param'][0, 0] - layerDistanceInMeters = numpy.squeeze(multi_slice_param['z_distance']) - except ValueError: - object_ = Object(array) - else: - object_ = Object(array, layerDistanceInMeters) - - return object_ - - -class MATObjectFileWriter(ObjectFileWriter): - def write(self, filePath: Path, object_: Object) -> None: - array = object_.array - matDict = {'object': array.transpose(1, 2, 0)} - # TODO layer distance to p.z_distance - scipy.io.savemat(filePath, matDict) - - -def registerPlugins(registry: PluginRegistry) -> None: - registry.objectFileReaders.registerPlugin( - MATObjectFileReader(), - simpleName='MAT', - displayName='MAT Files (*.mat)', - ) - registry.objectFileWriters.registerPlugin( - MATObjectFileWriter(), - simpleName='MAT', - displayName='MAT Files (*.mat)', - ) diff --git a/src/ptychodus/plugins/matProbeFile.py b/src/ptychodus/plugins/matProbeFile.py deleted file mode 100644 index c72613bb..00000000 --- a/src/ptychodus/plugins/matProbeFile.py +++ /dev/null @@ -1,42 +0,0 @@ -from pathlib import Path - -import scipy.io - -from ptychodus.api.plugins import PluginRegistry -from ptychodus.api.probe import Probe, ProbeFileReader, ProbeFileWriter - - -class MATProbeFileReader(ProbeFileReader): - def read(self, filePath: Path) -> Probe: - matDict = scipy.io.loadmat(filePath) - array = matDict['probe'] - - if array.ndim == 4: - # array[width, height, num_shared_modes, num_varying_modes] - array = array[..., 0] - - if array.ndim == 3: - # array[width, height, num_shared_modes] - array = array.transpose(2, 0, 1) - - return Probe(array) - - -class MATProbeFileWriter(ProbeFileWriter): - def write(self, filePath: Path, probe: Probe) -> None: - array = probe.array - matDict = {'probe': array.transpose(1, 2, 0)} - scipy.io.savemat(filePath, matDict) - - -def registerPlugins(registry: PluginRegistry) -> None: - registry.probeFileReaders.registerPlugin( - MATProbeFileReader(), - simpleName='MAT', - displayName='MAT Files (*.mat)', - ) - registry.probeFileWriters.registerPlugin( - MATProbeFileWriter(), - simpleName='MAT', - displayName='MAT Files (*.mat)', - ) diff --git a/src/ptychodus/plugins/mdaScanFile.py b/src/ptychodus/plugins/mda_position_file.py similarity index 73% rename from src/ptychodus/plugins/mdaScanFile.py rename to src/ptychodus/plugins/mda_position_file.py index 8ae4728b..8185f204 100644 --- a/src/ptychodus/plugins/mdaScanFile.py +++ b/src/ptychodus/plugins/mda_position_file.py @@ -11,9 +11,10 @@ import yaml import numpy +import numpy.typing from ptychodus.api.plugins import PluginRegistry -from ptychodus.api.scan import Scan, ScanFileReader, ScanPoint +from ptychodus.api.scan import PositionSequence, PositionFileReader, ScanPoint T = TypeVar('T') @@ -298,11 +299,11 @@ class MDAScanData: @classmethod def read( - cls, fp: typing.BinaryIO, scanHeader: MDAScanHeader, scanInfo: MDAScanInfo + cls, fp: typing.BinaryIO, scan_header: MDAScanHeader, scan_info: MDAScanInfo ) -> MDAScanData: - npts = scanHeader.num_requested_points - np = scanInfo.num_positioners - nd = scanInfo.num_detectors + npts = scan_header.num_requested_points + np = scan_info.num_positioners + nd = scan_info.num_detectors unpacker = xdrlib.Unpacker(fp.read(8 * np * npts)) readback_lol = [unpacker.unpack_farray(npts, unpacker.unpack_double) for p in range(np)] @@ -358,7 +359,7 @@ def __str__(self) -> str: class MDAProcessVariable(Generic[T]): name: str description: str - epicsType: EpicsType + epics_type: EpicsType unit: str value: T @@ -366,7 +367,7 @@ def to_mapping(self) -> Mapping[str, Any]: return { 'name': self.name, 'description': self.description, - 'epicsType': self.epicsType.name, + 'epicsType': self.epics_type.name, 'unit': self.unit, 'value': self.value, } @@ -380,42 +381,42 @@ class MDAFile: @staticmethod def _read_pv(unpacker: xdrlib.Unpacker) -> MDAProcessVariable[typing.Any]: - pvName = read_counted_string(unpacker) - pvDesc = read_counted_string(unpacker) - pvType = EpicsType(unpacker.unpack_int()) + pv_name = read_counted_string(unpacker) + pv_desc = read_counted_string(unpacker) + pv_type = EpicsType(unpacker.unpack_int()) - if pvType == EpicsType.DBR_STRING: - valueStr = read_counted_string(unpacker) - return MDAProcessVariable[str](pvName, pvDesc, pvType, str(), valueStr) + if pv_type == EpicsType.DBR_STRING: + value_str = read_counted_string(unpacker) + return MDAProcessVariable[str](pv_name, pv_desc, pv_type, str(), value_str) count = unpacker.unpack_int() - pvUnit = read_counted_string(unpacker) - - if pvType == EpicsType.DBR_CTRL_CHAR: - valueChar = unpacker.unpack_fstring(count).decode() - valueChar = valueChar.split('\x00', 1)[0] # treat as null-terminated string - return MDAProcessVariable[str](pvName, pvDesc, pvType, pvUnit, valueChar) - elif pvType == EpicsType.DBR_CTRL_SHORT: - valueShort = unpacker.unpack_farray(count, unpacker.unpack_int) - return MDAProcessVariable[list[int]](pvName, pvDesc, pvType, pvUnit, valueShort) - elif pvType == EpicsType.DBR_CTRL_LONG: - valueLong = unpacker.unpack_farray(count, unpacker.unpack_int) - return MDAProcessVariable[list[int]](pvName, pvDesc, pvType, pvUnit, valueLong) - elif pvType == EpicsType.DBR_CTRL_FLOAT: - valueFloat = unpacker.unpack_farray(count, unpacker.unpack_float) - return MDAProcessVariable[list[float]](pvName, pvDesc, pvType, pvUnit, valueFloat) - elif pvType == EpicsType.DBR_CTRL_DOUBLE: - valueDouble = unpacker.unpack_farray(count, unpacker.unpack_double) - return MDAProcessVariable[list[float]](pvName, pvDesc, pvType, pvUnit, valueDouble) - - return MDAProcessVariable[str](pvName, pvDesc, pvType, pvUnit, str()) + pv_unit = read_counted_string(unpacker) + + if pv_type == EpicsType.DBR_CTRL_CHAR: + value_char = unpacker.unpack_fstring(count).decode() + value_char = value_char.split('\x00', 1)[0] # treat as null-terminated string + return MDAProcessVariable[str](pv_name, pv_desc, pv_type, pv_unit, value_char) + elif pv_type == EpicsType.DBR_CTRL_SHORT: + value_short = unpacker.unpack_farray(count, unpacker.unpack_int) + return MDAProcessVariable[list[int]](pv_name, pv_desc, pv_type, pv_unit, value_short) + elif pv_type == EpicsType.DBR_CTRL_LONG: + value_long = unpacker.unpack_farray(count, unpacker.unpack_int) + return MDAProcessVariable[list[int]](pv_name, pv_desc, pv_type, pv_unit, value_long) + elif pv_type == EpicsType.DBR_CTRL_FLOAT: + value_float = unpacker.unpack_farray(count, unpacker.unpack_float) + return MDAProcessVariable[list[float]](pv_name, pv_desc, pv_type, pv_unit, value_float) + elif pv_type == EpicsType.DBR_CTRL_DOUBLE: + value_double = unpacker.unpack_farray(count, unpacker.unpack_double) + return MDAProcessVariable[list[float]](pv_name, pv_desc, pv_type, pv_unit, value_double) + + return MDAProcessVariable[str](pv_name, pv_desc, pv_type, pv_unit, str()) @classmethod - def read(cls, filePath: Path) -> MDAFile: + def read(cls, file_path: Path) -> MDAFile: extra_pvs: list[MDAProcessVariable[Any]] = list() try: - with filePath.open(mode='rb') as fp: + with file_path.open(mode='rb') as fp: header = MDAHeader.read(fp) scan = MDAScan.read(fp) @@ -443,15 +444,16 @@ def __str__(self) -> str: return yaml.safe_dump(self.to_mapping(), sort_keys=False) -class MDAScanFileReader(ScanFileReader): - MICRONS_TO_METERS: Final[float] = 1.0e-6 +class MDAPositionFileReader(PositionFileReader): + def __init__(self, scale_to_meters: float) -> None: + self._scale_to_meters = scale_to_meters - def read(self, filePath: Path) -> Scan: - pointList: list[ScanPoint] = list() + def read(self, file_path: Path) -> PositionSequence: + point_list: list[ScanPoint] = list() - mdaFile = MDAFile.read(filePath) + mda_file = MDAFile.read(file_path) - yscan = mdaFile.scan + yscan = mda_file.scan yarray = yscan.data.readback_array[0, :] for y, xscan in zip(yarray, yscan.lower_scans): @@ -459,56 +461,66 @@ def read(self, filePath: Path) -> Scan: for x in xarray: point = ScanPoint( - index=len(pointList), - positionXInMeters=x * self.MICRONS_TO_METERS, - positionYInMeters=y * self.MICRONS_TO_METERS, + index=len(point_list), + position_x_m=float(x) * self._scale_to_meters, + position_y_m=float(y) * self._scale_to_meters, ) - pointList.append(point) + point_list.append(point) - return Scan(pointList) + return PositionSequence(point_list) -class HXNScanFileReader(ScanFileReader): +class HXNPositionFileReader(PositionFileReader): MICRONS_TO_METERS: Final[float] = 1.0e-6 - def read(self, filePath: Path) -> Scan: - pointList = list() + def read(self, file_path: Path) -> PositionSequence: + point_list = list() - mdaFile = MDAFile.read(filePath) + mda_file = MDAFile.read(file_path) - xarray = mdaFile.scan.data.readback_array[0, :] - yarray = mdaFile.scan.data.readback_array[1, :] + xarray = mda_file.scan.data.readback_array[0, :] + yarray = mda_file.scan.data.readback_array[1, :] for idx, (x, y) in enumerate(zip(xarray, yarray)): point = ScanPoint( index=idx, - positionXInMeters=x * self.MICRONS_TO_METERS, - positionYInMeters=y * self.MICRONS_TO_METERS, + position_x_m=float(x) * self.MICRONS_TO_METERS, + position_y_m=float(y) * self.MICRONS_TO_METERS, ) - pointList.append(point) + point_list.append(point) - return Scan(pointList) + return PositionSequence(point_list) -def registerPlugins(registry: PluginRegistry) -> None: - registry.scanFileReaders.registerPlugin( - MDAScanFileReader(), - simpleName='MDA', - displayName='EPICS MDA Files (*.mda)', +def register_plugins(registry: PluginRegistry) -> None: + registry.position_file_readers.register_plugin( + MDAPositionFileReader(scale_to_meters=1.0e-6), + simple_name='MDA', + display_name='EPICS MDA Files (*.mda)', + ) + registry.position_file_readers.register_plugin( + MDAPositionFileReader(scale_to_meters=1.0e-3), + simple_name='APS_2IDD', + display_name='APS 2-ID-D Files (*.mda)', + ) + registry.position_file_readers.register_plugin( + MDAPositionFileReader(scale_to_meters=1.0e-3), + simple_name='APS_2IDE', + display_name='APS 2-ID-E Microprobe Files (*.mda)', ) - registry.scanFileReaders.registerPlugin( - MDAScanFileReader(), - simpleName='APS_2ID', - displayName='APS 2-ID MDA Files (*.mda)', + registry.position_file_readers.register_plugin( + MDAPositionFileReader(scale_to_meters=1.0e-6), + simple_name='APS_BNP', + display_name='APS Bionanoprobe Files (*.h5 *.hdf5)', ) - registry.scanFileReaders.registerPlugin( - HXNScanFileReader(), - simpleName='APS_HXN', - displayName='CNM/APS HXN Scan Files (*.mda)', + registry.position_file_readers.register_plugin( + HXNPositionFileReader(), + simple_name='CNM_APS_HXN', + display_name='CNM/APS HXN Files (*.mda)', ) if __name__ == '__main__': - filePath = Path(sys.argv[1]) - mdaFile = MDAFile.read(filePath) - print(mdaFile) + file_path = Path(sys.argv[1]) + mda_file = MDAFile.read(file_path) + print(mda_file) diff --git a/src/ptychodus/plugins/nanoMaxScanFile.py b/src/ptychodus/plugins/nanoMaxScanFile.py deleted file mode 100644 index 35e9e3c1..00000000 --- a/src/ptychodus/plugins/nanoMaxScanFile.py +++ /dev/null @@ -1,47 +0,0 @@ -from pathlib import Path -from typing import Final -import logging - -import h5py - -from ptychodus.api.plugins import PluginRegistry -from ptychodus.api.scan import Scan, ScanFileReader, ScanPoint, ScanPointParseError - -logger = logging.getLogger(__name__) - - -class NanoMaxScanFileReader(ScanFileReader): - MICRONS_TO_METERS: Final[float] = 1.0e-6 - - def read(self, filePath: Path) -> Scan: - pointList: list[ScanPoint] = list() - - with h5py.File(filePath, 'r') as h5File: - try: - positionX = h5File['/entry/measurement/pseudo/x'][()] - positionY = h5File['/entry/measurement/pseudo/y'][()] - except KeyError: - logger.exception('Unable to load scan.') - else: - if positionX.shape == positionY.shape: - logger.debug(f'Coordinate arrays have shape {positionX.shape}.') - else: - raise ScanPointParseError('Coordinate array shape mismatch!') - - for idx, (x, y) in enumerate(zip(positionX, positionY)): - point = ScanPoint( - idx, - x * self.MICRONS_TO_METERS, - y * self.MICRONS_TO_METERS, - ) - pointList.append(point) - - return Scan(pointList) - - -def registerPlugins(registry: PluginRegistry) -> None: - registry.scanFileReaders.registerPlugin( - NanoMaxScanFileReader(), - simpleName='NanoMax', - displayName='NanoMax DiffractionEndStation Scan Files (*.h5 *.hdf5)', - ) diff --git a/src/ptychodus/plugins/nanomax_position_file.py b/src/ptychodus/plugins/nanomax_position_file.py new file mode 100644 index 00000000..6e3788d4 --- /dev/null +++ b/src/ptychodus/plugins/nanomax_position_file.py @@ -0,0 +1,47 @@ +from pathlib import Path +from typing import Final +import logging + +import h5py + +from ptychodus.api.plugins import PluginRegistry +from ptychodus.api.scan import PositionSequence, PositionFileReader, ScanPoint, ScanPointParseError + +logger = logging.getLogger(__name__) + + +class NanoMaxPositionFileReader(PositionFileReader): + MICRONS_TO_METERS: Final[float] = 1.0e-6 + + def read(self, file_path: Path) -> PositionSequence: + point_list: list[ScanPoint] = list() + + with h5py.File(file_path, 'r') as h5_file: + try: + position_x = h5_file['/entry/measurement/pseudo/x'][()] + position_y = h5_file['/entry/measurement/pseudo/y'][()] + except KeyError: + logger.exception('Unable to load scan.') + else: + if position_x.shape == position_y.shape: + logger.debug(f'Coordinate arrays have shape {position_x.shape}.') + else: + raise ScanPointParseError('Coordinate array shape mismatch!') + + for idx, (x, y) in enumerate(zip(position_x, position_y)): + point = ScanPoint( + idx, + x * self.MICRONS_TO_METERS, + y * self.MICRONS_TO_METERS, + ) + point_list.append(point) + + return PositionSequence(point_list) + + +def register_plugins(registry: PluginRegistry) -> None: + registry.position_file_readers.register_plugin( + NanoMaxPositionFileReader(), + simple_name='MAX_IV_NanoMAX', + display_name='MAX IV NanoMAX DiffractionEndStation Files (*.h5 *.hdf5)', + ) diff --git a/src/ptychodus/plugins/neXus/__init__.py b/src/ptychodus/plugins/neXus/__init__.py deleted file mode 100644 index 6fba6d4a..00000000 --- a/src/ptychodus/plugins/neXus/__init__.py +++ /dev/null @@ -1,24 +0,0 @@ -from ptychodus.api.plugins import PluginRegistry - -from .neXusDiffractionFile import NeXusDiffractionFileReader -from .velociprobeScanFile import VelociprobeScanFileReader - - -def registerPlugins(registry: PluginRegistry) -> None: - neXusFileReader = NeXusDiffractionFileReader() - - registry.diffractionFileReaders.registerPlugin( - neXusFileReader, - simpleName='NeXus', - displayName='NeXus Master Files (*.h5 *.hdf5)', - ) - registry.scanFileReaders.registerPlugin( - VelociprobeScanFileReader.createLaserInterferometerInstance(neXusFileReader), - simpleName='VelociprobeLaserInterferometer', - displayName='Velociprobe Scan Files - Laser Interferometer (*.txt)', - ) - registry.scanFileReaders.registerPlugin( - VelociprobeScanFileReader.createPositionEncoderInstance(neXusFileReader), - simpleName='VelociprobePositionEncoder', - displayName='Velociprobe Scan Files - Position Encoder (*.txt)', - ) diff --git a/src/ptychodus/plugins/neXus/neXusDiffractionFile.py b/src/ptychodus/plugins/neXus/neXusDiffractionFile.py deleted file mode 100644 index f5a56ec1..00000000 --- a/src/ptychodus/plugins/neXus/neXusDiffractionFile.py +++ /dev/null @@ -1,273 +0,0 @@ -from __future__ import annotations -from collections.abc import Iterator, Sequence -from dataclasses import dataclass, field -from pathlib import Path -from typing import overload -import logging - -import h5py - -from ptychodus.api.geometry import ImageExtent, PixelGeometry -from ptychodus.api.patterns import ( - CropCenter, - DiffractionDataset, - DiffractionFileReader, - DiffractionMetadata, - DiffractionPatternArray, - SimpleDiffractionDataset, -) -from ptychodus.api.tree import SimpleTreeNode - -from ..h5DiffractionFile import H5DiffractionPatternArray, H5DiffractionFileTreeBuilder - -logger = logging.getLogger(__name__) - - -@dataclass(frozen=True) -class DataGroup: - arrayList: list[DiffractionPatternArray] = field(default_factory=list) - - @classmethod - def read(cls, group: h5py.Group) -> DataGroup: - arrayList: list[DiffractionPatternArray] = list() - masterFilePath = Path(group.file.filename) - - for name, h5Item in group.items(): - h5Item = group.get(name, getlink=True) - - if isinstance(h5Item, h5py.ExternalLink): - filePath = masterFilePath.parent / h5Item.filename - dataPath = str(h5Item.path) - # TODO use entry/data/data/image_nr_{low,high} - array = H5DiffractionPatternArray(name, len(arrayList), filePath, dataPath) - arrayList.append(array) - - return cls(arrayList) - - def __iter__(self) -> Iterator[DiffractionPatternArray]: - return iter(self.arrayList) - - @overload - def __getitem__(self, index: int) -> DiffractionPatternArray: ... - - @overload - def __getitem__(self, index: slice) -> Sequence[DiffractionPatternArray]: ... - - def __getitem__( - self, index: int | slice - ) -> DiffractionPatternArray | Sequence[DiffractionPatternArray]: - return self.arrayList[index] - - def __len__(self) -> int: - return len(self.arrayList) - - -@dataclass(frozen=True) -class DetectorSpecificGroup: - nimages: int - ntrigger: int - photonEnergyInElectronVolts: float - xPixelsInDetector: int - yPixelsInDetector: int - - @property - def numberOfPatternsTotal(self) -> int: - return max(self.nimages, self.ntrigger) - - @classmethod - def read(cls, group: h5py.Group) -> DetectorSpecificGroup: - nimages = group['nimages'] - ntrigger = group['ntrigger'] - photonEnergy = group['photon_energy'] - assert photonEnergy.attrs['units'] == b'eV' - xPixelsInDetector = group['x_pixels_in_detector'] - yPixelsInDetector = group['y_pixels_in_detector'] - return cls( - int(nimages[()]), - int(ntrigger[()]), - float(photonEnergy[()]), - int(xPixelsInDetector[()]), - int(yPixelsInDetector[()]), - ) - - -@dataclass(frozen=True) -class DetectorGroup: - detectorSpecific: DetectorSpecificGroup - detectorDistanceInMeters: float - beamCenterXInPixels: int - beamCenterYInPixels: int - bitDepthReadout: int - xPixelSizeInMeters: float - yPixelSizeInMeters: float - - @classmethod - def read(cls, group: h5py.Group) -> DetectorGroup: - detectorSpecific = DetectorSpecificGroup.read(group['detectorSpecific']) - h5DetectorDistance = group['detector_distance'] - assert h5DetectorDistance.attrs['units'] == b'm' - h5BeamCenterX = group['beam_center_x'] - assert h5BeamCenterX.attrs['units'] == b'pixel' - h5BeamCenterY = group['beam_center_y'] - assert h5BeamCenterY.attrs['units'] == b'pixel' - h5BitDepthReadout = group['bit_depth_readout'] - h5XPixelSize = group['x_pixel_size'] - assert h5XPixelSize.attrs['units'] == b'm' - h5YPixelSize = group['y_pixel_size'] - assert h5YPixelSize.attrs['units'] == b'm' - return cls( - detectorSpecific, - float(h5DetectorDistance[()]), - int(h5BeamCenterX[()]), - int(h5BeamCenterY[()]), - int(h5BitDepthReadout[()]), - float(h5XPixelSize[()]), - float(h5YPixelSize[()]), - ) - - -@dataclass(frozen=True) -class InstrumentGroup: - detector: DetectorGroup - - @classmethod - def read(cls, group: h5py.Group) -> InstrumentGroup: - detector = DetectorGroup.read(group['detector']) - return cls(detector) - - -@dataclass(frozen=True) -class GoniometerGroup: - chiDeg: float - - @classmethod - def read(cls, group: h5py.Group) -> GoniometerGroup: - chiItem = group['chi'] - chiSpace = chiItem.id.get_space() - - assert chiItem.attrs['units'] == b'degree' - - if chiSpace.get_simple_extent_type() == h5py.h5s.SCALAR: - chiDeg = float(chiItem[()]) - elif isinstance(chiItem, h5py.Dataset): - chiDeg = float(chiItem[0]) - else: - raise ValueError('Failed to read goniometer angle (chi)!') - - return cls(chiDeg) - - -@dataclass(frozen=True) -class SampleGroup: - goniometer: GoniometerGroup - - @classmethod - def read(cls, group: h5py.Group) -> SampleGroup: - goniometer = GoniometerGroup.read(group['goniometer']) - return cls(goniometer) - - -@dataclass(frozen=True) -class EntryGroup: - data: DataGroup - instrument: InstrumentGroup - sample: SampleGroup - - @classmethod - def read(cls, group: h5py.Group) -> EntryGroup: - data = DataGroup.read(group['data']) - instrument = InstrumentGroup.read(group['instrument']) - sample = SampleGroup.read(group['sample']) - return cls(data, instrument, sample) - - -class NeXusDiffractionDataset(DiffractionDataset): - def __init__( - self, - metadata: DiffractionMetadata, - contentsTree: SimpleTreeNode, - entry: EntryGroup, - ) -> None: - self._metadata = metadata - self._contentsTree = contentsTree - self._entry = entry - - def getMetadata(self) -> DiffractionMetadata: - return self._metadata - - def getContentsTree(self) -> SimpleTreeNode: - return self._contentsTree - - @overload - def __getitem__(self, index: int) -> DiffractionPatternArray: ... - - @overload - def __getitem__(self, index: slice) -> Sequence[DiffractionPatternArray]: ... - - def __getitem__( - self, index: int | slice - ) -> DiffractionPatternArray | Sequence[DiffractionPatternArray]: - return self._entry.data[index] - - def __len__(self) -> int: - return len(self._entry.data) - - -class NeXusDiffractionFileReader(DiffractionFileReader): - def __init__(self) -> None: - super().__init__() - self._treeBuilder = H5DiffractionFileTreeBuilder() - self.stageRotationInDegrees = 0.0 # TODO This is a hack; remove when able! - - def read(self, filePath: Path) -> DiffractionDataset: - dataset: DiffractionDataset = SimpleDiffractionDataset.createNullInstance(filePath) - - try: - with h5py.File(filePath, 'r') as h5File: - metadata = DiffractionMetadata.createNullInstance(filePath) - contentsTree = self._treeBuilder.build(h5File) - - try: - entry = EntryGroup.read(h5File['entry']) - h5Dataset = h5File['/entry/data/data_000001'] - except KeyError: - logger.warning(f'File {filePath} is not a NeXus data file.') - else: - detector = entry.instrument.detector - detectorPixelGeometry = PixelGeometry( - detector.xPixelSizeInMeters, - detector.yPixelSizeInMeters, - ) - cropCenter = CropCenter( - detector.beamCenterXInPixels, - detector.beamCenterYInPixels, - ) - - detectorSpecific = detector.detectorSpecific - detectorExtent = ImageExtent( - detectorSpecific.xPixelsInDetector, - detectorSpecific.yPixelsInDetector, - ) - probeEnergyInElectronVolts = detectorSpecific.photonEnergyInElectronVolts - - metadata = DiffractionMetadata( - numberOfPatternsPerArray=h5Dataset.shape[0], - numberOfPatternsTotal=detectorSpecific.numberOfPatternsTotal, - patternDataType=h5Dataset.dtype, - detectorDistanceInMeters=detector.detectorDistanceInMeters, - detectorExtent=detectorExtent, - detectorPixelGeometry=detectorPixelGeometry, - detectorBitDepth=detector.bitDepthReadout, - cropCenter=cropCenter, - probeEnergyInElectronVolts=probeEnergyInElectronVolts, - filePath=filePath, - ) - - dataset = NeXusDiffractionDataset(metadata, contentsTree, entry) - - # vvv TODO This is a hack; remove when able! vvv - self.stageRotationInDegrees = entry.sample.goniometer.chiDeg - except OSError: - logger.warning(f'Unable to read file "{filePath}".') - - return dataset diff --git a/src/ptychodus/plugins/neXus/velociprobeScanFile.py b/src/ptychodus/plugins/neXus/velociprobeScanFile.py deleted file mode 100644 index 898214b1..00000000 --- a/src/ptychodus/plugins/neXus/velociprobeScanFile.py +++ /dev/null @@ -1,91 +0,0 @@ -from __future__ import annotations -from enum import IntEnum -from pathlib import Path -from typing import Final -import csv - -import numpy - -from ptychodus.api.scan import Scan, ScanFileReader, ScanPoint, ScanPointParseError -from .neXusDiffractionFile import NeXusDiffractionFileReader - -__all__ = [ - 'VelociprobeScanFileReader', -] - - -class VelociprobeScanFileColumn(IntEnum): - X = 1 - LASER_INTERFEROMETER_Y = 2 - POSITION_ENCODER_Y = 5 - TRIGGER = 7 - - -class VelociprobeScanFileReader(ScanFileReader): - NANOMETERS_TO_METERS: Final[float] = 1.0e-9 - - def __init__(self, neXusReader: NeXusDiffractionFileReader, yColumn: int) -> None: - self._neXusReader = neXusReader - self._yColumn = yColumn - - @classmethod - def createLaserInterferometerInstance( - cls, neXusReader: NeXusDiffractionFileReader - ) -> VelociprobeScanFileReader: - return cls(neXusReader, VelociprobeScanFileColumn.LASER_INTERFEROMETER_Y) - - @classmethod - def createPositionEncoderInstance( - cls, neXusReader: NeXusDiffractionFileReader - ) -> VelociprobeScanFileReader: - return cls(neXusReader, VelociprobeScanFileColumn.POSITION_ENCODER_Y) - - def _applyTransform(self, scan: Scan) -> Scan: - stageRotationInRadians = numpy.deg2rad(self._neXusReader.stageRotationInDegrees) - stageRotationCosine = numpy.cos(stageRotationInRadians) - - xMean = sum(p.positionXInMeters for p in scan) / len(scan) - yMean = sum(p.positionYInMeters for p in scan) / len(scan) - pointList: list[ScanPoint] = list() - - for untransformedPoint in scan: - point = ScanPoint( - untransformedPoint.index, - (untransformedPoint.positionXInMeters - xMean) * stageRotationCosine, - (untransformedPoint.positionYInMeters - yMean), - ) - pointList.append(point) - - return Scan(pointList) - - def read(self, filePath: Path) -> Scan: - pointList: list[ScanPoint] = list() - minimumColumnCount = max(col.value for col in VelociprobeScanFileColumn) + 1 - - with filePath.open(newline='') as csvFile: - csvReader = csv.reader(csvFile, delimiter=',') - - for row in csvReader: - if row[0].startswith('#'): - continue - - if len(row) < minimumColumnCount: - raise ScanPointParseError('Bad number of columns!') - - trigger = int(row[VelociprobeScanFileColumn.TRIGGER]) - x_nm = int(row[VelociprobeScanFileColumn.X]) - y_nm = int(row[self._yColumn]) - - if self._yColumn == VelociprobeScanFileColumn.POSITION_ENCODER_Y: - y_nm = -y_nm - - point = ScanPoint( - trigger, - x_nm * self.NANOMETERS_TO_METERS, - y_nm * self.NANOMETERS_TO_METERS, - ) - pointList.append(point) - - rawScan = Scan(pointList) - - return self._applyTransform(rawScan) diff --git a/src/ptychodus/plugins/nexus/__init__.py b/src/ptychodus/plugins/nexus/__init__.py new file mode 100644 index 00000000..fcfcd4c2 --- /dev/null +++ b/src/ptychodus/plugins/nexus/__init__.py @@ -0,0 +1,24 @@ +from ptychodus.api.plugins import PluginRegistry + +from .nexus_diffraction_file import NeXusDiffractionFileReader +from .velociprobe_position_file import VelociprobePositionFileReader + + +def register_plugins(registry: PluginRegistry) -> None: + nexus_file_reader = NeXusDiffractionFileReader() + + registry.diffraction_file_readers.register_plugin( + nexus_file_reader, + simple_name='NeXus', + display_name='NeXus Master Files (*.h5 *.hdf5)', + ) + registry.position_file_readers.register_plugin( + VelociprobePositionFileReader.create_laser_interferometer_instance(nexus_file_reader), + simple_name='APS_Velociprobe-LI', + display_name='APS Velociprobe Files - Laser Interferometer (*.txt)', + ) + registry.position_file_readers.register_plugin( + VelociprobePositionFileReader.create_position_encoder_instance(nexus_file_reader), + simple_name='APS_Velociprobe-PE', + display_name='APS Velociprobe Files - Position Encoder (*.txt)', + ) diff --git a/src/ptychodus/plugins/nexus/nexus_diffraction_file.py b/src/ptychodus/plugins/nexus/nexus_diffraction_file.py new file mode 100644 index 00000000..9dc62319 --- /dev/null +++ b/src/ptychodus/plugins/nexus/nexus_diffraction_file.py @@ -0,0 +1,286 @@ +from __future__ import annotations +from collections.abc import Iterator, Sequence +from dataclasses import dataclass, field +from pathlib import Path +from typing import overload +import logging + +import h5py +import numpy + +from ptychodus.api.geometry import ImageExtent, PixelGeometry +from ptychodus.api.patterns import ( + CropCenter, + DiffractionDataset, + DiffractionFileReader, + DiffractionMetadata, + DiffractionPatternArray, + SimpleDiffractionDataset, +) +from ptychodus.api.tree import SimpleTreeNode + +from ..h5_diffraction_file import H5DiffractionPatternArray, H5DiffractionFileTreeBuilder + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class DataGroup: + array_list: list[DiffractionPatternArray] = field(default_factory=list) + + @classmethod + def read(cls, group: h5py.Group, num_patterns_per_array: int) -> DataGroup: + array_list: list[DiffractionPatternArray] = list() + master_file_path = Path(group.file.filename) + + for name, h5_item in sorted(group.items()): + h5_item = group.get(name, getlink=True) + + if isinstance(h5_item, h5py.ExternalLink): + array = H5DiffractionPatternArray( + label=name, + indexes=numpy.arange(num_patterns_per_array) + + len(array_list) * num_patterns_per_array, + file_path=master_file_path.parent / h5_item.filename, + data_path=str(h5_item.path), + ) + array_list.append(array) + + return cls(array_list) + + def __iter__(self) -> Iterator[DiffractionPatternArray]: + return iter(self.array_list) + + @overload + def __getitem__(self, index: int) -> DiffractionPatternArray: ... + + @overload + def __getitem__(self, index: slice) -> Sequence[DiffractionPatternArray]: ... + + def __getitem__( + self, index: int | slice + ) -> DiffractionPatternArray | Sequence[DiffractionPatternArray]: + return self.array_list[index] + + def __len__(self) -> int: + return len(self.array_list) + + +@dataclass(frozen=True) +class DetectorSpecificGroup: + nimages: int + ntrigger: int + photon_energy_eV: float # noqa: N815 + x_pixels_in_detector: int + y_pixels_in_detector: int + + @property + def num_patterns_total(self) -> int: + return max(self.nimages, self.ntrigger) + + @classmethod + def read(cls, group: h5py.Group) -> DetectorSpecificGroup: + nimages = group['nimages'] + ntrigger = group['ntrigger'] + photon_energy = group['photon_energy'] + assert photon_energy.attrs['units'] == b'eV' + x_pixels_in_detector = group['x_pixels_in_detector'] + y_pixels_in_detector = group['y_pixels_in_detector'] + return cls( + int(nimages[()]), + int(ntrigger[()]), + float(photon_energy[()]), + int(x_pixels_in_detector[()]), + int(y_pixels_in_detector[()]), + ) + + +@dataclass(frozen=True) +class DetectorGroup: + detector_specific: DetectorSpecificGroup + detector_distance_m: float + beam_center_x_px: int + beam_center_y_px: int + bit_depth_readout: int + x_pixel_size_m: float + y_pixel_size_m: float + + @classmethod + def read(cls, group: h5py.Group) -> DetectorGroup: + detector_specific = DetectorSpecificGroup.read(group['detectorSpecific']) + h5_detector_distance = group['detector_distance'] + assert h5_detector_distance.attrs['units'] == b'm' + h5_beam_center_x = group['beam_center_x'] + assert h5_beam_center_x.attrs['units'] == b'pixel' + h5_beam_center_y = group['beam_center_y'] + assert h5_beam_center_y.attrs['units'] == b'pixel' + h5_bit_depth_readout = group['bit_depth_readout'] + h5_x_pixel_size = group['x_pixel_size'] + assert h5_x_pixel_size.attrs['units'] == b'm' + h5_y_pixel_size = group['y_pixel_size'] + assert h5_y_pixel_size.attrs['units'] == b'm' + return cls( + detector_specific, + float(h5_detector_distance[()]), + int(h5_beam_center_x[()]), + int(h5_beam_center_y[()]), + int(h5_bit_depth_readout[()]), + float(h5_x_pixel_size[()]), + float(h5_y_pixel_size[()]), + ) + + +@dataclass(frozen=True) +class InstrumentGroup: + detector: DetectorGroup + + @classmethod + def read(cls, group: h5py.Group) -> InstrumentGroup: + detector = DetectorGroup.read(group['detector']) + return cls(detector) + + +@dataclass(frozen=True) +class GoniometerGroup: + chi_deg: float + + @classmethod + def read(cls, group: h5py.Group) -> GoniometerGroup: + chi_item = group['chi'] + chi_space = chi_item.id.get_space() + + assert chi_item.attrs['units'] == b'degree' + + if chi_space.get_simple_extent_type() == h5py.h5s.SCALAR: + chi_deg = float(chi_item[()]) + elif isinstance(chi_item, h5py.Dataset): + chi_deg = float(chi_item[0]) + else: + raise ValueError('Failed to read goniometer angle (chi)!') + + return cls(chi_deg) + + +@dataclass(frozen=True) +class SampleGroup: + goniometer: GoniometerGroup + + @classmethod + def read(cls, group: h5py.Group) -> SampleGroup: + goniometer = GoniometerGroup.read(group['goniometer']) + return cls(goniometer) + + +@dataclass(frozen=True) +class EntryGroup: + data: DataGroup + instrument: InstrumentGroup + sample: SampleGroup + + @classmethod + def read(cls, group: h5py.Group, num_patterns_per_array: int) -> EntryGroup: + data = DataGroup.read(group['data'], num_patterns_per_array) + instrument = InstrumentGroup.read(group['instrument']) + sample = SampleGroup.read(group['sample']) + return cls(data, instrument, sample) + + +class NeXusDiffractionDataset(DiffractionDataset): + def __init__( + self, + metadata: DiffractionMetadata, + contents_tree: SimpleTreeNode, + entry: EntryGroup, + ) -> None: + self._metadata = metadata + self._contents_tree = contents_tree + self._entry = entry + + def get_metadata(self) -> DiffractionMetadata: + return self._metadata + + def get_contents_tree(self) -> SimpleTreeNode: + return self._contents_tree + + @overload + def __getitem__(self, index: int) -> DiffractionPatternArray: ... + + @overload + def __getitem__(self, index: slice) -> Sequence[DiffractionPatternArray]: ... + + def __getitem__( + self, index: int | slice + ) -> DiffractionPatternArray | Sequence[DiffractionPatternArray]: + return self._entry.data[index] + + def __len__(self) -> int: + return len(self._entry.data) + + +class NeXusDiffractionFileReader(DiffractionFileReader): + def __init__(self) -> None: + super().__init__() + self._tree_builder = H5DiffractionFileTreeBuilder() + self.stage_rotation_deg = 0.0 # TODO This is a hack; remove when able! + + def read(self, file_path: Path) -> DiffractionDataset: + dataset: DiffractionDataset = SimpleDiffractionDataset.create_null(file_path) + + try: + with h5py.File(file_path, 'r') as h5_file: + metadata = DiffractionMetadata.create_null(file_path) + contents_tree = self._tree_builder.build(h5_file) + + try: + h5_dataset = h5_file['/entry/data/data_000001'] + except KeyError: + logger.error(f'File {file_path} is not a NeXus data file.') + raise + + num_patterns_per_array = h5_dataset.shape[0] + pattern_dtype = h5_dataset.dtype + + try: + entry = EntryGroup.read(h5_file['entry'], num_patterns_per_array) + except KeyError: + logger.error(f'File {file_path} is not a NeXus data file.') + raise + + detector = entry.instrument.detector + detector_pixel_geometry = PixelGeometry( + detector.x_pixel_size_m, + detector.y_pixel_size_m, + ) + crop_center = CropCenter( + detector.beam_center_x_px, + detector.beam_center_y_px, + ) + + detector_specific = detector.detector_specific + detector_extent = ImageExtent( + detector_specific.x_pixels_in_detector, + detector_specific.y_pixels_in_detector, + ) + probe_energy_eV = detector_specific.photon_energy_eV # noqa: N806 + + metadata = DiffractionMetadata( + num_patterns_per_array=num_patterns_per_array, + num_patterns_total=detector_specific.num_patterns_total, + pattern_dtype=pattern_dtype, + detector_distance_m=detector.detector_distance_m, + detector_extent=detector_extent, + detector_pixel_geometry=detector_pixel_geometry, + detector_bit_depth=detector.bit_depth_readout, + crop_center=crop_center, + probe_energy_eV=probe_energy_eV, + file_path=file_path, + ) + + dataset = NeXusDiffractionDataset(metadata, contents_tree, entry) + + # vvv TODO This is a hack; remove when able! vvv + self.stage_rotation_deg = entry.sample.goniometer.chi_deg + except OSError: + logger.warning(f'Unable to read file "{file_path}".') + + return dataset diff --git a/src/ptychodus/plugins/nexus/velociprobe_position_file.py b/src/ptychodus/plugins/nexus/velociprobe_position_file.py new file mode 100644 index 00000000..8e1dbee0 --- /dev/null +++ b/src/ptychodus/plugins/nexus/velociprobe_position_file.py @@ -0,0 +1,91 @@ +from __future__ import annotations +from enum import IntEnum +from pathlib import Path +from typing import Final +import csv + +import numpy + +from ptychodus.api.scan import PositionSequence, PositionFileReader, ScanPoint, ScanPointParseError +from .nexus_diffraction_file import NeXusDiffractionFileReader + +__all__ = [ + 'VelociprobePositionFileReader', +] + + +class VelociprobePositionFileColumn(IntEnum): + X = 1 + LASER_INTERFEROMETER_Y = 2 + POSITION_ENCODER_Y = 5 + TRIGGER = 7 + + +class VelociprobePositionFileReader(PositionFileReader): + NANOMETERS_TO_METERS: Final[float] = 1.0e-9 + + def __init__(self, nexus_reader: NeXusDiffractionFileReader, y_column: int) -> None: + self._nexus_reader = nexus_reader + self._y_column = y_column + + @classmethod + def create_laser_interferometer_instance( + cls, nexus_reader: NeXusDiffractionFileReader + ) -> VelociprobePositionFileReader: + return cls(nexus_reader, VelociprobePositionFileColumn.LASER_INTERFEROMETER_Y) + + @classmethod + def create_position_encoder_instance( + cls, nexus_reader: NeXusDiffractionFileReader + ) -> VelociprobePositionFileReader: + return cls(nexus_reader, VelociprobePositionFileColumn.POSITION_ENCODER_Y) + + def _apply_transform(self, positions: PositionSequence) -> PositionSequence: + stage_rotation_rad = numpy.deg2rad(self._nexus_reader.stage_rotation_deg) + stage_rotation_cos = numpy.cos(stage_rotation_rad) + + x_mean = sum(p.position_x_m for p in positions) / len(positions) + y_mean = sum(p.position_y_m for p in positions) / len(positions) + point_list: list[ScanPoint] = list() + + for untransformed_point in positions: + point = ScanPoint( + untransformed_point.index, + (untransformed_point.position_x_m - x_mean) * stage_rotation_cos, + (untransformed_point.position_y_m - y_mean), + ) + point_list.append(point) + + return PositionSequence(point_list) + + def read(self, file_path: Path) -> PositionSequence: + point_list: list[ScanPoint] = list() + minimum_column_count = max(col.value for col in VelociprobePositionFileColumn) + 1 + + with file_path.open(newline='') as csv_file: + csv_reader = csv.reader(csv_file, delimiter=',') + + for row in csv_reader: + if row[0].startswith('#'): + continue + + if len(row) < minimum_column_count: + raise ScanPointParseError('Bad number of columns!') + + trigger = int(row[VelociprobePositionFileColumn.TRIGGER]) + x_nm = int(row[VelociprobePositionFileColumn.X]) + y_nm = int(row[self._y_column]) + + if self._y_column == VelociprobePositionFileColumn.POSITION_ENCODER_Y: + y_nm = -y_nm + + point = ScanPoint( + trigger, + x_nm * self.NANOMETERS_TO_METERS, + y_nm * self.NANOMETERS_TO_METERS, + ) + point_list.append(point) + + raw_positions = PositionSequence(point_list) + + return self._apply_transform(raw_positions) diff --git a/src/ptychodus/plugins/npyDiffractionFile.py b/src/ptychodus/plugins/npyDiffractionFile.py deleted file mode 100644 index 6b0736f3..00000000 --- a/src/ptychodus/plugins/npyDiffractionFile.py +++ /dev/null @@ -1,89 +0,0 @@ -from pathlib import Path -from typing import Final -import logging - -import numpy - -from ptychodus.api.geometry import ImageExtent -from ptychodus.api.patterns import ( - DiffractionDataset, - DiffractionFileReader, - DiffractionFileWriter, - DiffractionMetadata, - DiffractionPatternState, - SimpleDiffractionDataset, - SimpleDiffractionPatternArray, -) -from ptychodus.api.plugins import PluginRegistry -from ptychodus.api.tree import SimpleTreeNode - -logger = logging.getLogger(__name__) - - -class NPYDiffractionFileIO(DiffractionFileReader, DiffractionFileWriter): - SIMPLE_NAME: Final[str] = 'NPY' - DISPLAY_NAME: Final[str] = 'NumPy Binary Files (*.npy)' - - def read(self, filePath: Path) -> DiffractionDataset: - dataset = SimpleDiffractionDataset.createNullInstance(filePath) - - try: - data = numpy.load(filePath) - except OSError: - logger.warning(f'Unable to read file "{filePath}".') - else: - if data.ndim == 2: - data = data[numpy.newaxis, :, :] - - numberOfPatterns, detectorHeight, detectorWidth = data.shape - - metadata = DiffractionMetadata( - numberOfPatternsPerArray=numberOfPatterns, - numberOfPatternsTotal=numberOfPatterns, - patternDataType=data.dtype, - detectorExtent=ImageExtent(detectorWidth, detectorHeight), - filePath=filePath, - ) - - contentsTree = SimpleTreeNode.createRoot(['Name', 'Type', 'Details']) - contentsTree.createChild( - [filePath.stem, type(data).__name__, f'{data.dtype}{data.shape}'] - ) - - array = SimpleDiffractionPatternArray( - label=filePath.stem, - index=0, - data=data, - state=DiffractionPatternState.FOUND, - ) - - dataset = SimpleDiffractionDataset(metadata, contentsTree, [array]) - - return dataset - - def write(self, filePath: Path, dataset: DiffractionDataset) -> None: - patterns = list() - - for array in dataset: - arrayData = array.getData() - - if arrayData.size > 0: - patterns.append(arrayData) - - data = numpy.concatenate(patterns) - numpy.save(filePath, data) - - -def registerPlugins(registry: PluginRegistry) -> None: - npyDiffractionFileIO = NPYDiffractionFileIO() - - registry.diffractionFileReaders.registerPlugin( - npyDiffractionFileIO, - simpleName=NPYDiffractionFileIO.SIMPLE_NAME, - displayName=NPYDiffractionFileIO.DISPLAY_NAME, - ) - registry.diffractionFileWriters.registerPlugin( - npyDiffractionFileIO, - simpleName=NPYDiffractionFileIO.SIMPLE_NAME, - displayName=NPYDiffractionFileIO.DISPLAY_NAME, - ) diff --git a/src/ptychodus/plugins/npyObjectFile.py b/src/ptychodus/plugins/npyObjectFile.py deleted file mode 100644 index 7389b277..00000000 --- a/src/ptychodus/plugins/npyObjectFile.py +++ /dev/null @@ -1,31 +0,0 @@ -from pathlib import Path - -import numpy - -from ptychodus.api.object import Object, ObjectFileReader, ObjectFileWriter -from ptychodus.api.plugins import PluginRegistry - - -class NPYObjectFileReader(ObjectFileReader): - def read(self, filePath: Path) -> Object: - array = numpy.load(filePath) - return Object(array) - - -class NPYObjectFileWriter(ObjectFileWriter): - def write(self, filePath: Path, object_: Object) -> None: - array = object_.array - numpy.save(filePath, array) - - -def registerPlugins(registry: PluginRegistry) -> None: - registry.objectFileReaders.registerPlugin( - NPYObjectFileReader(), - simpleName='NPY', - displayName='NumPy Binary Files (*.npy)', - ) - registry.objectFileWriters.registerPlugin( - NPYObjectFileWriter(), - simpleName='NPY', - displayName='NumPy Binary Files (*.npy)', - ) diff --git a/src/ptychodus/plugins/npyProbeFile.py b/src/ptychodus/plugins/npyProbeFile.py deleted file mode 100644 index b5c58fe0..00000000 --- a/src/ptychodus/plugins/npyProbeFile.py +++ /dev/null @@ -1,31 +0,0 @@ -from pathlib import Path - -import numpy - -from ptychodus.api.plugins import PluginRegistry -from ptychodus.api.probe import Probe, ProbeFileReader, ProbeFileWriter - - -class NPYProbeFileReader(ProbeFileReader): - def read(self, filePath: Path) -> Probe: - array = numpy.load(filePath) - return Probe(array) - - -class NPYProbeFileWriter(ProbeFileWriter): - def write(self, filePath: Path, probe: Probe) -> None: - array = probe.array - numpy.save(filePath, array) - - -def registerPlugins(registry: PluginRegistry) -> None: - registry.probeFileReaders.registerPlugin( - NPYProbeFileReader(), - simpleName='NPY', - displayName='NumPy Binary Files (*.npy)', - ) - registry.probeFileWriters.registerPlugin( - NPYProbeFileWriter(), - simpleName='NPY', - displayName='NumPy Binary Files (*.npy)', - ) diff --git a/src/ptychodus/plugins/npy_diffraction_file.py b/src/ptychodus/plugins/npy_diffraction_file.py new file mode 100644 index 00000000..02e86a49 --- /dev/null +++ b/src/ptychodus/plugins/npy_diffraction_file.py @@ -0,0 +1,79 @@ +from pathlib import Path +from typing import Final +import logging + +import numpy + +from ptychodus.api.geometry import ImageExtent +from ptychodus.api.patterns import ( + DiffractionDataset, + DiffractionFileReader, + DiffractionFileWriter, + DiffractionMetadata, + SimpleDiffractionDataset, + SimpleDiffractionPatternArray, +) +from ptychodus.api.plugins import PluginRegistry +from ptychodus.api.tree import SimpleTreeNode + +logger = logging.getLogger(__name__) + + +class NPYDiffractionFileIO(DiffractionFileReader, DiffractionFileWriter): + SIMPLE_NAME: Final[str] = 'NPY' + DISPLAY_NAME: Final[str] = 'NumPy Binary Files (*.npy)' + + def read(self, file_path: Path) -> DiffractionDataset: + dataset = SimpleDiffractionDataset.create_null(file_path) + + try: + data = numpy.load(file_path) + except OSError: + logger.warning(f'Unable to read file "{file_path}".') + else: + if data.ndim == 2: + data = data[numpy.newaxis, :, :] + + num_patterns, detector_height, detector_width = data.shape + + metadata = DiffractionMetadata( + num_patterns_per_array=num_patterns, + num_patterns_total=num_patterns, + pattern_dtype=data.dtype, + detector_extent=ImageExtent(detector_width, detector_height), + file_path=file_path, + ) + + contents_tree = SimpleTreeNode.create_root(['Name', 'Type', 'Details']) + contents_tree.create_child( + [file_path.stem, type(data).__name__, f'{data.dtype}{data.shape}'] + ) + + array = SimpleDiffractionPatternArray( + label=file_path.stem, + indexes=numpy.arange(num_patterns), + data=data, + ) + + dataset = SimpleDiffractionDataset(metadata, contents_tree, [array]) + + return dataset + + def write(self, file_path: Path, dataset: DiffractionDataset) -> None: + patterns = numpy.concatenate([array.get_data() for array in dataset]) + numpy.save(file_path, patterns) + + +def register_plugins(registry: PluginRegistry) -> None: + npy_diffraction_file_io = NPYDiffractionFileIO() + + registry.diffraction_file_readers.register_plugin( + npy_diffraction_file_io, + simple_name=NPYDiffractionFileIO.SIMPLE_NAME, + display_name=NPYDiffractionFileIO.DISPLAY_NAME, + ) + registry.diffraction_file_writers.register_plugin( + npy_diffraction_file_io, + simple_name=NPYDiffractionFileIO.SIMPLE_NAME, + display_name=NPYDiffractionFileIO.DISPLAY_NAME, + ) diff --git a/src/ptychodus/plugins/npy_object_file.py b/src/ptychodus/plugins/npy_object_file.py new file mode 100644 index 00000000..9194034f --- /dev/null +++ b/src/ptychodus/plugins/npy_object_file.py @@ -0,0 +1,31 @@ +from pathlib import Path + +import numpy + +from ptychodus.api.object import Object, ObjectFileReader, ObjectFileWriter +from ptychodus.api.plugins import PluginRegistry + + +class NPYObjectFileReader(ObjectFileReader): + def read(self, file_path: Path) -> Object: + array = numpy.load(file_path) + return Object(array=array, pixel_geometry=None, center=None) + + +class NPYObjectFileWriter(ObjectFileWriter): + def write(self, file_path: Path, object_: Object) -> None: + array = object_.get_array() + numpy.save(file_path, array) + + +def register_plugins(registry: PluginRegistry) -> None: + registry.object_file_readers.register_plugin( + NPYObjectFileReader(), + simple_name='NPY', + display_name='NumPy Binary Files (*.npy)', + ) + registry.object_file_writers.register_plugin( + NPYObjectFileWriter(), + simple_name='NPY', + display_name='NumPy Binary Files (*.npy)', + ) diff --git a/src/ptychodus/plugins/npy_probe_file.py b/src/ptychodus/plugins/npy_probe_file.py new file mode 100644 index 00000000..6a2c617f --- /dev/null +++ b/src/ptychodus/plugins/npy_probe_file.py @@ -0,0 +1,31 @@ +from pathlib import Path + +import numpy + +from ptychodus.api.plugins import PluginRegistry +from ptychodus.api.probe import ProbeSequence, ProbeFileReader, ProbeFileWriter + + +class NPYProbeFileReader(ProbeFileReader): + def read(self, file_path: Path) -> ProbeSequence: + array = numpy.load(file_path) + return ProbeSequence(array=array, opr_weights=None, pixel_geometry=None) + + +class NPYProbeFileWriter(ProbeFileWriter): + def write(self, file_path: Path, probes: ProbeSequence) -> None: + array = probes.get_probe_no_opr().get_array() + numpy.save(file_path, array) + + +def register_plugins(registry: PluginRegistry) -> None: + registry.probe_file_readers.register_plugin( + NPYProbeFileReader(), + simple_name='NPY', + display_name='NumPy Binary Files (*.npy)', + ) + registry.probe_file_writers.register_plugin( + NPYProbeFileWriter(), + simple_name='NPY', + display_name='NumPy Binary Files (*.npy)', + ) diff --git a/src/ptychodus/plugins/npzDiffractionFile.py b/src/ptychodus/plugins/npzDiffractionFile.py deleted file mode 100644 index 26c3fd08..00000000 --- a/src/ptychodus/plugins/npzDiffractionFile.py +++ /dev/null @@ -1,100 +0,0 @@ -from pathlib import Path -from typing import Any, Final -import logging - -import numpy - -from ptychodus.api.geometry import ImageExtent -from ptychodus.api.patterns import ( - DiffractionDataset, - DiffractionFileReader, - DiffractionFileWriter, - DiffractionMetadata, - DiffractionPatternState, - SimpleDiffractionDataset, - SimpleDiffractionPatternArray, -) -from ptychodus.api.plugins import PluginRegistry -from ptychodus.api.tree import SimpleTreeNode - -logger = logging.getLogger(__name__) - - -class NPZDiffractionFileIO(DiffractionFileReader, DiffractionFileWriter): - SIMPLE_NAME: Final[str] = 'NPZ' - DISPLAY_NAME: Final[str] = 'NumPy Zipped Archive (*.npz)' - - INDEXES: Final[str] = 'indexes' # TODO include indexes - PATTERNS: Final[str] = 'patterns' - - def read(self, filePath: Path) -> DiffractionDataset: - dataset = SimpleDiffractionDataset.createNullInstance(filePath) - - try: - contents = numpy.load(filePath) - except OSError: - logger.warning(f'Unable to read file "{filePath}".') - return dataset - - try: - patterns = contents[self.PATTERNS] - except KeyError: - logger.warning(f'Failed to read patterns in "{filePath}".') - return dataset - - numberOfPatterns, detectorHeight, detectorWidth = patterns.shape - - metadata = DiffractionMetadata( - numberOfPatternsPerArray=numberOfPatterns, - numberOfPatternsTotal=numberOfPatterns, - patternDataType=patterns.dtype, - detectorExtent=ImageExtent(detectorWidth, detectorHeight), - filePath=filePath, - ) - - contentsTree = SimpleTreeNode.createRoot(['Name', 'Type', 'Details']) - contentsTree.createChild( - [ - filePath.stem, - type(patterns).__name__, - f'{patterns.dtype}{patterns.shape}', - ] - ) - - array = SimpleDiffractionPatternArray( - label=filePath.stem, - index=0, - data=patterns, - state=DiffractionPatternState.FOUND, - ) - - return SimpleDiffractionDataset(metadata, contentsTree, [array]) - - def write(self, filePath: Path, dataset: DiffractionDataset) -> None: - patterns = list() - - for array in dataset: - arrayData = array.getData() - - if arrayData.size > 0: - patterns.append(arrayData) - - contents: dict[str, Any] = dict() - # TODO contents[self.INDEXES] = numpy.array(dataset.getAssembledIndexes()), - contents[self.PATTERNS] = numpy.concatenate(patterns) - numpy.savez(filePath, **contents) - - -def registerPlugins(registry: PluginRegistry) -> None: - npzDiffractionFileIO = NPZDiffractionFileIO() - - registry.diffractionFileReaders.registerPlugin( - npzDiffractionFileIO, - simpleName=NPZDiffractionFileIO.SIMPLE_NAME, - displayName=NPZDiffractionFileIO.DISPLAY_NAME, - ) - registry.diffractionFileWriters.registerPlugin( - npzDiffractionFileIO, - simpleName=NPZDiffractionFileIO.SIMPLE_NAME, - displayName=NPZDiffractionFileIO.DISPLAY_NAME, - ) diff --git a/src/ptychodus/plugins/npzProductFile.py b/src/ptychodus/plugins/npzProductFile.py deleted file mode 100644 index 848a346d..00000000 --- a/src/ptychodus/plugins/npzProductFile.py +++ /dev/null @@ -1,200 +0,0 @@ -from pathlib import Path -from typing import Any, Final - -import numpy - -from ptychodus.api.object import Object, ObjectFileReader -from ptychodus.api.plugins import PluginRegistry -from ptychodus.api.probe import Probe, ProbeFileReader -from ptychodus.api.product import ( - Product, - ProductFileReader, - ProductFileWriter, - ProductMetadata, -) -from ptychodus.api.scan import Scan, ScanFileReader, ScanPoint - - -class NPZProductFileIO(ProductFileReader, ProductFileWriter): - SIMPLE_NAME: Final[str] = 'NPZ' - DISPLAY_NAME: Final[str] = 'NumPy Zipped Archive (*.npz)' - - NAME: Final[str] = 'name' - COMMENTS: Final[str] = 'comments' - DETECTOR_OBJECT_DISTANCE: Final[str] = 'detector_object_distance_m' - PROBE_ENERGY: Final[str] = 'probe_energy_eV' - PROBE_PHOTON_FLUX: Final[str] = 'probe_photons_per_s' - EXPOSURE_TIME: Final[str] = 'exposure_time_s' - - PROBE_ARRAY: Final[str] = 'probe' - PROBE_PIXEL_HEIGHT: Final[str] = 'probe_pixel_height_m' - PROBE_PIXEL_WIDTH: Final[str] = 'probe_pixel_width_m' - PROBE_POSITION_INDEXES: Final[str] = 'probe_position_indexes' - PROBE_POSITION_X: Final[str] = 'probe_position_x_m' - PROBE_POSITION_Y: Final[str] = 'probe_position_y_m' - - OBJECT_ARRAY: Final[str] = 'object' - OBJECT_CENTER_X: Final[str] = 'object_center_x_m' - OBJECT_CENTER_Y: Final[str] = 'object_center_y_m' - OBJECT_LAYER_DISTANCE: Final[str] = 'object_layer_distance_m' - OBJECT_PIXEL_HEIGHT: Final[str] = 'object_pixel_height_m' - OBJECT_PIXEL_WIDTH: Final[str] = 'object_pixel_width_m' - - COSTS_ARRAY: Final[str] = 'costs' - - def read(self, filePath: Path) -> Product: - with numpy.load(filePath) as npzFile: - metadata = ProductMetadata( - name=str(npzFile[self.NAME]), - comments=str(npzFile[self.COMMENTS]), - detectorDistanceInMeters=float(npzFile[self.DETECTOR_OBJECT_DISTANCE]), - probeEnergyInElectronVolts=float(npzFile[self.PROBE_ENERGY]), - probePhotonsPerSecond=float(npzFile[self.PROBE_PHOTON_FLUX]), - exposureTimeInSeconds=float(npzFile[self.EXPOSURE_TIME]), - ) - - scanIndexes = npzFile[self.PROBE_POSITION_INDEXES] - scanXInMeters = npzFile[self.PROBE_POSITION_X] - scanYInMeters = npzFile[self.PROBE_POSITION_Y] - - probe = Probe( - array=npzFile[self.PROBE_ARRAY], - pixelWidthInMeters=float(npzFile[self.PROBE_PIXEL_WIDTH]), - pixelHeightInMeters=float(npzFile[self.PROBE_PIXEL_HEIGHT]), - ) - - object_ = Object( - array=npzFile[self.OBJECT_ARRAY], - layerDistanceInMeters=npzFile[self.OBJECT_LAYER_DISTANCE], - pixelWidthInMeters=float(npzFile[self.OBJECT_PIXEL_WIDTH]), - pixelHeightInMeters=float(npzFile[self.OBJECT_PIXEL_HEIGHT]), - centerXInMeters=float(npzFile[self.OBJECT_CENTER_X]), - centerYInMeters=float(npzFile[self.OBJECT_CENTER_Y]), - ) - - costs = npzFile[self.COSTS_ARRAY] - - scanPointList: list[ScanPoint] = list() - - for idx, x_m, y_m in zip(scanIndexes, scanXInMeters, scanYInMeters): - point = ScanPoint(idx, x_m, y_m) - scanPointList.append(point) - - return Product( - metadata=metadata, - scan=Scan(scanPointList), - probe=probe, - object_=object_, - costs=costs, - ) - - def write(self, filePath: Path, product: Product) -> None: - contents: dict[str, Any] = dict() - scanIndexes: list[int] = list() - scanXInMeters: list[float] = list() - scanYInMeters: list[float] = list() - - for point in product.scan: - scanIndexes.append(point.index) - scanXInMeters.append(point.positionXInMeters) - scanYInMeters.append(point.positionYInMeters) - - metadata = product.metadata - contents[self.NAME] = metadata.name - contents[self.COMMENTS] = metadata.comments - contents[self.DETECTOR_OBJECT_DISTANCE] = metadata.detectorDistanceInMeters - contents[self.PROBE_ENERGY] = metadata.probeEnergyInElectronVolts - contents[self.PROBE_PHOTON_FLUX] = metadata.probePhotonsPerSecond - contents[self.EXPOSURE_TIME] = metadata.exposureTimeInSeconds - - contents[self.PROBE_POSITION_INDEXES] = scanIndexes - contents[self.PROBE_POSITION_X] = scanXInMeters - contents[self.PROBE_POSITION_Y] = scanYInMeters - - probe = product.probe - probeGeometry = probe.getGeometry() - contents[self.PROBE_ARRAY] = probe.array - contents[self.PROBE_PIXEL_WIDTH] = probeGeometry.pixelWidthInMeters - contents[self.PROBE_PIXEL_HEIGHT] = probeGeometry.pixelHeightInMeters - - object_ = product.object_ - objectGeometry = object_.getGeometry() - contents[self.OBJECT_ARRAY] = object_.array - contents[self.OBJECT_CENTER_X] = objectGeometry.centerXInMeters - contents[self.OBJECT_CENTER_Y] = objectGeometry.centerYInMeters - contents[self.OBJECT_PIXEL_WIDTH] = objectGeometry.pixelWidthInMeters - contents[self.OBJECT_PIXEL_HEIGHT] = objectGeometry.pixelHeightInMeters - contents[self.OBJECT_LAYER_DISTANCE] = object_.layerDistanceInMeters - - contents[self.COSTS_ARRAY] = product.costs - - numpy.savez(filePath, **contents) - - -class NPZScanFileReader(ScanFileReader): - def read(self, filePath: Path) -> Scan: - with numpy.load(filePath) as npzFile: - scanIndexes = npzFile[NPZProductFileIO.PROBE_POSITION_INDEXES] - scanXInMeters = npzFile[NPZProductFileIO.PROBE_POSITION_X] - scanYInMeters = npzFile[NPZProductFileIO.PROBE_POSITION_Y] - - scanPointList: list[ScanPoint] = list() - - for idx, x_m, y_m in zip(scanIndexes, scanXInMeters, scanYInMeters): - point = ScanPoint(idx, x_m, y_m) - scanPointList.append(point) - - return Scan(scanPointList) - - -class NPZProbeFileReader(ProbeFileReader): - def read(self, filePath: Path) -> Probe: - with numpy.load(filePath) as npzFile: - return Probe( - array=npzFile[NPZProductFileIO.PROBE_ARRAY], - pixelWidthInMeters=float(npzFile[NPZProductFileIO.PROBE_PIXEL_WIDTH]), - pixelHeightInMeters=float(npzFile[NPZProductFileIO.PROBE_PIXEL_HEIGHT]), - ) - - -class NPZObjectFileReader(ObjectFileReader): - def read(self, filePath: Path) -> Object: - with numpy.load(filePath) as npzFile: - return Object( - array=npzFile[NPZProductFileIO.OBJECT_ARRAY], - layerDistanceInMeters=npzFile[NPZProductFileIO.OBJECT_LAYER_DISTANCE], - pixelWidthInMeters=float(npzFile[NPZProductFileIO.OBJECT_PIXEL_WIDTH]), - pixelHeightInMeters=float(npzFile[NPZProductFileIO.OBJECT_PIXEL_HEIGHT]), - centerXInMeters=float(npzFile[NPZProductFileIO.OBJECT_CENTER_X]), - centerYInMeters=float(npzFile[NPZProductFileIO.OBJECT_CENTER_Y]), - ) - - -def registerPlugins(registry: PluginRegistry) -> None: - npzProductFileIO = NPZProductFileIO() - - registry.productFileReaders.registerPlugin( - npzProductFileIO, - simpleName=NPZProductFileIO.SIMPLE_NAME, - displayName=NPZProductFileIO.DISPLAY_NAME, - ) - registry.productFileWriters.registerPlugin( - npzProductFileIO, - simpleName=NPZProductFileIO.SIMPLE_NAME, - displayName=NPZProductFileIO.DISPLAY_NAME, - ) - registry.scanFileReaders.registerPlugin( - NPZScanFileReader(), - simpleName=NPZProductFileIO.SIMPLE_NAME, - displayName=NPZProductFileIO.DISPLAY_NAME, - ) - registry.probeFileReaders.registerPlugin( - NPZProbeFileReader(), - simpleName=NPZProductFileIO.SIMPLE_NAME, - displayName=NPZProductFileIO.DISPLAY_NAME, - ) - registry.objectFileReaders.registerPlugin( - NPZObjectFileReader(), - simpleName=NPZProductFileIO.SIMPLE_NAME, - displayName=NPZProductFileIO.DISPLAY_NAME, - ) diff --git a/src/ptychodus/plugins/npz_diffraction_file.py b/src/ptychodus/plugins/npz_diffraction_file.py new file mode 100644 index 00000000..111e3575 --- /dev/null +++ b/src/ptychodus/plugins/npz_diffraction_file.py @@ -0,0 +1,97 @@ +from pathlib import Path +from typing import Final +import logging + +import numpy + +from ptychodus.api.geometry import ImageExtent +from ptychodus.api.patterns import ( + DiffractionDataset, + DiffractionFileReader, + DiffractionFileWriter, + DiffractionMetadata, + SimpleDiffractionDataset, + SimpleDiffractionPatternArray, +) +from ptychodus.api.plugins import PluginRegistry +from ptychodus.api.tree import SimpleTreeNode + +logger = logging.getLogger(__name__) + + +class NPZDiffractionFileIO(DiffractionFileReader, DiffractionFileWriter): + SIMPLE_NAME: Final[str] = 'NPZ' + DISPLAY_NAME: Final[str] = 'NumPy Zipped Archive (*.npz)' + + INDEXES: Final[str] = 'indexes' + PATTERNS: Final[str] = 'patterns' + + def read(self, file_path: Path) -> DiffractionDataset: + dataset = SimpleDiffractionDataset.create_null(file_path) + + try: + contents = numpy.load(file_path) + except OSError: + logger.warning(f'Unable to read file "{file_path}".') + return dataset + + try: + patterns = contents[self.PATTERNS] + except KeyError: + logger.warning(f'Failed to read patterns in "{file_path}".') + return dataset + + num_patterns, detector_height, detector_width = patterns.shape + + metadata = DiffractionMetadata( + num_patterns_per_array=num_patterns, + num_patterns_total=num_patterns, + pattern_dtype=patterns.dtype, + detector_extent=ImageExtent(detector_width, detector_height), + file_path=file_path, + ) + + try: + indexes = contents[self.INDEXES] + except KeyError: + logger.warning(f'Failed to read indexes in "{file_path}".') + indexes = numpy.arange(num_patterns) + + contents_tree = SimpleTreeNode.create_root(['Name', 'Type', 'Details']) + contents_tree.create_child( + [ + file_path.stem, + type(patterns).__name__, + f'{patterns.dtype}{patterns.shape}', + ] + ) + + array = SimpleDiffractionPatternArray( + label=file_path.stem, + indexes=indexes, + data=patterns, + ) + + return SimpleDiffractionDataset(metadata, contents_tree, [array]) + + def write(self, file_path: Path, dataset: DiffractionDataset) -> None: + contents = { + self.INDEXES: numpy.concatenate([array.get_indexes() for array in dataset]), + self.PATTERNS: numpy.concatenate([array.get_data() for array in dataset]), + } + numpy.savez(file_path, **contents) + + +def register_plugins(registry: PluginRegistry) -> None: + npz_diffraction_file_io = NPZDiffractionFileIO() + + registry.diffraction_file_readers.register_plugin( + npz_diffraction_file_io, + simple_name=NPZDiffractionFileIO.SIMPLE_NAME, + display_name=NPZDiffractionFileIO.DISPLAY_NAME, + ) + registry.diffraction_file_writers.register_plugin( + npz_diffraction_file_io, + simple_name=NPZDiffractionFileIO.SIMPLE_NAME, + display_name=NPZDiffractionFileIO.DISPLAY_NAME, + ) diff --git a/src/ptychodus/plugins/npz_product_file.py b/src/ptychodus/plugins/npz_product_file.py new file mode 100644 index 00000000..290d32ff --- /dev/null +++ b/src/ptychodus/plugins/npz_product_file.py @@ -0,0 +1,202 @@ +from pathlib import Path +from typing import Any, Final +import logging + +import numpy + +from ptychodus.api.geometry import PixelGeometry +from ptychodus.api.object import Object, ObjectCenter, ObjectFileReader +from ptychodus.api.plugins import PluginRegistry +from ptychodus.api.probe import ProbeSequence, ProbeFileReader +from ptychodus.api.product import ( + Product, + ProductFileReader, + ProductFileWriter, + ProductMetadata, +) +from ptychodus.api.scan import PositionSequence, PositionFileReader, ScanPoint + +logger = logging.getLogger(__name__) + + +class NPZProductFileIO(ProductFileReader, ProductFileWriter): + SIMPLE_NAME: Final[str] = 'NPZ' + DISPLAY_NAME: Final[str] = 'Ptychodus NumPy Zipped Archive (*.npz)' + + NAME: Final[str] = 'name' + COMMENTS: Final[str] = 'comments' + DETECTOR_OBJECT_DISTANCE: Final[str] = 'detector_object_distance_m' + PROBE_ENERGY: Final[str] = 'probe_energy_eV' + PROBE_PHOTON_COUNT: Final[str] = 'probe_photon_count' + EXPOSURE_TIME: Final[str] = 'exposure_time_s' + MASS_ATTENUATION: Final[str] = 'mass_attenuation_m2_kg' + TOMOGRAPHY_ANGLE: Final[str] = 'tomography_angle_deg' + + PROBE_ARRAY: Final[str] = 'probe' + OPR_WEIGHTS: Final[str] = 'opr_weights' + PROBE_PIXEL_HEIGHT: Final[str] = 'probe_pixel_height_m' + PROBE_PIXEL_WIDTH: Final[str] = 'probe_pixel_width_m' + PROBE_POSITION_INDEXES: Final[str] = 'probe_position_indexes' + PROBE_POSITION_X: Final[str] = 'probe_position_x_m' + PROBE_POSITION_Y: Final[str] = 'probe_position_y_m' + + OBJECT_ARRAY: Final[str] = 'object' + OBJECT_CENTER_X: Final[str] = 'object_center_x_m' + OBJECT_CENTER_Y: Final[str] = 'object_center_y_m' + OBJECT_LAYER_SPACING: Final[str] = 'object_layer_spacing_m' + OBJECT_PIXEL_HEIGHT: Final[str] = 'object_pixel_height_m' + OBJECT_PIXEL_WIDTH: Final[str] = 'object_pixel_width_m' + + COSTS_ARRAY: Final[str] = 'costs' + + def read(self, file_path: Path) -> Product: + with numpy.load(file_path) as npz_file: + probe_photon_count = 0.0 + + try: + probe_photon_count = float(npz_file[self.PROBE_PHOTON_COUNT]) + except KeyError: + logger.debug('Probe photon count not found.') + + mass_attenuation_m2_kg = 0.0 + + try: + mass_attenuation_m2_kg = float(npz_file[self.MASS_ATTENUATION]) + except KeyError: + logger.debug('Mass attenuation not found.') + + tomography_angle_deg = 0.0 + + try: + tomography_angle_deg = float(npz_file[self.TOMOGRAPHY_ANGLE]) + except KeyError: + logger.debug('Tomography angle not found.') + + metadata = ProductMetadata( + name=str(npz_file[self.NAME]), + comments=str(npz_file[self.COMMENTS]), + detector_distance_m=float(npz_file[self.DETECTOR_OBJECT_DISTANCE]), + probe_energy_eV=float(npz_file[self.PROBE_ENERGY]), + probe_photon_count=probe_photon_count, + exposure_time_s=float(npz_file[self.EXPOSURE_TIME]), + mass_attenuation_m2_kg=mass_attenuation_m2_kg, + tomography_angle_deg=tomography_angle_deg, + ) + + scan_indexes = npz_file[self.PROBE_POSITION_INDEXES] + scan_x_m = npz_file[self.PROBE_POSITION_X] + scan_y_m = npz_file[self.PROBE_POSITION_Y] + + probe_pixel_geometry = PixelGeometry( + width_m=float(npz_file[self.PROBE_PIXEL_WIDTH]), + height_m=float(npz_file[self.PROBE_PIXEL_HEIGHT]), + ) + + try: + opr_weights = npz_file[self.OPR_WEIGHTS] + except KeyError: + logger.debug('OPR weights not found.') + opr_weights = None + + probe = ProbeSequence( + array=npz_file[self.PROBE_ARRAY], + opr_weights=opr_weights, + pixel_geometry=probe_pixel_geometry, + ) + + object_pixel_geometry = PixelGeometry( + width_m=float(npz_file[self.OBJECT_PIXEL_WIDTH]), + height_m=float(npz_file[self.OBJECT_PIXEL_HEIGHT]), + ) + object_center = ObjectCenter( + position_x_m=float(npz_file[self.OBJECT_CENTER_X]), + position_y_m=float(npz_file[self.OBJECT_CENTER_Y]), + ) + object_ = Object( + array=npz_file[self.OBJECT_ARRAY], + pixel_geometry=object_pixel_geometry, + center=object_center, + layer_spacing_m=npz_file[self.OBJECT_LAYER_SPACING], + ) + + costs = npz_file[self.COSTS_ARRAY] + + point_list: list[ScanPoint] = list() + + for idx, x_m, y_m in zip(scan_indexes, scan_x_m, scan_y_m): + point = ScanPoint(idx, x_m, y_m) + point_list.append(point) + + return Product( + metadata=metadata, + positions=PositionSequence(point_list), + probes=probe, + object_=object_, + costs=costs, + ) + + def write(self, file_path: Path, product: Product) -> None: + contents: dict[str, Any] = dict() + scan_indexes: list[int] = list() + scan_x_m: list[float] = list() + scan_y_m: list[float] = list() + + for point in product.positions: + scan_indexes.append(point.index) + scan_x_m.append(point.position_x_m) + scan_y_m.append(point.position_y_m) + + metadata = product.metadata + contents[self.NAME] = metadata.name + contents[self.COMMENTS] = metadata.comments + contents[self.DETECTOR_OBJECT_DISTANCE] = metadata.detector_distance_m + contents[self.PROBE_ENERGY] = metadata.probe_energy_eV + contents[self.PROBE_PHOTON_COUNT] = metadata.probe_photon_count + contents[self.EXPOSURE_TIME] = metadata.exposure_time_s + contents[self.MASS_ATTENUATION] = metadata.mass_attenuation_m2_kg + + contents[self.PROBE_POSITION_INDEXES] = scan_indexes + contents[self.PROBE_POSITION_X] = scan_x_m + contents[self.PROBE_POSITION_Y] = scan_y_m + + probe = product.probes + contents[self.PROBE_ARRAY] = probe.get_array() + + try: + opr_weights = probe.get_opr_weights() + except ValueError: + pass + else: + contents[self.OPR_WEIGHTS] = opr_weights + + probe_pixel_geometry = probe.get_pixel_geometry() + contents[self.PROBE_PIXEL_WIDTH] = probe_pixel_geometry.width_m + contents[self.PROBE_PIXEL_HEIGHT] = probe_pixel_geometry.height_m + + object_ = product.object_ + object_geometry = object_.get_geometry() + contents[self.OBJECT_ARRAY] = object_.get_array() + contents[self.OBJECT_CENTER_X] = object_geometry.center_x_m + contents[self.OBJECT_CENTER_Y] = object_geometry.center_y_m + contents[self.OBJECT_PIXEL_WIDTH] = object_geometry.pixel_width_m + contents[self.OBJECT_PIXEL_HEIGHT] = object_geometry.pixel_height_m + contents[self.OBJECT_LAYER_SPACING] = object_.layer_spacing_m + + contents[self.COSTS_ARRAY] = product.costs + + numpy.savez(file_path, **contents) + + +def register_plugins(registry: PluginRegistry) -> None: + npz_product_file_io = NPZProductFileIO() + + registry.register_product_file_reader_with_adapters( + npz_product_file_io, + simple_name=NPZProductFileIO.SIMPLE_NAME, + display_name=NPZProductFileIO.DISPLAY_NAME, + ) + registry.product_file_writers.register_plugin( + npz_product_file_io, + simple_name=NPZProductFileIO.SIMPLE_NAME, + display_name=NPZProductFileIO.DISPLAY_NAME, + ) diff --git a/src/ptychodus/plugins/nsls2_diffraction_file.py b/src/ptychodus/plugins/nsls2_diffraction_file.py new file mode 100644 index 00000000..cbad2da8 --- /dev/null +++ b/src/ptychodus/plugins/nsls2_diffraction_file.py @@ -0,0 +1,76 @@ +from pathlib import Path +from typing import Final +import logging + +import h5py +import numpy + +from ptychodus.api.geometry import ImageExtent, PixelGeometry +from ptychodus.api.patterns import ( + DiffractionDataset, + DiffractionFileReader, + DiffractionMetadata, + SimpleDiffractionDataset, +) +from ptychodus.api.plugins import PluginRegistry + +from .h5_diffraction_file import H5DiffractionPatternArray, H5DiffractionFileTreeBuilder + +logger = logging.getLogger(__name__) + + +class NSLSIIDiffractionFileReader(DiffractionFileReader): + SIMPLE_NAME: Final[str] = 'NSLS-II' + DISPLAY_NAME: Final[str] = 'NSLS-II Files (*.mat)' + ONE_MICRON_M: Final[float] = 1e-6 + + def __init__(self) -> None: + self._data_path = 'det_data' + self._tree_builder = H5DiffractionFileTreeBuilder() + + def read(self, file_path: Path) -> DiffractionDataset: + dataset = SimpleDiffractionDataset.create_null(file_path) + + try: + with h5py.File(file_path, 'r') as h5_file: + contents_tree = self._tree_builder.build(h5_file) + + try: + data = h5_file[self._data_path] + except KeyError: + logger.warning('Unable to load data.') + else: + num_patterns, detector_height, detector_width = data.shape + pixel_size_m = ( + float(numpy.squeeze(h5_file['det_pixel_size'][()])) * self.ONE_MICRON_M + ) + + metadata = DiffractionMetadata( + num_patterns_per_array=num_patterns, + num_patterns_total=num_patterns, + pattern_dtype=data.dtype, + detector_extent=ImageExtent(detector_width, detector_height), + detector_pixel_geometry=PixelGeometry(pixel_size_m, pixel_size_m), + file_path=file_path, + ) + + array = H5DiffractionPatternArray( + label=file_path.stem, + indexes=numpy.arange(num_patterns), + file_path=file_path, + data_path=self._data_path, + ) + + dataset = SimpleDiffractionDataset(metadata, contents_tree, [array]) + except OSError: + logger.warning(f'Unable to read file "{file_path}".') + + return dataset + + +def register_plugins(registry: PluginRegistry) -> None: + registry.diffraction_file_readers.register_plugin( + NSLSIIDiffractionFileReader(), + simple_name=NSLSIIDiffractionFileReader.SIMPLE_NAME, + display_name=NSLSIIDiffractionFileReader.DISPLAY_NAME, + ) diff --git a/src/ptychodus/plugins/nsls2_product_file.py b/src/ptychodus/plugins/nsls2_product_file.py new file mode 100644 index 00000000..ac41879a --- /dev/null +++ b/src/ptychodus/plugins/nsls2_product_file.py @@ -0,0 +1,77 @@ +from pathlib import Path +from typing import Final, Sequence + +import h5py + +from ptychodus.api.geometry import PixelGeometry +from ptychodus.api.object import Object +from ptychodus.api.plugins import PluginRegistry +from ptychodus.api.probe import ProbeSequence +from ptychodus.api.product import Product, ProductFileReader, ProductMetadata +from ptychodus.api.scan import PositionSequence, ScanPoint + + +class NSLSIIProductFileReader(ProductFileReader): + SIMPLE_NAME: Final[str] = 'NSLS-II' + DISPLAY_NAME: Final[str] = 'NSLS-II Product Files (*.mat)' + ONE_MICRON_M: Final[float] = 1e-6 + + def read(self, file_path: Path) -> Product: + point_list: list[ScanPoint] = list() + + with h5py.File(file_path, 'r') as h5_file: + detector_distance_m = float(h5_file['det_dist'][()]) * self.ONE_MICRON_M + probe_energy_eV = 1000 * float(h5_file['energy'][()]) # noqa: N806 + + metadata = ProductMetadata( + name=file_path.stem, + comments='', + detector_distance_m=detector_distance_m, + probe_energy_eV=probe_energy_eV, + probe_photon_count=0.0, # not included in file + exposure_time_s=0.0, # not included in file + mass_attenuation_m2_kg=0.0, # not included in file + tomography_angle_deg=0.0, # not included in file + ) + + pixel_width_m = h5_file['img_pixel_size_x'][()] + pixel_height_m = h5_file['img_pixel_size_y'][()] + pixel_geometry = PixelGeometry(width_m=pixel_width_m, height_m=pixel_height_m) + positions_m = h5_file['pos_xy'][()].T * self.ONE_MICRON_M + + for index, _xy in enumerate(positions_m): + point = ScanPoint( + index=index, + position_x_m=_xy[1], + position_y_m=_xy[2], + ) + point_list.append(point) + + probe_array = h5_file['prb'][()].astype(complex) + probes = ProbeSequence( + array=probe_array, opr_weights=None, pixel_geometry=pixel_geometry + ) + + object_array = h5_file['obj'][()].astype(complex) + object_ = Object( + array=object_array, + pixel_geometry=pixel_geometry, + center=None, + ) + costs: Sequence[float] = list() + + return Product( + metadata=metadata, + positions=PositionSequence(point_list), + probes=probes, + object_=object_, + costs=costs, + ) + + +def register_plugins(registry: PluginRegistry) -> None: + registry.register_product_file_reader_with_adapters( + NSLSIIProductFileReader(), + simple_name=NSLSIIProductFileReader.SIMPLE_NAME, + display_name=NSLSIIProductFileReader.DISPLAY_NAME, + ) diff --git a/src/ptychodus/plugins/objectPhaseCentering.py b/src/ptychodus/plugins/objectPhaseCentering.py deleted file mode 100644 index 10c481df..00000000 --- a/src/ptychodus/plugins/objectPhaseCentering.py +++ /dev/null @@ -1,35 +0,0 @@ -import numpy - -from ptychodus.api.object import ObjectArrayType, ObjectPhaseCenteringStrategy -from ptychodus.api.plugins import PluginRegistry - - -class IdentityPhaseCenteringStrategy(ObjectPhaseCenteringStrategy): - def __call__(self, array: ObjectArrayType) -> ObjectArrayType: - return array - - -class CenterBoxMeanPhaseCenteringStrategy(ObjectPhaseCenteringStrategy): - def __call__(self, array: ObjectArrayType) -> ObjectArrayType: - oneThirdHeight = array.shape[-2] // 3 - oneThirdWidth = array.shape[-1] // 3 - - amplitude = numpy.absolute(array) - phase = numpy.angle(array) - - centerBoxMeanPhase = phase[ - oneThirdHeight : oneThirdHeight * 2, oneThirdWidth : oneThirdWidth * 2 - ].mean() - - return amplitude * numpy.exp(1j * (phase - centerBoxMeanPhase)) - - -def registerPlugins(registry: PluginRegistry) -> None: - registry.objectPhaseCenteringStrategies.registerPlugin( - IdentityPhaseCenteringStrategy(), - displayName='Identity', - ) - registry.objectPhaseCenteringStrategies.registerPlugin( - CenterBoxMeanPhaseCenteringStrategy(), - displayName='Center Box Mean', - ) diff --git a/src/ptychodus/plugins/ptychoShelvesProductFile.py b/src/ptychodus/plugins/ptychoShelvesProductFile.py deleted file mode 100644 index c3f08036..00000000 --- a/src/ptychodus/plugins/ptychoShelvesProductFile.py +++ /dev/null @@ -1,141 +0,0 @@ -from pathlib import Path -from typing import Final, Sequence - -import scipy.io - -from ptychodus.api.constants import ( - ELECTRON_VOLT_J, - LIGHT_SPEED_M_PER_S, - PLANCK_CONSTANT_J_PER_HZ, -) -from ptychodus.api.object import Object, ObjectArrayType, ObjectFileWriter -from ptychodus.api.plugins import PluginRegistry -from ptychodus.api.probe import Probe, ProbeFileWriter -from ptychodus.api.product import Product, ProductFileReader, ProductMetadata -from ptychodus.api.propagator import WavefieldArrayType -from ptychodus.api.scan import Scan, ScanPoint - - -class MATProductFileReader(ProductFileReader): - SIMPLE_NAME: Final[str] = 'PtychoShelves' - DISPLAY_NAME: Final[str] = 'PtychoShelves Files (*.mat)' - - def _load_probe_array(self, probeMatrix: WavefieldArrayType) -> WavefieldArrayType: - if probeMatrix.ndim == 4: - # probeMatrix[width, height, num_shared_modes, num_varying_modes] - # TODO support spatially varying probe modes - probeMatrix = probeMatrix[..., 0] - - if probeMatrix.ndim == 3: - # probeMatrix[width, height, num_shared_modes] - probeMatrix = probeMatrix - - return probeMatrix.transpose(2, 0, 1) - - def _load_object_array(self, objectMatrix: ObjectArrayType) -> ObjectArrayType: - if objectMatrix.ndim == 3: - # objectMatrix[width, height, num_layers] - objectMatrix = objectMatrix.transpose(2, 0, 1) - - return objectMatrix - - def read(self, filePath: Path) -> Product: - scanPointList: list[ScanPoint] = list() - - hc_eVm = PLANCK_CONSTANT_J_PER_HZ * LIGHT_SPEED_M_PER_S / ELECTRON_VOLT_J - matDict = scipy.io.loadmat(filePath, simplify_cells=True) - p_struct = matDict['p'] - probe_energy_eV = hc_eVm / p_struct['lambda'] - - metadata = ProductMetadata( - name=filePath.stem, - comments='', - detectorDistanceInMeters=0.0, # not included in file - probeEnergyInElectronVolts=probe_energy_eV, - probePhotonsPerSecond=0.0, # not included in file - exposureTimeInSeconds=0.0, # not included in file - ) - - dx_spec = p_struct['dx_spec'] - pixel_width_m = dx_spec[0] - pixel_height_m = dx_spec[1] - - outputs_struct = matDict['outputs'] - probe_positions = outputs_struct['probe_positions'] - - for idx, pos_px in enumerate(probe_positions): - point = ScanPoint( - idx, - pos_px[0] * pixel_width_m, - pos_px[1] * pixel_height_m, - ) - scanPointList.append(point) - - probe = Probe( - self._load_probe_array(matDict['probe']), - pixelWidthInMeters=pixel_width_m, - pixelHeightInMeters=pixel_height_m, - ) - - layer_distance_m: Sequence[float] | None = None - - try: - multi_slice_param = p_struct['multi_slice_param'] - except KeyError: - pass - else: - try: - z_distance = multi_slice_param['z_distance'] - except KeyError: - pass - else: - layer_distance_m = z_distance.tolist() - - object_ = Object( - self._load_object_array(matDict['object']), - layer_distance_m, - pixelWidthInMeters=pixel_width_m, - pixelHeightInMeters=pixel_height_m, - ) - costs = outputs_struct['fourier_error_out'] - - return Product( - metadata=metadata, - scan=Scan(scanPointList), - probe=probe, - object_=object_, - costs=costs, - ) - - -class MATObjectFileWriter(ObjectFileWriter): - def write(self, filePath: Path, object_: Object) -> None: - array = object_.array - matDict = {'object': array.transpose(1, 2, 0)} - # TODO layer distance to p.z_distance - scipy.io.savemat(filePath, matDict) - - -class MATProbeFileWriter(ProbeFileWriter): - def write(self, filePath: Path, probe: Probe) -> None: - array = probe.array - matDict = {'probe': array.transpose(1, 2, 0)} - scipy.io.savemat(filePath, matDict) - - -def registerPlugins(registry: PluginRegistry) -> None: - registry.productFileReaders.registerPlugin( - MATProductFileReader(), - simpleName=MATProductFileReader.SIMPLE_NAME, - displayName=MATProductFileReader.DISPLAY_NAME, - ) - registry.probeFileWriters.registerPlugin( - MATProbeFileWriter(), - simpleName=MATProductFileReader.SIMPLE_NAME, - displayName=MATProductFileReader.DISPLAY_NAME, - ) - registry.objectFileWriters.registerPlugin( - MATObjectFileWriter(), - simpleName=MATProductFileReader.SIMPLE_NAME, - displayName=MATProductFileReader.DISPLAY_NAME, - ) diff --git a/src/ptychodus/plugins/ptychoShelvesScanFile.py b/src/ptychodus/plugins/ptychoShelvesScanFile.py deleted file mode 100644 index 2759f6eb..00000000 --- a/src/ptychodus/plugins/ptychoShelvesScanFile.py +++ /dev/null @@ -1,44 +0,0 @@ -from pathlib import Path -import logging - -import h5py -import numpy - -from ptychodus.api.plugins import PluginRegistry -from ptychodus.api.scan import Scan, ScanFileReader, ScanPoint, ScanPointParseError - -logger = logging.getLogger(__name__) - - -class PtychoShelvesScanFileReader(ScanFileReader): - def read(self, filePath: Path) -> Scan: - pointList: list[ScanPoint] = list() - - try: - with h5py.File(filePath, 'r') as h5File: - try: - ppX = numpy.squeeze(h5File['/ppX']) - ppY = numpy.squeeze(h5File['/ppY']) - except KeyError: - logger.warning('Unable to find data.') - else: - if ppX.shape == ppY.shape: - logger.debug(f'Coordinate arrays have shape {ppX.shape}.') - else: - raise ScanPointParseError('Coordinate array shape mismatch!') - - for idx, (x, y) in enumerate(zip(ppX, ppY)): - point = ScanPoint(idx, x, y) - pointList.append(point) - except OSError: - logger.warning(f'Unable to read file "{filePath}".') - - return Scan(pointList) - - -def registerPlugins(registry: PluginRegistry) -> None: - registry.scanFileReaders.registerPlugin( - PtychoShelvesScanFileReader(), - simpleName='PtychoShelves', - displayName='PtychoShelves Scan Position Files (*.h5 *.hdf5)', - ) diff --git a/src/ptychodus/plugins/ptychoshelves_position_file.py b/src/ptychodus/plugins/ptychoshelves_position_file.py new file mode 100644 index 00000000..b2cade9e --- /dev/null +++ b/src/ptychodus/plugins/ptychoshelves_position_file.py @@ -0,0 +1,41 @@ +from pathlib import Path +import logging + +import h5py +import numpy + +from ptychodus.api.plugins import PluginRegistry +from ptychodus.api.scan import PositionSequence, PositionFileReader, ScanPoint, ScanPointParseError + +logger = logging.getLogger(__name__) + + +class PtychoShelvesPositionFileReader(PositionFileReader): + def read(self, file_path: Path) -> PositionSequence: + point_list: list[ScanPoint] = list() + + with h5py.File(file_path, 'r') as h5_file: + try: + pp_x = numpy.squeeze(h5_file['/ppX']) + pp_y = numpy.squeeze(h5_file['/ppY']) + except KeyError: + logger.warning('Unable to find data.') + else: + if pp_x.shape == pp_y.shape: + logger.debug(f'Coordinate arrays have shape {pp_x.shape}.') + else: + raise ScanPointParseError('Coordinate array shape mismatch!') + + for idx, (x, y) in enumerate(zip(pp_x, pp_y)): + point = ScanPoint(idx, x, y) + point_list.append(point) + + return PositionSequence(point_list) + + +def register_plugins(registry: PluginRegistry) -> None: + registry.position_file_readers.register_plugin( + PtychoShelvesPositionFileReader(), + simple_name='PtychoShelves', + display_name='PtychoShelves Files (*.h5 *.hdf5)', + ) diff --git a/src/ptychodus/plugins/ptychoshelves_product_file.py b/src/ptychodus/plugins/ptychoshelves_product_file.py new file mode 100644 index 00000000..0ad66c88 --- /dev/null +++ b/src/ptychodus/plugins/ptychoshelves_product_file.py @@ -0,0 +1,115 @@ +from pathlib import Path +from typing import Final, Sequence + +import numpy +import scipy.io + +from ptychodus.api.geometry import PixelGeometry +from ptychodus.api.object import Object +from ptychodus.api.plugins import PluginRegistry +from ptychodus.api.probe import ProbeSequence +from ptychodus.api.product import ( + ELECTRON_VOLT_J, + LIGHT_SPEED_M_PER_S, + PLANCK_CONSTANT_J_PER_HZ, + Product, + ProductFileReader, + ProductMetadata, +) +from ptychodus.api.scan import PositionSequence, ScanPoint + + +class PtychoShelvesProductFileReader(ProductFileReader): + SIMPLE_NAME: Final[str] = 'PtychoShelves' + DISPLAY_NAME: Final[str] = 'PtychoShelves Files (*.mat)' + + def read(self, file_path: Path) -> Product: + point_list: list[ScanPoint] = list() + + hc_eVm = PLANCK_CONSTANT_J_PER_HZ * LIGHT_SPEED_M_PER_S / ELECTRON_VOLT_J # noqa: N806 + mat_dict = scipy.io.loadmat(file_path, simplify_cells=True) + p_struct = mat_dict['p'] + probe_energy_eV = hc_eVm / p_struct['lambda'] # noqa: N806 + + metadata = ProductMetadata( + name=file_path.stem, + comments='', + detector_distance_m=0.0, # not included in file + probe_energy_eV=probe_energy_eV, + probe_photon_count=0.0, # not included in file + exposure_time_s=0.0, # not included in file + mass_attenuation_m2_kg=0.0, # not included in file + tomography_angle_deg=0.0, # not included in file + ) + + dx_spec = p_struct['dx_spec'] + pixel_width_m = dx_spec[0] + pixel_height_m = dx_spec[1] + pixel_geometry = PixelGeometry(width_m=pixel_width_m, height_m=pixel_height_m) + + outputs_struct = mat_dict['outputs'] + probe_positions = outputs_struct['probe_positions'] + + for idx, pos_px in enumerate(probe_positions): + point = ScanPoint( + idx, + pos_px[0] * pixel_width_m, + pos_px[1] * pixel_height_m, + ) + point_list.append(point) + + probe_array = mat_dict['probe'] + + if probe_array.ndim == 3: + # probe_array[height, width, num_shared_modes] + probe_array = probe_array.transpose(2, 0, 1) + elif probe_array.ndim == 4: + # probe_array[height, width, num_shared_modes, num_varying_modes] + probe_array = probe_array.transpose(3, 2, 0, 1) + + probe = ProbeSequence( + array=probe_array, + opr_weights=None, # TODO OPR, if available + pixel_geometry=pixel_geometry, + ) + + object_array = mat_dict['object'] + + if object_array.ndim == 3: + # object_array[height, width, num_layers] + object_array = object_array.transpose(2, 0, 1) + + layer_spacing_m: Sequence[float] = list() + + try: + multi_slice_param = p_struct['multi_slice_param'] + z_distance = multi_slice_param['z_distance'] + except KeyError: + pass + else: + num_spaces = object_array.shape[-3] - 1 + layer_spacing_m = numpy.squeeze(z_distance)[:num_spaces] + + object_ = Object( + array=object_array, + pixel_geometry=pixel_geometry, + center=None, + layer_spacing_m=layer_spacing_m, + ) + costs = outputs_struct['fourier_error_out'] + + return Product( + metadata=metadata, + positions=PositionSequence(point_list), + probes=probe, + object_=object_, + costs=costs, + ) + + +def register_plugins(registry: PluginRegistry) -> None: + registry.register_product_file_reader_with_adapters( + PtychoShelvesProductFileReader(), + simple_name=PtychoShelvesProductFileReader.SIMPLE_NAME, + display_name=PtychoShelvesProductFileReader.DISPLAY_NAME, + ) diff --git a/src/ptychodus/plugins/slac_npz_file.py b/src/ptychodus/plugins/slac_npz_file.py new file mode 100644 index 00000000..d2b6647e --- /dev/null +++ b/src/ptychodus/plugins/slac_npz_file.py @@ -0,0 +1,103 @@ +from pathlib import Path +from typing import Final, Sequence +import logging + +import numpy + +from ptychodus.api.geometry import ImageExtent +from ptychodus.api.object import Object +from ptychodus.api.patterns import ( + DiffractionDataset, + DiffractionFileReader, + DiffractionMetadata, + SimpleDiffractionDataset, + SimpleDiffractionPatternArray, +) +from ptychodus.api.plugins import PluginRegistry +from ptychodus.api.probe import ProbeSequence +from ptychodus.api.product import Product, ProductFileReader, ProductMetadata +from ptychodus.api.scan import PositionSequence, ScanPoint +from ptychodus.api.tree import SimpleTreeNode + +logger = logging.getLogger(__name__) + + +class SLACDiffractionFileReader(DiffractionFileReader): + def read(self, file_path: Path) -> DiffractionDataset: + with numpy.load(file_path) as npz_file: + patterns = numpy.transpose(npz_file['diffraction'], [2, 0, 1]) + + num_patterns, detector_height, detector_width = patterns.shape + + metadata = DiffractionMetadata( + num_patterns_per_array=num_patterns, + num_patterns_total=num_patterns, + pattern_dtype=patterns.dtype, + detector_extent=ImageExtent(detector_width, detector_height), + file_path=file_path, + ) + + contents_tree = SimpleTreeNode.create_root(['Name', 'Type', 'Details']) + contents_tree.create_child( + [file_path.stem, type(patterns).__name__, f'{patterns.dtype}{patterns.shape}'] + ) + + array = SimpleDiffractionPatternArray( + label=file_path.stem, + indexes=numpy.arange(num_patterns), + data=patterns, + ) + + return SimpleDiffractionDataset(metadata, contents_tree, [array]) + + +class SLACProductFileReader(ProductFileReader): + def read(self, file_path: Path) -> Product: + with numpy.load(file_path) as npz_file: + scan_x_m = npz_file['xcoords_start'] + scan_y_m = npz_file['ycoords_start'] + probe_array = npz_file['probeGuess'] + object_array = npz_file['objectGuess'] + + metadata = ProductMetadata( + name=file_path.stem, + comments='', + detector_distance_m=0.0, # not included in file + probe_energy_eV=0.0, # not included in file + probe_photon_count=0.0, # not included in file + exposure_time_s=0.0, # not included in file + mass_attenuation_m2_kg=0.0, # not included in file + tomography_angle_deg=0.0, # not included in file + ) + + point_list: list[ScanPoint] = list() + + for idx, (x_m, y_m) in enumerate(zip(scan_x_m, scan_y_m)): + point = ScanPoint(idx, x_m, y_m) + point_list.append(point) + + costs: Sequence[float] = list() # not included in file + + return Product( + metadata=metadata, + positions=PositionSequence(point_list), + probes=ProbeSequence(array=probe_array, opr_weights=None, pixel_geometry=None), + object_=Object(array=object_array, pixel_geometry=None, center=None), + costs=costs, + ) + + +def register_plugins(registry: PluginRegistry) -> None: + SIMPLE_NAME: Final[str] = 'SLAC' # noqa: N806 + DISPLAY_NAME: Final[str] = 'SLAC NumPy Zipped Archive (*.npz)' # noqa: N806 + + registry.diffraction_file_readers.register_plugin( + SLACDiffractionFileReader(), + simple_name=SIMPLE_NAME, + display_name=DISPLAY_NAME, + ) + registry.register_product_file_reader_with_adapters( + SLACProductFileReader(), + simple_name=SIMPLE_NAME, + display_name=DISPLAY_NAME, + ) diff --git a/src/ptychodus/plugins/tiffDiffractionFile.py b/src/ptychodus/plugins/tiffDiffractionFile.py deleted file mode 100644 index 21a135c2..00000000 --- a/src/ptychodus/plugins/tiffDiffractionFile.py +++ /dev/null @@ -1,124 +0,0 @@ -from collections.abc import Mapping -from pathlib import Path -import logging -import re -import sys - -from tifffile import TiffFile -import numpy - -from ptychodus.api.geometry import ImageExtent -from ptychodus.api.patterns import ( - DiffractionDataset, - DiffractionFileReader, - DiffractionMetadata, - DiffractionPatternArray, - DiffractionPatternArrayType, - DiffractionPatternState, - SimpleDiffractionDataset, -) -from ptychodus.api.plugins import PluginRegistry -from ptychodus.api.tree import SimpleTreeNode - -logger = logging.getLogger(__name__) - - -class TiffDiffractionPatternArray(DiffractionPatternArray): - def __init__(self, filePath: Path, index: int) -> None: - super().__init__() - self._filePath = filePath - self._index = index - self._state = DiffractionPatternState.UNKNOWN - - def getLabel(self) -> str: - return self._filePath.stem - - def getIndex(self) -> int: - return self._index - - def getState(self) -> DiffractionPatternState: - return self._state - - def getData(self) -> DiffractionPatternArrayType: - self._state = DiffractionPatternState.MISSING - - with TiffFile(self._filePath) as tiff: - try: - data = tiff.asarray() - except: - raise - else: - self._state = DiffractionPatternState.FOUND - - if data.ndim == 2: - data = data[numpy.newaxis, :, :] - - return data - - -class TiffDiffractionFileReader(DiffractionFileReader): - def _getFileSeries(self, filePath: Path) -> tuple[Mapping[int, Path], str]: - filePathDict: dict[int, Path] = dict() - - digits = re.findall(r'\d+', filePath.stem) - longest_digits = max(digits, key=len) - filePattern = filePath.name.replace(longest_digits, f'(\\d{{{len(longest_digits)}}})') - - for fp in filePath.parent.iterdir(): - z = re.match(filePattern, fp.name) - - if z: - index = int(z.group(1).lstrip('0')) - filePathDict[index] = fp - - return filePathDict, filePattern - - def read(self, filePath: Path) -> DiffractionDataset: - dataset = SimpleDiffractionDataset.createNullInstance(filePath) - - filePathMapping, filePattern = self._getFileSeries(filePath) - contentsTree = SimpleTreeNode.createRoot(['Name', 'Type', 'Details']) - arrayList: list[DiffractionPatternArray] = list() - - for idx, (_, fp) in enumerate(sorted(filePathMapping.items())): # TODO use keys - array = TiffDiffractionPatternArray(fp, idx) - contentsTree.createChild([array.getLabel(), 'TIFF', str(idx)]) - arrayList.append(array) - - try: - with TiffFile(filePath) as tiff: - data = tiff.asarray() - except OSError: - logger.warning(f'Unable to read file "{filePath}".') - else: - if data.ndim == 2: - data = data[numpy.newaxis, :, :] - - numberOfPatternsPerArray, detectorHeight, detectorWidth = data.shape - - metadata = DiffractionMetadata( - numberOfPatternsPerArray=numberOfPatternsPerArray, - numberOfPatternsTotal=numberOfPatternsPerArray * len(arrayList), - patternDataType=data.dtype, - detectorExtent=ImageExtent(detectorWidth, detectorHeight), - filePath=filePath.parent / filePattern, - ) - - dataset = SimpleDiffractionDataset(metadata, contentsTree, arrayList) - - return dataset - - -def registerPlugins(registry: PluginRegistry) -> None: - registry.diffractionFileReaders.registerPlugin( - TiffDiffractionFileReader(), - simpleName='TIFF', - displayName='Tagged Image File Format Files (*.tif *.tiff)', - ) - - -if __name__ == '__main__': - filePath = Path(sys.argv[1]) - reader = TiffDiffractionFileReader() - tiffFile = reader.read(filePath) - print(tiffFile) diff --git a/src/ptychodus/plugins/tiff_diffraction_file.py b/src/ptychodus/plugins/tiff_diffraction_file.py new file mode 100644 index 00000000..fb1a1a01 --- /dev/null +++ b/src/ptychodus/plugins/tiff_diffraction_file.py @@ -0,0 +1,113 @@ +from collections.abc import Mapping +from pathlib import Path +import logging +import re +import sys + +from tifffile import TiffFile +import numpy + +from ptychodus.api.geometry import ImageExtent +from ptychodus.api.patterns import ( + DiffractionDataset, + DiffractionFileReader, + DiffractionMetadata, + DiffractionPatternArray, + PatternDataType, + PatternIndexesType, + SimpleDiffractionDataset, +) +from ptychodus.api.plugins import PluginRegistry +from ptychodus.api.tree import SimpleTreeNode + +logger = logging.getLogger(__name__) + + +class TiffDiffractionPatternArray(DiffractionPatternArray): + def __init__(self, file_path: Path, index: int) -> None: + super().__init__() + self._file_path = file_path + self._indexes = numpy.array([index]) + + def get_label(self) -> str: + return self._file_path.stem + + def get_indexes(self) -> PatternIndexesType: + return self._indexes + + def get_data(self) -> PatternDataType: + with TiffFile(self._file_path) as tiff: + data = tiff.asarray() + + if data.ndim == 2: + data = data[numpy.newaxis, :, :] + + return data + + +class TiffDiffractionFileReader(DiffractionFileReader): + def _get_file_series(self, file_path: Path) -> tuple[Mapping[int, Path], str]: + file_path_dict: dict[int, Path] = dict() + + digits = re.findall(r'\d+', file_path.stem) + longest_digits = max(digits, key=len) + file_pattern = file_path.name.replace(longest_digits, f'(\\d{{{len(longest_digits)}}})') + + for fp in file_path.parent.iterdir(): + z = re.match(file_pattern, fp.name) + + if z: + index = int(z.group(1).lstrip('0')) + file_path_dict[index] = fp + + return file_path_dict, file_pattern + + def read(self, file_path: Path) -> DiffractionDataset: + dataset = SimpleDiffractionDataset.create_null(file_path) + + file_path_mapping, file_pattern = self._get_file_series(file_path) + contents_tree = SimpleTreeNode.create_root(['Name', 'Type', 'Details']) + array_list: list[DiffractionPatternArray] = list() + + for idx, (_, fp) in enumerate(sorted(file_path_mapping.items())): # TODO use keys + array = TiffDiffractionPatternArray(fp, idx) + contents_tree.create_child([array.get_label(), 'TIFF', str(idx)]) + array_list.append(array) + + try: + with TiffFile(file_path) as tiff: + data = tiff.asarray() + except OSError: + logger.warning(f'Unable to read file "{file_path}".') + else: + if data.ndim == 2: + data = data[numpy.newaxis, :, :] + + num_patterns_per_array, detector_height, detector_width = data.shape + + metadata = DiffractionMetadata( + num_patterns_per_array=num_patterns_per_array, + num_patterns_total=num_patterns_per_array * len(array_list), + pattern_dtype=data.dtype, + detector_extent=ImageExtent(detector_width, detector_height), + file_path=file_path.parent / file_pattern, + ) + + dataset = SimpleDiffractionDataset(metadata, contents_tree, array_list) + + return dataset + + +def register_plugins(registry: PluginRegistry) -> None: + registry.diffraction_file_readers.register_plugin( + TiffDiffractionFileReader(), + simple_name='TIFF', + display_name='Tagged Image File Format Files (*.tif *.tiff)', + ) + + +if __name__ == '__main__': + file_path = Path(sys.argv[1]) + reader = TiffDiffractionFileReader() + tiff_file = reader.read(file_path) + print(tiff_file) diff --git a/src/ptychodus/plugins/upscaling.py b/src/ptychodus/plugins/upscaling.py index 2a2cc9c7..0594f67d 100644 --- a/src/ptychodus/plugins/upscaling.py +++ b/src/ptychodus/plugins/upscaling.py @@ -16,17 +16,17 @@ def __init__(self, method: str) -> None: self._method = method def __call__(self, emap: ElementMap, product: Product) -> ElementMap: - objectGeometry = product.object_.getGeometry() - scanCoordinatesInPixels: list[float] = list() + object_geometry = product.object_.get_geometry() + scan_coords_px: list[float] = list() - for scanPoint in product.scan: - objectPoint = objectGeometry.mapScanPointToObjectPoint(scanPoint) - scanCoordinatesInPixels.append(objectPoint.positionYInPixels) - scanCoordinatesInPixels.append(objectPoint.positionXInPixels) + for scan_point in product.positions: + object_point = object_geometry.map_scan_point_to_object_point(scan_point) + scan_coords_px.append(object_point.position_y_px) + scan_coords_px.append(object_point.position_x_px) - points = numpy.reshape(scanCoordinatesInPixels, (-1, 2)) + points = numpy.reshape(scan_coords_px, (-1, 2)) values = emap.counts_per_second.flat - YY, XX = numpy.mgrid[: objectGeometry.heightInPixels, : objectGeometry.widthInPixels] + YY, XX = numpy.mgrid[: object_geometry.height_px, : object_geometry.width_px] # noqa: N806 query_points = numpy.transpose((YY.flat, XX.flat)) cps = griddata(points, values, query_points, method=self._method, fill_value=0.0).reshape( @@ -51,77 +51,77 @@ def __init__( self._degree = degree def __call__(self, emap: ElementMap, product: Product) -> ElementMap: - objectGeometry = product.object_.getGeometry() - scanCoordinatesInPixels: list[float] = list() + object_geometry = product.object_.get_geometry() + scan_coords_px: list[float] = list() - for scanPoint in product.scan: - objectPoint = objectGeometry.mapScanPointToObjectPoint(scanPoint) - scanCoordinatesInPixels.append(objectPoint.positionYInPixels) - scanCoordinatesInPixels.append(objectPoint.positionXInPixels) + for scan_point in product.positions: + object_point = object_geometry.map_scan_point_to_object_point(scan_point) + scan_coords_px.append(object_point.position_y_px) + scan_coords_px.append(object_point.position_x_px) interpolator = RBFInterpolator( - numpy.reshape(scanCoordinatesInPixels, (-1, 2)), + numpy.reshape(scan_coords_px, (-1, 2)), emap.counts_per_second.flat, kernel=self._kernel, neighbors=self._neighbors, epsilon=self._epsilon, degree=self._degree, ) - YY, XX = numpy.mgrid[: objectGeometry.heightInPixels, : objectGeometry.widthInPixels] + YY, XX = numpy.mgrid[: object_geometry.height_px, : object_geometry.width_px] # noqa: N806 cps = interpolator(numpy.transpose((YY.flat, XX.flat))) return ElementMap(emap.name, cps.astype(emap.counts_per_second.dtype).reshape(XX.shape)) -def registerPlugins(registry: PluginRegistry) -> None: +def register_plugins(registry: PluginRegistry) -> None: # TODO natural neighbor # TODO kriging # TODO inverse distance weighting - registry.upscalingStrategies.registerPlugin( + registry.upscaling_strategies.register_plugin( IdentityUpscaling(), - displayName='Identity', + display_name='Identity', ) - registry.upscalingStrategies.registerPlugin( + registry.upscaling_strategies.register_plugin( GridDataUpscaling('nearest'), - displayName='Nearest Neighbor', + display_name='Nearest Neighbor', ) - registry.upscalingStrategies.registerPlugin( + registry.upscaling_strategies.register_plugin( GridDataUpscaling('linear'), - displayName='Linear', + display_name='Linear', ) - registry.upscalingStrategies.registerPlugin( + registry.upscaling_strategies.register_plugin( GridDataUpscaling('cubic'), - displayName='Cubic', + display_name='Cubic', ) - registry.upscalingStrategies.registerPlugin( + registry.upscaling_strategies.register_plugin( RadialBasisFunctionUpscaling('linear'), - displayName='Linear RBF', + display_name='Linear RBF', ) - registry.upscalingStrategies.registerPlugin( + registry.upscaling_strategies.register_plugin( RadialBasisFunctionUpscaling('thin_plate_spline'), - displayName='Thin Plate Spline RBF', + display_name='Thin Plate Spline RBF', ) - registry.upscalingStrategies.registerPlugin( + registry.upscaling_strategies.register_plugin( RadialBasisFunctionUpscaling('cubic'), - displayName='Cubic RBF', + display_name='Cubic RBF', ) - registry.upscalingStrategies.registerPlugin( + registry.upscaling_strategies.register_plugin( RadialBasisFunctionUpscaling('quintic'), - displayName='Quintic RBF', + display_name='Quintic RBF', ) - registry.upscalingStrategies.registerPlugin( + registry.upscaling_strategies.register_plugin( RadialBasisFunctionUpscaling('multiquadric'), - displayName='Multiquadric RBF', + display_name='Multiquadric RBF', ) - registry.upscalingStrategies.registerPlugin( + registry.upscaling_strategies.register_plugin( RadialBasisFunctionUpscaling('inverse_multiquadric'), - displayName='Inverse Multiquadric RBF', + display_name='Inverse Multiquadric RBF', ) - registry.upscalingStrategies.registerPlugin( + registry.upscaling_strategies.register_plugin( RadialBasisFunctionUpscaling('inverse_quadratic'), - displayName='Inverse Quadratic RBF', + display_name='Inverse Quadratic RBF', ) - registry.upscalingStrategies.registerPlugin( + registry.upscaling_strategies.register_plugin( RadialBasisFunctionUpscaling('gaussian'), - displayName='Gaussian RBF', + display_name='Gaussian RBF', ) diff --git a/src/ptychodus/plugins/workflow.py b/src/ptychodus/plugins/workflow.py index 6e20e0d8..7619023b 100644 --- a/src/ptychodus/plugins/workflow.py +++ b/src/ptychodus/plugins/workflow.py @@ -12,63 +12,63 @@ class PtychodusAutoloadProductFileBasedWorkflow(FileBasedWorkflow): @property - def isWatchRecursive(self) -> bool: + def is_watch_recursive(self) -> bool: return True - def getWatchFilePattern(self) -> str: + def get_watch_file_pattern(self) -> str: return 'product-out.npz' - def execute(self, workflowAPI: WorkflowAPI, filePath: Path) -> None: - workflowAPI.openProduct(filePath, fileType='NPZ') + def execute(self, api: WorkflowAPI, file_path: Path) -> None: + api.open_product(file_path, file_type='NPZ') class APS2IDFileBasedWorkflow(FileBasedWorkflow): @property - def isWatchRecursive(self) -> bool: + def is_watch_recursive(self) -> bool: return False - def getWatchFilePattern(self) -> str: + def get_watch_file_pattern(self) -> str: return '*.csv' - def execute(self, workflowAPI: WorkflowAPI, filePath: Path) -> None: - scanName = filePath.stem - scanID = int(re.findall(r'\d+', scanName)[-1]) + def execute(self, api: WorkflowAPI, file_path: Path) -> None: + scan_name = file_path.stem + scan_id = int(re.findall(r'\d+', scan_name)[-1]) - diffractionFilePath = filePath.parents[1] / 'raw_data' / f'scan{scanID}_master.h5' - workflowAPI.openPatterns(diffractionFilePath, fileType='NeXus') - productAPI = workflowAPI.createProduct(f'scan{scanID}') - productAPI.openScan(filePath, fileType='CSV') - productAPI.buildProbe() - productAPI.buildObject() - productAPI.reconstructRemote() + diffraction_file_path = file_path.parents[1] / 'raw_data' / f'scan{scan_id}_master.h5' + api.open_patterns(diffraction_file_path, file_type='NeXus') + product_api = api.create_product(f'scan{scan_id}') + product_api.open_scan(file_path, file_type='CSV') + product_api.build_probe() + product_api.build_object() + product_api.reconstruct_remote() class APS26IDFileBasedWorkflow(FileBasedWorkflow): @property - def isWatchRecursive(self) -> bool: + def is_watch_recursive(self) -> bool: return False - def getWatchFilePattern(self) -> str: + def get_watch_file_pattern(self) -> str: return '*.mda' - def execute(self, workflowAPI: WorkflowAPI, filePath: Path) -> None: - scanName = filePath.stem - scanID = int(re.findall(r'\d+', scanName)[-1]) + def execute(self, api: WorkflowAPI, file_path: Path) -> None: + scan_name = file_path.stem + scan_id = int(re.findall(r'\d+', scan_name)[-1]) - diffractionDirPath = filePath.parents[1] / 'h5' + diffraction_dir_path = file_path.parents[1] / 'h5' - for diffractionFilePath in diffractionDirPath.glob(f'scan_{scanID}_*.h5'): - digits = int(re.findall(r'\d+', diffractionFilePath.stem)[-1]) + for diffraction_file_path in diffraction_dir_path.glob(f'scan_{scan_id}_*.h5'): + digits = int(re.findall(r'\d+', diffraction_file_path.stem)[-1]) if digits != 0: break - workflowAPI.openPatterns(diffractionFilePath, fileType='HDF5') - productAPI = workflowAPI.createProduct(f'scan_{scanID}') - productAPI.openScan(filePath, fileType='MDA') - productAPI.buildProbe() - productAPI.buildObject() - productAPI.reconstructRemote() + api.open_patterns(diffraction_file_path, file_type='HDF5') + product_api = api.create_product(f'scan_{scan_id}') + product_api.open_scan(file_path, file_type='MDA') + product_api.build_probe() + product_api.build_object() + product_api.reconstruct_remote() @dataclass(frozen=True) @@ -94,23 +94,23 @@ def __str__(self) -> str: class APS31IDEFileBasedWorkflow(FileBasedWorkflow): @property - def isWatchRecursive(self) -> bool: + def is_watch_recursive(self) -> bool: return True - def getWatchFilePattern(self) -> str: + def get_watch_file_pattern(self) -> str: return '*.h5' - def execute(self, workflowAPI: WorkflowAPI, filePath: Path) -> None: - experimentDir = filePath.parents[3] - scan_no = int(re.findall(r'\d+', filePath.stem)[0]) - scanFile = experimentDir / 'scan_positions' / f'scan_{scan_no:05d}.dat' - scanNumbersFile = experimentDir / 'dat-files' / 'tomography_scannumbers.txt' + def execute(self, api: WorkflowAPI, file_path: Path) -> None: + experiment_dir = file_path.parents[3] + scan_num = int(re.findall(r'\d+', file_path.stem)[0]) + scan_file = experiment_dir / 'scan_positions' / f'scan_{scan_num:05d}.dat' + scan_numbers_file = experiment_dir / 'dat-files' / 'tomography_scannumbers.txt' metadata: APS31IDEMetadata | None = None - with scanNumbersFile.open(newline='') as csvFile: - csvReader = csv.reader(csvFile, delimiter=' ') + with scan_numbers_file.open(newline='') as csv_file: + csv_reader = csv.reader(csv_file, delimiter=' ') - for row in csvReader: + for row in csv_reader: if row[0].startswith('#'): continue @@ -126,9 +126,9 @@ def execute(self, workflowAPI: WorkflowAPI, filePath: Path) -> None: logger.debug(row[0]) continue - if row_no == scan_no: + if row_no == scan_num: metadata = APS31IDEMetadata( - scan_no=scan_no, + scan_no=scan_num, golden_angle=str(row[1]), encoder_angle=str(row[2]), measurement_id=str(row[3]), @@ -141,37 +141,37 @@ def execute(self, workflowAPI: WorkflowAPI, filePath: Path) -> None: if metadata is None: logger.warning(f'Failed to locate label for {row_no}!') else: - productName = f'scan{scan_no:05d}_' + metadata.label - workflowAPI.openPatterns(filePath, fileType='LYNX') - inputProductAPI = workflowAPI.createProduct(productName, comments=str(metadata)) - inputProductAPI.openScan(scanFile, fileType='LYNXOrchestra') - inputProductAPI.buildProbe() - inputProductAPI.buildObject() + product_name = f'scan{scan_num:05d}_' + metadata.label + api.open_patterns(file_path, file_type='LYNX') + input_product_api = api.create_product(product_name, comments=str(metadata)) + input_product_api.open_scan(scan_file, file_type='LYNXOrchestra') + input_product_api.build_probe() + input_product_api.build_object() # TODO would prefer to write instructions and submit to queue - outputProductAPI = inputProductAPI.reconstructLocal(f'{productName}_out') - outputProductAPI.saveProduct( - experimentDir / 'ptychodus' / f'{productName}.h5', fileType='HDF5' + output_product_api = input_product_api.reconstruct_local() + output_product_api.save_product( + experiment_dir / 'ptychodus' / f'{product_name}.h5', file_type='HDF5' ) -def registerPlugins(registry: PluginRegistry) -> None: - registry.fileBasedWorkflows.registerPlugin( +def register_plugins(registry: PluginRegistry) -> None: + registry.file_based_workflows.register_plugin( PtychodusAutoloadProductFileBasedWorkflow(), - simpleName='Autoload_Product', - displayName='Autoload Product', + simple_name='Autoload_Product', + display_name='Autoload Product', ) - registry.fileBasedWorkflows.registerPlugin( + registry.file_based_workflows.register_plugin( APS2IDFileBasedWorkflow(), - simpleName='APS_2ID', - displayName='APS 2-ID', + simple_name='APS_2ID', + display_name='APS 2-ID', ) - registry.fileBasedWorkflows.registerPlugin( + registry.file_based_workflows.register_plugin( APS26IDFileBasedWorkflow(), - simpleName='APS_26ID', - displayName='APS 26-ID', + simple_name='APS_26ID', + display_name='APS 26-ID', ) - registry.fileBasedWorkflows.registerPlugin( + registry.file_based_workflows.register_plugin( APS31IDEFileBasedWorkflow(), - simpleName='APS_31IDE', - displayName='APS 31-ID-E', + simple_name='APS_31IDE', + display_name='APS 31-ID-E', ) diff --git a/src/ptychodus/plugins/xrfMapsFile.py b/src/ptychodus/plugins/xrf_maps_file.py similarity index 64% rename from src/ptychodus/plugins/xrfMapsFile.py rename to src/ptychodus/plugins/xrf_maps_file.py index 8234cd00..f78ce7f0 100644 --- a/src/ptychodus/plugins/xrfMapsFile.py +++ b/src/ptychodus/plugins/xrf_maps_file.py @@ -23,25 +23,25 @@ def _split_path(data_path: str) -> tuple[str, str]: parts = data_path.split('/') return '/'.join(parts[:-1]), parts[-1] - def read(cls, filePath: Path) -> FluorescenceDataset: + def read(self, file_path: Path) -> FluorescenceDataset: element_maps: list[ElementMap] = list() counts_per_second_path = str() channel_names_path = str() - with h5py.File(filePath, 'r') as h5file: + with h5py.File(file_path, 'r') as h5_file: # try to see if v10 layout, Non Negative Lease squares fitting tech was used - h5_counts_per_second = h5file['/MAPS/XRF_Analyzed/NNLS/Counts_Per_Sec'] - h5_channel_names = h5file['/MAPS/XRF_Analyzed/NNLS/Channel_Names'] + h5_counts_per_second = h5_file['/MAPS/XRF_Analyzed/NNLS/Counts_Per_Sec'] + h5_channel_names = h5_file['/MAPS/XRF_Analyzed/NNLS/Channel_Names'] if h5_counts_per_second is None: # try to see if v10 layout, iterative matrix fitting tech was used - h5_counts_per_second = h5file['/MAPS/XRF_Analyzed/Fitted/Counts_Per_Sec'] - h5_channel_names = h5file['/MAPS/XRF_Analyzed/Fitted/Channel_Names'] + h5_counts_per_second = h5_file['/MAPS/XRF_Analyzed/Fitted/Counts_Per_Sec'] + h5_channel_names = h5_file['/MAPS/XRF_Analyzed/Fitted/Channel_Names'] if h5_counts_per_second is None: # try to see if was saved in v9 layout - h5_counts_per_second = h5file['/MAPS/XRF_fits'] - h5_channel_names = h5file['/MAPS/channel_names'] + h5_counts_per_second = h5_file['/MAPS/XRF_fits'] + h5_channel_names = h5_file['/MAPS/channel_names'] if h5_counts_per_second is not None: # Counts_Per_Sec is an N x H x W @@ -64,7 +64,7 @@ def read(cls, filePath: Path) -> FluorescenceDataset: channel_names_path=channel_names_path, ) - def write(self, filePath: Path, dataset: FluorescenceDataset) -> None: + def write(self, file_path: Path, dataset: FluorescenceDataset) -> None: channel_names: list[str] = list() counts_per_sec: list[RealArrayType] = list() @@ -75,34 +75,34 @@ def write(self, filePath: Path, dataset: FluorescenceDataset) -> None: cps_group_path, cps_dataset_name = self._split_path(dataset.counts_per_second_path) ch_group_path, ch_dataset_name = self._split_path(dataset.channel_names_path) - with h5py.File(filePath, 'w') as h5file: - cps_group = h5file.require_group(cps_group_path) + with h5py.File(file_path, 'w') as h5_file: + cps_group = h5_file.require_group(cps_group_path) cps_group.create_dataset(cps_dataset_name, data=numpy.stack(counts_per_sec)) - ch_group = h5file.require_group(ch_group_path) + ch_group = h5_file.require_group(ch_group_path) ch_group.create_dataset(ch_dataset_name, data=channel_names, dtype='S256') class NPZFluorescenceFileWriter(FluorescenceFileWriter): - def write(self, filePath: Path, dataset: FluorescenceDataset) -> None: + def write(self, file_path: Path, dataset: FluorescenceDataset) -> None: element_maps = {emap.name: emap.counts_per_second for emap in dataset.element_maps} - numpy.savez(filePath, **element_maps) + numpy.savez(file_path, **element_maps) -def registerPlugins(registry: PluginRegistry) -> None: - xrfMapsFileIO = XRFMapsFileIO() +def register_plugins(registry: PluginRegistry) -> None: + xrf_maps_file_io = XRFMapsFileIO() - registry.fluorescenceFileReaders.registerPlugin( - xrfMapsFileIO, - simpleName=XRFMapsFileIO.SIMPLE_NAME, - displayName=XRFMapsFileIO.DISPLAY_NAME, + registry.fluorescence_file_readers.register_plugin( + xrf_maps_file_io, + simple_name=XRFMapsFileIO.SIMPLE_NAME, + display_name=XRFMapsFileIO.DISPLAY_NAME, ) - registry.fluorescenceFileWriters.registerPlugin( - xrfMapsFileIO, - simpleName=XRFMapsFileIO.SIMPLE_NAME, - displayName=XRFMapsFileIO.DISPLAY_NAME, + registry.fluorescence_file_writers.register_plugin( + xrf_maps_file_io, + simple_name=XRFMapsFileIO.SIMPLE_NAME, + display_name=XRFMapsFileIO.DISPLAY_NAME, ) - registry.fluorescenceFileWriters.registerPlugin( + registry.fluorescence_file_writers.register_plugin( NPZFluorescenceFileWriter(), - simpleName='NPZ', - displayName='NumPy Zipped Archive (*.npz)', + simple_name='NPZ', + display_name='NumPy Zipped Archive (*.npz)', ) diff --git a/src/ptychodus/ptychodusAdImageProcessor.py b/src/ptychodus/ptychodusAdImageProcessor.py deleted file mode 100644 index a9b9c012..00000000 --- a/src/ptychodus/ptychodusAdImageProcessor.py +++ /dev/null @@ -1,188 +0,0 @@ -from pathlib import Path -from typing import Any -import logging -import threading -import time - -import numpy - -from pvapy.hpc.adImageProcessor import AdImageProcessor -from pvapy.utility.floatWithUnits import FloatWithUnits -from pvapy.utility.timeUtility import TimeUtility -import pvaccess -import pvapy - -from ptychodus.model import ModelCore -import ptychodus - - -class ReconstructionThread(threading.Thread): - def __init__( - self, - ptychodus: ModelCore, - inputProductPath: Path, - outputProductPath: Path, - reconstructPV: str, - ) -> None: - super().__init__() - self._ptychodus = ptychodus - self._inputProductPath = inputProductPath - self._outputProductPath = outputProductPath - self._channel = pvapy.Channel(reconstructPV, pvapy.CA) - self._reconstructEvent = threading.Event() - self._stopEvent = threading.Event() - - self._channel.subscribe('reconstructor', self._monitor) - self._channel.startMonitor() - - def run(self) -> None: - while not self._stopEvent.is_set(): - if self._reconstructEvent.wait(timeout=1.0): - logging.debug('ReconstructionThread: Begin assembling scan positions') - self._ptychodus.finalizeStreamingWorkflow() - logging.debug('ReconstructionThread: End assembling scan positions') - self._ptychodus.batchModeExecute( - 'reconstruct', self._inputProductPath, self._outputProductPath - ) - self._reconstructEvent.clear() - # reconstruction done; indicate that results are ready - self._channel.put(0) - - def _monitor(self, pvObject: pvaccess.PvObject) -> None: - # NOTE caput bdpgp:gp:bit3 1 - logging.debug(f'ReconstructionThread::monitor {pvObject}') - - if pvObject['value']['index'] == 1: - logging.debug('ReconstructionThread: Reconstruct PV triggered!') - # start reconstructing - self._reconstructEvent.set() - else: - logging.debug('ReconstructionThread: Reconstruct PV not triggered!') - - def stop(self) -> None: - self._stopEvent.set() - - -class PtychodusAdImageProcessor(AdImageProcessor): - def __init__(self, configDict: dict[str, Any] = {}) -> None: - super().__init__(configDict) - - self.logger.debug(f'{ptychodus.__name__.title()} ({ptychodus.__version__})') - - settingsFile = configDict.get('settingsFile') - self._ptychodus = ModelCore(settingsFile) - self._reconstructionThread = ReconstructionThread( - self._ptychodus, - Path(configDict.get('inputProductPath', 'input.npz')), - Path(configDict.get('outputProductPath', 'output.npz')), - configDict.get('reconstructPV', 'bdpgp:gp:bit3'), - ) - self._posXPV = configDict.get('posXPV', 'bluesky:pos_x') - self._posYPV = configDict.get('posYPV', 'bluesky:pos_y') - self._nFramesProcessed = 0 - self._processingTime = 0.0 - - def start(self) -> None: - """Called at startup""" - self._ptychodus.__enter__() - self._reconstructionThread.start() - - def stop(self) -> None: - """Called at shutdown""" - self._reconstructionThread.stop() - self._reconstructionThread.join() - self._ptychodus.__exit__(None, None, None) - - def configure(self, configDict: dict[str, Any]) -> None: - """Configures user processor""" - numberOfPatternsTotal = configDict['nPatternsTotal'] - numberOfPatternsPerArray = configDict.get('nPatternsPerArray', 1) - patternDataType = configDict.get('PatternDataType', 'uint16') - - metadata = ptychodus.api.patterns.DiffractionMetadata( - numberOfPatternsPerArray=int(numberOfPatternsPerArray), - numberOfPatternsTotal=int(numberOfPatternsTotal), - patternDataType=numpy.dtype(patternDataType), - ) - self._ptychodus.initializeStreamingWorkflow(metadata) - - def process(self, pvObject: pvaccess.PvObject) -> pvaccess.PvObject: - """Processes monitor update""" - processingBeginTime = time.time() - - (frameId, image, nx, ny, nz, colorMode, fieldKey) = self.reshapeNtNdArray(pvObject) - frameTimeStamp = TimeUtility.getTimeStampAsFloat(pvObject['timeStamp']) - - if nx is None: - self.logger.debug(f'Frame id {frameId} contains an empty image.') - else: - self.logger.debug(f'Frame id {frameId} time stamp {frameTimeStamp}') - image3d = image[numpy.newaxis, :, :].copy() - array = ptychodus.api.patterns.SimpleDiffractionPatternArray( - label=f'Frame{frameId}', - index=frameId, - data=image3d, - state=ptychodus.api.patterns.DiffractionPatternState.LOADED, - ) - self._ptychodus.assembleDiffractionPattern(array, frameTimeStamp) - - posXQueue = self.metadataQueueMap[self._posXPV] - - while True: - try: - posX = posXQueue.get(0) - except pvaccess.QueueEmpty: - break - else: - self._ptychodus.assembleScanPositionsX( - posX['values'], - [TimeUtility.getTimeStampAsFloat(ts) for ts in posX['t']], - ) - - posYQueue = self.metadataQueueMap[self._posYPV] - - while True: - try: - posY = posYQueue.get(0) - except pvaccess.QueueEmpty: - break - else: - self._ptychodus.assembleScanPositionsY( - posY['values'], - [TimeUtility.getTimeStampAsFloat(ts) for ts in posY['t']], - ) - - processingEndTime = time.time() - self.processingTime += processingEndTime - processingBeginTime - self.nFramesProcessed += 1 - - return pvObject - - def resetStats(self) -> None: - """Resets statistics for user processor""" - self.nFramesProcessed = 0 - self.processingTime = 0.0 - - def getStats(self) -> dict[str, Any]: - """Retrieves statistics for user processor""" - nFramesQueued = self._ptychodus.getDiffractionPatternAssemblyQueueSize() - processedFrameRate = 0.0 - - if self.processingTime > 0.0: - processedFrameRate = self.nFramesProcessed / self.processingTime - - return { - 'nFramesProcessed': self.nFramesProcessed, - 'nFramesQueued': nFramesQueued, - 'processingTime': FloatWithUnits(self.processingTime, 's'), - 'processedFrameRate': FloatWithUnits(processedFrameRate, 'fps'), - } - - def getStatsPvaTypes(self) -> dict[str, pvaccess.ScalarType]: - """Defines PVA types for different stats variables""" - return { - 'nFramesProcessed': pvaccess.UINT, - 'nFramesQueued': pvaccess.UINT, - 'processingTime': pvaccess.DOUBLE, - 'processingFrameRate': pvaccess.DOUBLE, - } diff --git a/src/ptychodus/ptychodusFuncXPolarisConfig.py b/src/ptychodus/ptychodusFuncXPolarisConfig.py deleted file mode 100644 index 6e307a2d..00000000 --- a/src/ptychodus/ptychodusFuncXPolarisConfig.py +++ /dev/null @@ -1,42 +0,0 @@ -from parsl.addresses import address_by_interface -from parsl.launchers import SingleNodeLauncher -from parsl.providers import PBSProProvider - -from funcx_endpoint.endpoint.utils.config import Config -from funcx_endpoint.executors import HighThroughputExecutor -from funcx_endpoint.strategies import SimpleStrategy - -# PLEASE UPDATE user_opts BEFORE USE -user_opts = { - 'polaris': { - # Node setup: activate necessary conda environment and such. - 'worker_init': 'source ~/miniconda3/etc/profile.d/conda.sh; conda activate ptychodus', - 'scheduler_options': '#PBS -l filesystems=home:grand:eagle\n#PBS -k doe', - # ALCF allocation to use - 'account': 'APSDataAnalysis', - } -} - -config = Config( - executors=[ - HighThroughputExecutor( - max_workers_per_node=1, - strategy=SimpleStrategy(max_idletime=300), - address=address_by_interface('bond0'), - provider=PBSProProvider( - launcher=SingleNodeLauncher(), - account=user_opts['polaris']['account'], - queue='preemptable', - cpus_per_node=32, - select_options='ngpus=4', - worker_init=user_opts['polaris']['worker_init'], - scheduler_options=user_opts['polaris']['scheduler_options'], - walltime='01:00:00', - nodes_per_block=1, - init_blocks=0, - min_blocks=0, - max_blocks=2, - ), - ) - ], -) diff --git a/src/ptychodus/ptychodus_bdp.py b/src/ptychodus/ptychodus_bdp.py index b08eb0d5..79bee302 100755 --- a/src/ptychodus/ptychodus_bdp.py +++ b/src/ptychodus/ptychodus_bdp.py @@ -14,7 +14,7 @@ logger = logging.getLogger(__name__) -def versionString() -> str: +def version_string() -> str: return f'{ptychodus.__name__.title()} ({ptychodus.__version__})' @@ -32,9 +32,9 @@ def __call__(self, string: str) -> Path: def main() -> int: - changePathPrefix: PathPrefixChange | None = None - cropCenter: CropCenter | None = None - cropExtent: ImageExtent | None = None + change_path_prefix: PathPrefixChange | None = None + crop_center: CropCenter | None = None + crop_extent: ImageExtent | None = None prog = Path(__file__).stem.lower() parser = argparse.ArgumentParser( @@ -125,9 +125,9 @@ def main() -> int: type=float, ) parser.add_argument( - '--probe_photon_flux_Hz', - metavar='FLUX', - help='probe number of photons per second', + '--probe_photon_count', + metavar='NUMBER', + help='probe number of photons', type=float, ) parser.add_argument( @@ -164,31 +164,31 @@ def main() -> int: '-v', '--version', action='version', - version=versionString(), + version=version_string(), ) args = parser.parse_args() if args.local_path_prefix is not None and args.remote_path_prefix is not None: - changePathPrefix = PathPrefixChange( - findPathPrefix=args.local_path_prefix, - replacementPathPrefix=args.remote_path_prefix, + change_path_prefix = PathPrefixChange( + find_path_prefix=args.local_path_prefix, + replacement_path_prefix=args.remote_path_prefix, ) elif bool(args.local_path_prefix) ^ bool(args.remote_path_prefix): - parser.error('--local_path_prefix and --remote_path_prefix' 'must be given together.') + parser.error('--local_path_prefix and --remote_path_prefix must be given together.') if args.crop_center_x_px is not None and args.crop_center_y_px is not None: - cropCenter = CropCenter( - positionXInPixels=args.crop_center_x_px, - positionYInPixels=args.crop_center_y_px, + crop_center = CropCenter( + position_x_px=args.crop_center_x_px, + position_y_px=args.crop_center_y_px, ) elif bool(args.crop_center_x_px) ^ bool(args.crop_center_y_px): parser.error('--crop_center_x_px and --crop_center_y_px must be given together.') if args.crop_width_px is not None and args.crop_height_px is not None: - cropExtent = ImageExtent( - widthInPixels=args.crop_width_px, - heightInPixels=args.crop_height_px, + crop_extent = ImageExtent( + width_px=args.crop_width_px, + height_px=args.crop_height_px, ) elif bool(args.crop_width_px) ^ bool(args.crop_height_px): parser.error('--crop_width_px and --crop_height_px must be given together.') @@ -199,30 +199,30 @@ def main() -> int: if args.number_of_gpus is not None: logger.warning('Number of GPUs is not implemented yet!') # TODO - with ModelCore(Path(args.settings.name), isDeveloperModeEnabled=args.dev) as model: - model.workflowAPI.openPatterns( + with ModelCore(Path(args.settings.name), is_developer_mode_enabled=args.dev) as model: + model.workflow_api.open_patterns( Path(args.patterns_file_path.name), - cropCenter=cropCenter, - cropExtent=cropExtent, + crop_center=crop_center, + crop_extent=crop_extent, ) - workflowProductAPI = model.workflowAPI.createProduct( + workflow_product_api = model.workflow_api.create_product( name=args.name, comments=args.comment, - detectorDistanceInMeters=args.detector_distance_m, - probeEnergyInElectronVolts=args.probe_energy_eV, - probePhotonsPerSecond=args.probe_photon_flux_Hz, - exposureTimeInSeconds=args.exposure_time_s, + detector_distance_m=args.detector_distance_m, + probe_energy_eV=args.probe_energy_eV, + probe_photon_count=args.probe_photon_count, + exposure_time_s=args.exposure_time_s, ) - workflowProductAPI.openScan(Path(args.scan_file_path.name)) - workflowProductAPI.buildProbe() - workflowProductAPI.buildObject() - - stagingDir = args.output_directory - stagingDir.mkdir(parents=True, exist_ok=True) - model.workflowAPI.saveSettings(stagingDir / 'settings.ini', changePathPrefix) - model.workflowAPI.exportProcessedPatterns(stagingDir / 'patterns.npz') - workflowProductAPI.saveProduct(stagingDir / 'product-in.npz', fileType='NPZ') + workflow_product_api.open_scan(Path(args.scan_file_path.name)) + workflow_product_api.build_probe() + workflow_product_api.build_object() + + staging_dir = args.output_directory + staging_dir.mkdir(parents=True, exist_ok=True) + model.workflow_api.save_settings(staging_dir / 'settings.ini', change_path_prefix) + model.workflow_api.export_assembled_patterns(staging_dir / 'patterns.npz') + workflow_product_api.save_product(staging_dir / 'product-in.npz', file_type='NPZ') return 0 diff --git a/src/ptychodus/ptychodus_stream_processor.py b/src/ptychodus/ptychodus_stream_processor.py new file mode 100644 index 00000000..76a02adc --- /dev/null +++ b/src/ptychodus/ptychodus_stream_processor.py @@ -0,0 +1,193 @@ +from pathlib import Path +from typing import Any +import logging +import threading +import time + +import numpy + +from pvapy.hpc.adImageProcessor import AdImageProcessor +from pvapy.utility.floatWithUnits import FloatWithUnits +from pvapy.utility.timeUtility import TimeUtility +import pvaccess +import pvapy + +from ptychodus.model import ModelCore +import ptychodus + + +class ReconstructionThread(threading.Thread): + def __init__( + self, + ptychodus: ModelCore, + input_product_path: Path, + output_product_path: Path, + reconstruct_pv: str, + ) -> None: + super().__init__() + self._ptychodus = ptychodus + self._input_product_path = input_product_path + self._output_product_path = output_product_path + self._channel = pvapy.Channel(reconstruct_pv, pvapy.CA) + self._reconstruct_event = threading.Event() + self._stop_event = threading.Event() + + self._ptychodus_streaming_context = None + + self._channel.subscribe('reconstructor', self._monitor) + self._channel.startMonitor() + + def run(self) -> None: + while not self._stop_event.is_set(): + if self._reconstruct_event.wait(timeout=1.0): + logging.debug('ReconstructionThread: Begin assembling scan positions') + + if self._ptychodus_streaming_context is not None: + self._ptychodus_streaming_context.stop() + + logging.debug('ReconstructionThread: End assembling scan positions') + self._ptychodus.batch_mode_execute( + 'reconstruct', self._input_product_path, self._output_product_path + ) + self._reconstruct_event.clear() + # reconstruction done; indicate that results are ready + self._channel.put(0) + + def _monitor(self, pv_object: pvaccess.PvObject) -> None: + # NOTE caput bdpgp:gp:bit3 1 + logging.debug(f'ReconstructionThread::monitor {pv_object}') + + if pv_object['value']['index'] == 1: + logging.debug('ReconstructionThread: Reconstruct PV triggered!') + # start reconstructing + self._reconstruct_event.set() + else: + logging.debug('ReconstructionThread: Reconstruct PV not triggered!') + + def stop(self) -> None: + self._stop_event.set() + + +class PtychodusAdImageProcessor(AdImageProcessor): + def __init__(self, config_dict: dict[str, Any] = {}) -> None: + super().__init__(config_dict) + + self.logger.debug(f'{ptychodus.__name__.title()} ({ptychodus.__version__})') + + settings_file = config_dict.get('settingsFile') + self._ptychodus = ModelCore(settings_file) + self._reconstruction_thread = ReconstructionThread( + self._ptychodus, + Path(config_dict.get('inputProductPath', 'input.npz')), + Path(config_dict.get('outputProductPath', 'output.npz')), + config_dict.get('reconstructPV', 'bdpgp:gp:bit3'), + ) + self._pos_x_pv = config_dict.get('pos_x_pv', 'bluesky:pos_x') + self._pos_y_pv = config_dict.get('pos_y_pv', 'bluesky:pos_y') + self._num_frames_processed = 0 + self._processing_time = 0.0 + + def start(self) -> None: + """Called at startup""" + self._ptychodus.__enter__() + self._reconstruction_thread.start() + + def stop(self) -> None: + """Called at shutdown""" + self._reconstruction_thread.stop() + self._reconstruction_thread.join() + self._ptychodus.__exit__(None, None, None) + + def configure(self, config_dict: dict[str, Any]) -> None: + """Configures user processor""" + num_patterns_total = config_dict['nPatternsTotal'] + num_patterns_per_array = config_dict.get('nPatternsPerArray', 1) + pattern_dtype = config_dict.get('PatternDataType', 'uint16') + + metadata = ptychodus.api.patterns.DiffractionMetadata( + num_patterns_per_array=int(num_patterns_per_array), + num_patterns_total=int(num_patterns_total), + pattern_dtype=numpy.dtype(pattern_dtype), + ) + self._ptychodus_streaming_context = self._ptychodus.create_streaming_context(metadata) + self._ptychodus_streaming_context.start() # TODO clean up + + def process(self, pv_object: pvaccess.PvObject) -> pvaccess.PvObject: + """Processes monitor update""" + processing_begin_time = time.time() + + (frame_id, image, nx, ny, nz, color_mode, field_key) = self.reshapeNtNdArray(pv_object) + frame_time_stamp = TimeUtility.getTimeStampAsFloat(pv_object['timeStamp']) + + if nx is None: + self.logger.debug(f'Frame id {frame_id} contains an empty image.') + else: + self.logger.debug(f'Frame id {frame_id} time stamp {frame_time_stamp}') + image3d = image[numpy.newaxis, :, :].copy() + array = ptychodus.api.patterns.SimpleDiffractionPatternArray( + label=f'Frame{frame_id}', + indexes=numpy.array([frame_id]), + data=image3d, + ) + self._ptychodus_streaming_context.append_array(array) + + pos_x_queue = self.metadataQueueMap[self._pos_x_pv] + + while True: + try: + pos_x = pos_x_queue.get(0) + except pvaccess.QueueEmpty: + break + else: + self._ptychodus_streaming_context.append_positions_x( + pos_x['values'], + [TimeUtility.getTimeStampAsFloat(ts) for ts in pos_x['t']], + ) + + pos_y_queue = self.metadataQueueMap[self._pos_y_pv] + + while True: + try: + pos_y = pos_y_queue.get(0) + except pvaccess.QueueEmpty: + break + else: + self._ptychodus_streaming_context.append_positions_y( + pos_y['values'], + [TimeUtility.getTimeStampAsFloat(ts) for ts in pos_y['t']], + ) + + processing_end_time = time.time() + self._processing_time += processing_end_time - processing_begin_time + self._num_frames_processed += 1 + + return pv_object + + def resetStats(self) -> None: # noqa: N802 + """Resets statistics for user processor""" + self._num_frames_processed = 0 + self._processing_time = 0.0 + + def getStats(self) -> dict[str, Any]: # noqa: N802 + """Retrieves statistics for user processor""" + num_frames_queued = self._ptychodus_streaming_context.get_queue_size() + processed_frame_rate = 0.0 + + if self._processing_time > 0.0: + processed_frame_rate = self._num_frames_processed / self._processing_time + + return { + 'num_frames_processed': self._num_frames_processed, + 'num_frames_queued': num_frames_queued, + 'processing_time': FloatWithUnits(self._processing_time, 's'), + 'processed_frame_rate': FloatWithUnits(processed_frame_rate, 'fps'), + } + + def getStatsPvaTypes(self) -> dict[str, pvaccess.ScalarType]: # noqa: N802 + """Defines PVA types for different stats variables""" + return { + 'num_frames_processed': pvaccess.UINT, + 'num_frames_queued': pvaccess.UINT, + 'processing_time': pvaccess.DOUBLE, + 'processing_frame_rate': pvaccess.DOUBLE, + } diff --git a/src/ptychodus/view/agent.py b/src/ptychodus/view/agent.py new file mode 100644 index 00000000..3ec0ee54 --- /dev/null +++ b/src/ptychodus/view/agent.py @@ -0,0 +1,49 @@ +from PyQt5.QtCore import Qt +from PyQt5.QtGui import QIcon +from PyQt5.QtWidgets import ( + QFrame, + QHBoxLayout, + QListView, + QPushButton, + QScrollArea, + QSizePolicy, + QSplitter, + QPlainTextEdit, + QWidget, +) + + +class AgentView(QWidget): + pass + + +class AgentInputView(QFrame): + def __init__(self, parent: QWidget | None = None) -> None: + super().__init__(parent) + self.text_edit = QPlainTextEdit() + + send_button_size_policy = QSizePolicy(QSizePolicy.Preferred, QSizePolicy.Preferred) + self.send_button = QPushButton(QIcon(':/icons/send'), 'Send') + self.send_button.setSizePolicy(send_button_size_policy) + + layout = QHBoxLayout() + layout.setContentsMargins(0, 0, 0, 0) + layout.addWidget(self.text_edit) + layout.addWidget(self.send_button) + self.setLayout(layout) + + +class AgentChatView(QSplitter): + def __init__(self, parent: QWidget | None = None) -> None: + super().__init__(Qt.Orientation.Vertical, parent) + self.message_list_view = QListView() + self.scroll_area = QScrollArea() + self.scroll_area.setWidgetResizable(True) + self.scroll_area.setWidget(self.message_list_view) + self.input_view = AgentInputView() + + self.addWidget(self.scroll_area) + self.addWidget(self.input_view) + + self.setStretchFactor(0, 2) + self.setStretchFactor(1, 0) diff --git a/src/ptychodus/view/automation.py b/src/ptychodus/view/automation.py index 32c78c65..6ba93683 100644 --- a/src/ptychodus/view/automation.py +++ b/src/ptychodus/view/automation.py @@ -20,26 +20,26 @@ class AutomationProcessingView(QGroupBox): def __init__(self, parent: QWidget | None) -> None: super().__init__('Processing', parent) - self.strategyLabel = QLabel('Strategy:') - self.strategyComboBox = QComboBox() - self.directoryLabel = QLabel('Directory:') - self.directoryLineEdit = QLineEdit() - self.directoryBrowseButton = QPushButton('Browse') - self.intervalLabel = QLabel('Interval [sec]:') - self.intervalSpinBox = QSpinBox() + self.strategy_label = QLabel('Strategy:') + self.strategy_combo_box = QComboBox() + self.directory_label = QLabel('Directory:') + self.directory_line_edit = QLineEdit() + self.directory_browse_button = QPushButton('Browse') + self.interval_label = QLabel('Interval [sec]:') + self.interval_spin_box = QSpinBox() @classmethod - def createInstance(cls, parent: QWidget | None = None) -> AutomationProcessingView: + def create_instance(cls, parent: QWidget | None = None) -> AutomationProcessingView: view = cls(parent) layout = QGridLayout() - layout.addWidget(view.strategyLabel, 0, 0) - layout.addWidget(view.strategyComboBox, 0, 1, 1, 2) - layout.addWidget(view.directoryLabel, 1, 0) - layout.addWidget(view.directoryLineEdit, 1, 1) - layout.addWidget(view.directoryBrowseButton, 1, 2) - layout.addWidget(view.intervalLabel, 2, 0) - layout.addWidget(view.intervalSpinBox, 2, 1, 1, 2) + layout.addWidget(view.strategy_label, 0, 0) + layout.addWidget(view.strategy_combo_box, 0, 1, 1, 2) + layout.addWidget(view.directory_label, 1, 0) + layout.addWidget(view.directory_line_edit, 1, 1) + layout.addWidget(view.directory_browse_button, 1, 2) + layout.addWidget(view.interval_label, 2, 0) + layout.addWidget(view.interval_spin_box, 2, 1, 1, 2) layout.setColumnStretch(1, 1) view.setLayout(layout) @@ -49,16 +49,16 @@ def createInstance(cls, parent: QWidget | None = None) -> AutomationProcessingVi class AutomationWatchdogView(QGroupBox): def __init__(self, parent: QWidget | None) -> None: super().__init__('Watchdog', parent) - self.delaySpinBox = QSpinBox() - self.usePollingObserverCheckBox = QCheckBox('Use Polling Observer') + self.delay_spin_box = QSpinBox() + self.use_polling_observer_check_box = QCheckBox('Use Polling Observer') @classmethod - def createInstance(cls, parent: QWidget | None = None) -> AutomationWatchdogView: + def create_instance(cls, parent: QWidget | None = None) -> AutomationWatchdogView: view = cls(parent) layout = QFormLayout() - layout.addRow('Delay [sec]:', view.delaySpinBox) - layout.addRow(view.usePollingObserverCheckBox) + layout.addRow('Delay [sec]:', view.delay_spin_box) + layout.addRow(view.use_polling_observer_check_box) view.setLayout(layout) return view @@ -67,29 +67,29 @@ def createInstance(cls, parent: QWidget | None = None) -> AutomationWatchdogView class AutomationView(QWidget): def __init__(self, parent: QWidget | None) -> None: super().__init__(parent) - self.processingView = AutomationProcessingView.createInstance() - self.watchdogView = AutomationWatchdogView.createInstance() - self.processingListView = QListView() - self.loadButton = QPushButton('Load') - self.watchButton = QPushButton('Watch') - self.processButton = QPushButton('Process') - self.clearButton = QPushButton('Clear') + self.processing_view = AutomationProcessingView.create_instance() + self.watchdog_view = AutomationWatchdogView.create_instance() + self.processing_list_view = QListView() + self.load_button = QPushButton('Load') + self.watch_button = QPushButton('Watch') + self.process_button = QPushButton('Process') + self.clear_button = QPushButton('Clear') @classmethod - def createInstance(cls, parent: QWidget | None = None) -> AutomationView: + def create_instance(cls, parent: QWidget | None = None) -> AutomationView: view = cls(parent) - buttonLayout = QHBoxLayout() - buttonLayout.addWidget(view.loadButton) - buttonLayout.addWidget(view.watchButton) - buttonLayout.addWidget(view.processButton) - buttonLayout.addWidget(view.clearButton) + button_layout = QHBoxLayout() + button_layout.addWidget(view.load_button) + button_layout.addWidget(view.watch_button) + button_layout.addWidget(view.process_button) + button_layout.addWidget(view.clear_button) layout = QVBoxLayout() - layout.addWidget(view.processingView) - layout.addWidget(view.watchdogView) - layout.addWidget(view.processingListView) - layout.addLayout(buttonLayout) + layout.addWidget(view.processing_view) + layout.addWidget(view.watchdog_view) + layout.addWidget(view.processing_list_view) + layout.addLayout(button_layout) view.setLayout(layout) return view diff --git a/src/ptychodus/view/core.py b/src/ptychodus/view/core.py index 3a57fb33..8208760a 100644 --- a/src/ptychodus/view/core.py +++ b/src/ptychodus/view/core.py @@ -6,8 +6,8 @@ from PyQt5.QtWidgets import ( QActionGroup, QApplication, + QLCDNumber, QMainWindow, - QProgressBar, QSizePolicy, QSplitter, QStackedWidget, @@ -17,11 +17,12 @@ ) from . import resources # noqa +from .agent import AgentView, AgentChatView from .automation import AutomationView from .image import ImageView from .patterns import PatternsView from .product import ProductView -from .reconstructor import ReconstructorParametersView, ReconstructorPlotView +from .reconstructor import ReconstructorView, ReconstructorPlotView from .repository import RepositoryTableView, RepositoryTreeView from .scan import ScanPlotView from .settings import SettingsView @@ -31,125 +32,122 @@ class ViewCore(QMainWindow): - def __init__(self, parent: QWidget | None) -> None: + def __init__(self, parent: QWidget | None = None) -> None: super().__init__(parent) - self.navigationToolBar = QToolBar() - self.navigationActionGroup = QActionGroup(self.navigationToolBar) + logger.info(f'PyQt {PYQT_VERSION_STR}') + logger.info(f'Qt {QT_VERSION_STR}') + + self.navigation_tool_bar = QToolBar() + self.navigation_action_group = QActionGroup(self.navigation_tool_bar) self.splitter = QSplitter(Qt.Orientation.Horizontal) - self.parametersWidget = QStackedWidget() - self.contentsWidget = QStackedWidget() - self.memoryProgressBar = QProgressBar() + self.left_panel = QStackedWidget() + self.right_panel = QStackedWidget() + self.memory_widget = QLCDNumber() - self.settingsAction = self.navigationToolBar.addAction( + self.settings_action = self.navigation_tool_bar.addAction( QIcon(':/icons/settings'), 'Settings' ) - self.settingsView = SettingsView.createInstance() - self.settingsTableView = QTableView() + self.settings_view = SettingsView.create_instance() + self.settings_table_view = QTableView() - self.patternsAction = self.navigationToolBar.addAction( + self.patterns_action = self.navigation_tool_bar.addAction( QIcon(':/icons/patterns'), 'Patterns' ) - self.patternsView = PatternsView.createInstance() - self.patternsImageView = ImageView.createInstance() + self.patterns_view = PatternsView() + self.patterns_image_view = ImageView.create_instance() - self.productAction = self.navigationToolBar.addAction(QIcon(':/icons/products'), 'Products') - self.productView = ProductView() - self.productDiagramView = QWidget() + self.product_action = self.navigation_tool_bar.addAction( + QIcon(':/icons/products'), 'Products' + ) + self.product_view = ProductView() + self.product_diagram_view = QWidget() - self.scanAction = self.navigationToolBar.addAction(QIcon(':/icons/scan'), 'Scan') - self.scanView = RepositoryTableView() - self.scanPlotView = ScanPlotView.createInstance() + self.scan_action = self.navigation_tool_bar.addAction(QIcon(':/icons/scan'), 'Positions') + self.scan_view = RepositoryTableView() + self.scan_plot_view = ScanPlotView.create_instance() - self.probeAction = self.navigationToolBar.addAction(QIcon(':/icons/probe'), 'Probe') - self.probeView = RepositoryTreeView() - self.probeImageView = ImageView.createInstance() + self.probe_action = self.navigation_tool_bar.addAction(QIcon(':/icons/probe'), 'Probe') + self.probe_view = RepositoryTreeView() + self.probe_image_view = ImageView.create_instance() - self.objectAction = self.navigationToolBar.addAction(QIcon(':/icons/object'), 'Object') - self.objectView = RepositoryTreeView() - self.objectImageView = ImageView.createInstance() + self.object_action = self.navigation_tool_bar.addAction(QIcon(':/icons/object'), 'Object') + self.object_view = RepositoryTreeView() + self.object_image_view = ImageView.create_instance() - self.reconstructorAction = self.navigationToolBar.addAction( + self.reconstructor_action = self.navigation_tool_bar.addAction( QIcon(':/icons/reconstructor'), 'Reconstructor' ) - self.reconstructorParametersView = ReconstructorParametersView.createInstance() - self.reconstructorPlotView = ReconstructorPlotView.createInstance() + self.reconstructor_view = ReconstructorView() + self.reconstructor_plot_view = ReconstructorPlotView() - self.workflowAction = self.navigationToolBar.addAction( + self.workflow_action = self.navigation_tool_bar.addAction( QIcon(':/icons/workflow'), 'Workflow' ) - self.workflowParametersView = WorkflowParametersView.createInstance() - self.workflowTableView = QTableView() + self.workflow_parameters_view = WorkflowParametersView.create_instance() + self.workflow_table_view = QTableView() - self.automationAction = self.navigationToolBar.addAction( + self.automation_action = self.navigation_tool_bar.addAction( QIcon(':/icons/automate'), 'Automation' ) - self.automationView = AutomationView.createInstance() - self.automationWidget = QWidget() + self.automation_view = AutomationView.create_instance() + self.automation_widget = QWidget() - @classmethod - def createInstance( - cls, isDeveloperModeEnabled: bool, parent: QWidget | None = None - ) -> ViewCore: - logger.info(f'PyQt {PYQT_VERSION_STR}') - logger.info(f'Qt {QT_VERSION_STR}') + self.agent_action = self.navigation_tool_bar.addAction( + QIcon(':/icons/sparkles'), + 'Agent', + ) + self.agent_view = AgentView() + self.agent_chat_view = AgentChatView() + + ##### - view = cls(parent) - view.setWindowIcon(QIcon(':/icons/ptychodus')) + self.setWindowIcon(QIcon(':/icons/ptychodus')) - view.navigationToolBar.setContextMenuPolicy(Qt.ContextMenuPolicy.PreventContextMenu) - view.addToolBar(Qt.ToolBarArea.LeftToolBarArea, view.navigationToolBar) - view.navigationToolBar.setToolButtonStyle(Qt.ToolButtonStyle.ToolButtonTextUnderIcon) - view.navigationToolBar.setIconSize(QSize(32, 32)) + self.navigation_tool_bar.setContextMenuPolicy(Qt.ContextMenuPolicy.PreventContextMenu) + self.navigation_tool_bar.setToolButtonStyle(Qt.ToolButtonStyle.ToolButtonTextUnderIcon) + self.navigation_tool_bar.setIconSize(QSize(32, 32)) + self.addToolBar(Qt.ToolBarArea.LeftToolBarArea, self.navigation_tool_bar) - for index, action in enumerate(view.navigationToolBar.actions()): + for index, action in enumerate(self.navigation_tool_bar.actions()): action.setCheckable(True) action.setData(index) - view.navigationActionGroup.addAction(action) + self.navigation_action_group.addAction(action) # maintain same order as navigationToolBar buttons - view.parametersWidget.addWidget(view.settingsView) - view.parametersWidget.addWidget(view.patternsView) - view.parametersWidget.addWidget(view.productView) - view.parametersWidget.addWidget(view.scanView) - view.parametersWidget.addWidget(view.probeView) - view.parametersWidget.addWidget(view.objectView) - view.parametersWidget.addWidget(view.reconstructorParametersView) - view.parametersWidget.addWidget(view.workflowParametersView) - view.parametersWidget.addWidget(view.automationView) - view.parametersWidget.setSizePolicy(QSizePolicy.Policy.Minimum, QSizePolicy.Policy.Minimum) - view.splitter.addWidget(view.parametersWidget) + self.left_panel.addWidget(self.settings_view) + self.left_panel.addWidget(self.patterns_view) + self.left_panel.addWidget(self.product_view) + self.left_panel.addWidget(self.scan_view) + self.left_panel.addWidget(self.probe_view) + self.left_panel.addWidget(self.object_view) + self.left_panel.addWidget(self.reconstructor_view) + self.left_panel.addWidget(self.workflow_parameters_view) + self.left_panel.addWidget(self.automation_view) + self.left_panel.addWidget(self.agent_view) + self.left_panel.setSizePolicy(QSizePolicy.Policy.Minimum, QSizePolicy.Policy.Minimum) + self.splitter.addWidget(self.left_panel) # maintain same order as navigationToolBar buttons - view.contentsWidget.addWidget(view.settingsTableView) - view.contentsWidget.addWidget(view.patternsImageView) - view.contentsWidget.addWidget(view.productDiagramView) - view.contentsWidget.addWidget(view.scanPlotView) - view.contentsWidget.addWidget(view.probeImageView) - view.contentsWidget.addWidget(view.objectImageView) - view.contentsWidget.addWidget(view.reconstructorPlotView) - view.contentsWidget.addWidget(view.workflowTableView) - view.contentsWidget.addWidget(view.automationWidget) - view.contentsWidget.setSizePolicy(QSizePolicy.Policy.Minimum, QSizePolicy.Policy.Minimum) - view.splitter.addWidget(view.contentsWidget) - - view.setCentralWidget(view.splitter) - - # TODO make visible when complete - view.scanView.buttonBox.analyzeButton.setVisible(isDeveloperModeEnabled) - view.probeView.buttonBox.analyzeButton.setVisible(isDeveloperModeEnabled) - view.objectView.buttonBox.analyzeButton.setVisible(isDeveloperModeEnabled) - - desktopSize = QApplication.desktop().availableGeometry().size() - preferredHeight = desktopSize.height() * 2 // 3 - preferredWidth = min(desktopSize.width() * 2 // 3, 2 * preferredHeight) - view.resize(preferredWidth, preferredHeight) - - view.memoryProgressBar.setSizePolicy( - QSizePolicy.Policy.Minimum, QSizePolicy.Policy.Preferred - ) - view.statusBar().addPermanentWidget(view.memoryProgressBar) - view.statusBar().showMessage('Ready') - - return view + self.right_panel.addWidget(self.settings_table_view) + self.right_panel.addWidget(self.patterns_image_view) + self.right_panel.addWidget(self.product_diagram_view) + self.right_panel.addWidget(self.scan_plot_view) + self.right_panel.addWidget(self.probe_image_view) + self.right_panel.addWidget(self.object_image_view) + self.right_panel.addWidget(self.reconstructor_plot_view) + self.right_panel.addWidget(self.workflow_table_view) + self.right_panel.addWidget(self.automation_widget) + self.right_panel.addWidget(self.agent_chat_view) + self.right_panel.setSizePolicy(QSizePolicy.Policy.Minimum, QSizePolicy.Policy.Minimum) + self.splitter.addWidget(self.right_panel) + + self.setCentralWidget(self.splitter) + + desktop_size = QApplication.desktop().availableGeometry().size() + preferred_height = desktop_size.height() * 2 // 3 + preferred_width = min(desktop_size.width() * 2 // 3, 2 * preferred_height) + self.resize(preferred_width, preferred_height) + + self.statusBar().addPermanentWidget(self.memory_widget) diff --git a/src/ptychodus/view/image.py b/src/ptychodus/view/image.py index c0f0c0ac..572a4bff 100644 --- a/src/ptychodus/view/image.py +++ b/src/ptychodus/view/image.py @@ -19,7 +19,7 @@ QWidget, ) -from ptychodus.api.visualization import RealArrayType +from ptychodus.api.typing import RealArrayType from .visualization import VisualizationView from .widgets import BottomTitledGroupBox, DecimalLineEdit, DecimalSlider @@ -28,23 +28,23 @@ class ImageDisplayRangeDialog(QDialog): def __init__(self, parent: QWidget | None) -> None: super().__init__(parent) - self.buttonBox = QDialogButtonBox() - self.minValueLineEdit = DecimalLineEdit.createInstance() - self.maxValueLineEdit = DecimalLineEdit.createInstance() + self.button_box = QDialogButtonBox() + self.min_value_line_edit = DecimalLineEdit.create_instance() + self.max_value_line_edit = DecimalLineEdit.create_instance() @classmethod - def createInstance(cls, parent: QWidget | None = None) -> ImageDisplayRangeDialog: + def create_instance(cls, parent: QWidget | None = None) -> ImageDisplayRangeDialog: dialog = cls(parent) dialog.setWindowTitle('Set Display Range') - dialog.buttonBox.addButton(QDialogButtonBox.StandardButton.Ok) - dialog.buttonBox.accepted.connect(dialog.accept) - dialog.buttonBox.addButton(QDialogButtonBox.StandardButton.Cancel) - dialog.buttonBox.rejected.connect(dialog.reject) + dialog.button_box.addButton(QDialogButtonBox.StandardButton.Ok) + dialog.button_box.accepted.connect(dialog.accept) + dialog.button_box.addButton(QDialogButtonBox.StandardButton.Cancel) + dialog.button_box.rejected.connect(dialog.reject) layout = QFormLayout() - layout.addRow('Minimum Displayed Value:', dialog.minValueLineEdit) - layout.addRow('Maximum Displayed Value:', dialog.maxValueLineEdit) - layout.addRow(dialog.buttonBox) + layout.addRow('Minimum Displayed Value:', dialog.min_value_line_edit) + layout.addRow('Maximum Displayed Value:', dialog.max_value_line_edit) + layout.addRow(dialog.button_box) dialog.setLayout(layout) return dialog @@ -53,48 +53,48 @@ def createInstance(cls, parent: QWidget | None = None) -> ImageDisplayRangeDialo class ImageToolsGroupBox(BottomTitledGroupBox): def __init__(self, parent: QWidget | None) -> None: super().__init__('Tools', parent) - self.homeButton = QToolButton() - self.saveButton = QToolButton() - self.moveButton = QToolButton() - self.rulerButton = QToolButton() - self.rectangleButton = QToolButton() - self.lineCutButton = QToolButton() + self.home_button = QToolButton() + self.save_button = QToolButton() + self.move_button = QToolButton() + self.ruler_button = QToolButton() + self.rectangle_button = QToolButton() + self.line_cut_button = QToolButton() @classmethod - def createInstance(cls, parent: QWidget | None = None) -> ImageToolsGroupBox: + def create_instance(cls, parent: QWidget | None = None) -> ImageToolsGroupBox: view = cls(parent) - view.homeButton.setIcon(QIcon(':/icons/home')) - view.homeButton.setIconSize(QSize(32, 32)) - view.homeButton.setToolTip('Home') + view.home_button.setIcon(QIcon(':/icons/home')) + view.home_button.setIconSize(QSize(32, 32)) + view.home_button.setToolTip('Home') - view.saveButton.setIcon(QIcon(':/icons/save')) - view.saveButton.setIconSize(QSize(32, 32)) - view.saveButton.setToolTip('Save Image') + view.save_button.setIcon(QIcon(':/icons/save')) + view.save_button.setIconSize(QSize(32, 32)) + view.save_button.setToolTip('Save Image') - view.moveButton.setIcon(QIcon(':/icons/move')) - view.moveButton.setIconSize(QSize(32, 32)) - view.moveButton.setToolTip('Move') + view.move_button.setIcon(QIcon(':/icons/move')) + view.move_button.setIconSize(QSize(32, 32)) + view.move_button.setToolTip('Move') - view.rulerButton.setIcon(QIcon(':/icons/ruler')) - view.rulerButton.setIconSize(QSize(32, 32)) - view.rulerButton.setToolTip('Ruler') + view.ruler_button.setIcon(QIcon(':/icons/ruler')) + view.ruler_button.setIconSize(QSize(32, 32)) + view.ruler_button.setToolTip('Ruler') - view.rectangleButton.setIcon(QIcon(':/icons/rectangle')) - view.rectangleButton.setIconSize(QSize(32, 32)) - view.rectangleButton.setToolTip('Rectangle') + view.rectangle_button.setIcon(QIcon(':/icons/rectangle')) + view.rectangle_button.setIconSize(QSize(32, 32)) + view.rectangle_button.setToolTip('Rectangle') - view.lineCutButton.setIcon(QIcon(':/icons/line-cut')) - view.lineCutButton.setIconSize(QSize(32, 32)) - view.lineCutButton.setToolTip('Line-Cut Profile') + view.line_cut_button.setIcon(QIcon(':/icons/line-cut')) + view.line_cut_button.setIconSize(QSize(32, 32)) + view.line_cut_button.setToolTip('Line-Cut Profile') layout = QGridLayout() - layout.addWidget(view.homeButton, 0, 0) - layout.addWidget(view.saveButton, 0, 1) - layout.addWidget(view.moveButton, 0, 2) - layout.addWidget(view.rulerButton, 1, 0) - layout.addWidget(view.rectangleButton, 1, 1) - layout.addWidget(view.lineCutButton, 1, 2) + layout.addWidget(view.home_button, 0, 0) + layout.addWidget(view.save_button, 0, 1) + layout.addWidget(view.move_button, 0, 2) + layout.addWidget(view.ruler_button, 1, 0) + layout.addWidget(view.rectangle_button, 1, 1) + layout.addWidget(view.line_cut_button, 1, 2) view.setLayout(layout) view.setSizePolicy(QSizePolicy.Policy.Minimum, QSizePolicy.Policy.Preferred) @@ -105,23 +105,23 @@ def createInstance(cls, parent: QWidget | None = None) -> ImageToolsGroupBox: class ImageRendererGroupBox(BottomTitledGroupBox): def __init__(self, parent: QWidget | None) -> None: super().__init__('Colorize', parent) - self.rendererComboBox = QComboBox() - self.transformationComboBox = QComboBox() - self.variantComboBox = QComboBox() + self.renderer_combo_box = QComboBox() + self.transformation_combo_box = QComboBox() + self.variant_combo_box = QComboBox() @classmethod - def createInstance(cls, parent: QWidget | None = None) -> ImageRendererGroupBox: + def create_instance(cls, parent: QWidget | None = None) -> ImageRendererGroupBox: view = cls(parent) - view.rendererComboBox.setToolTip('Array Component') - view.transformationComboBox.setToolTip('Transformation') - view.variantComboBox.setToolTip('Variant') + view.renderer_combo_box.setToolTip('Array Component') + view.transformation_combo_box.setToolTip('Transformation') + view.variant_combo_box.setToolTip('Variant') layout = QVBoxLayout() layout.setContentsMargins(10, 10, 10, 35) - layout.addWidget(view.rendererComboBox) - layout.addWidget(view.transformationComboBox) - layout.addWidget(view.variantComboBox) + layout.addWidget(view.renderer_combo_box) + layout.addWidget(view.transformation_combo_box) + layout.addWidget(view.variant_combo_box) view.setLayout(layout) view.setSizePolicy(QSizePolicy.Policy.Minimum, QSizePolicy.Policy.Preferred) @@ -132,33 +132,33 @@ def createInstance(cls, parent: QWidget | None = None) -> ImageRendererGroupBox: class ImageDataRangeGroupBox(BottomTitledGroupBox): def __init__(self, parent: QWidget | None) -> None: super().__init__('Data Range', parent) - self.minDisplayValueSlider = DecimalSlider.createInstance(Qt.Orientation.Horizontal) - self.maxDisplayValueSlider = DecimalSlider.createInstance(Qt.Orientation.Horizontal) - self.autoButton = QPushButton('Auto') - self.editButton = QPushButton('Edit') - self.colorLegendButton = QPushButton('Color Legend') + self.min_display_value_slider = DecimalSlider.create_instance(Qt.Orientation.Horizontal) + self.max_display_value_slider = DecimalSlider.create_instance(Qt.Orientation.Horizontal) + self.auto_button = QPushButton('Auto') + self.edit_button = QPushButton('Edit') + self.color_legend_button = QPushButton('Color Legend') @classmethod - def createInstance(cls, parent: QWidget | None = None) -> ImageDataRangeGroupBox: + def create_instance(cls, parent: QWidget | None = None) -> ImageDataRangeGroupBox: view = cls(parent) - view.minDisplayValueSlider.setToolTip('Minimum Display Value') - view.maxDisplayValueSlider.setToolTip('Maximum Display Value') - view.autoButton.setToolTip('Rescale to Data Range') - view.editButton.setToolTip('Rescale to Custom Range') - view.colorLegendButton.setToolTip('Toggle Color Legend Visibility') + view.min_display_value_slider.setToolTip('Minimum Display Value') + view.max_display_value_slider.setToolTip('Maximum Display Value') + view.auto_button.setToolTip('Rescale to Data Range') + view.edit_button.setToolTip('Rescale to Custom Range') + view.color_legend_button.setToolTip('Toggle Color Legend Visibility') - buttonLayout = QHBoxLayout() - buttonLayout.setContentsMargins(0, 0, 0, 0) - buttonLayout.addWidget(view.autoButton) - buttonLayout.addWidget(view.editButton) - buttonLayout.addWidget(view.colorLegendButton) + button_layout = QHBoxLayout() + button_layout.setContentsMargins(0, 0, 0, 0) + button_layout.addWidget(view.auto_button) + button_layout.addWidget(view.edit_button) + button_layout.addWidget(view.color_legend_button) layout = QFormLayout() layout.setContentsMargins(10, 10, 10, 35) - layout.addRow('Min:', view.minDisplayValueSlider) - layout.addRow('Max:', view.maxDisplayValueSlider) - layout.addRow(buttonLayout) + layout.addRow('Min:', view.min_display_value_slider) + layout.addRow('Max:', view.max_display_value_slider) + layout.addRow(button_layout) view.setLayout(layout) view.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Preferred) @@ -169,19 +169,19 @@ def createInstance(cls, parent: QWidget | None = None) -> ImageDataRangeGroupBox class ImageRibbon(QWidget): def __init__(self, parent: QWidget | None) -> None: super().__init__(parent) - self.imageToolsGroupBox = ImageToolsGroupBox.createInstance() - self.colormapGroupBox = ImageRendererGroupBox.createInstance() - self.dataRangeGroupBox = ImageDataRangeGroupBox.createInstance() + self.image_tools_group_box = ImageToolsGroupBox.create_instance() + self.colormap_group_box = ImageRendererGroupBox.create_instance() + self.data_range_group_box = ImageDataRangeGroupBox.create_instance() @classmethod - def createInstance(cls, parent: QWidget | None = None) -> ImageRibbon: + def create_instance(cls, parent: QWidget | None = None) -> ImageRibbon: view = cls(parent) layout = QHBoxLayout() layout.setContentsMargins(0, 0, 0, 0) - layout.addWidget(view.imageToolsGroupBox) - layout.addWidget(view.colormapGroupBox) - layout.addWidget(view.dataRangeGroupBox) + layout.addWidget(view.image_tools_group_box) + layout.addWidget(view.colormap_group_box) + layout.addWidget(view.data_range_group_box) view.setLayout(layout) view.setSizePolicy(QSizePolicy.Policy.Minimum, QSizePolicy.Policy.Fixed) @@ -192,118 +192,118 @@ def createInstance(cls, parent: QWidget | None = None) -> ImageRibbon: class ImageWidget(VisualizationView): def __init__(self, parent: QWidget | None = None) -> None: super().__init__(parent) - self._colorLegendMinValue = 0.0 - self._colorLegendMaxValue = 1.0 - self._colorLegendStopPoints: list[tuple[float, QColor]] = [ + self._color_legend_min_value = 0.0 + self._color_legend_max_value = 1.0 + self._color_legend_stop_points: list[tuple[float, QColor]] = [ (0.0, QColor(Qt.GlobalColor.green)), (0.5, QColor(Qt.GlobalColor.yellow)), (1.0, QColor(Qt.GlobalColor.red)), ] - self._colorLegendNumberOfTicks = 5 # TODO - self._isColorLegendVisible = False - self._isColorLegendCyclic = False + self._color_legend_num_ticks = 5 # TODO + self._is_color_legend_visible = False + self._is_color_legend_cyclic = False - def setColorLegendColors( - self, values: RealArrayType, rgbaArray: RealArrayType, isCyclic: bool + def set_color_legend_colors( + self, values: RealArrayType, rgba_array: RealArrayType, is_cyclic: bool ) -> None: - colorLegendStopPoints: list[tuple[float, QColor]] = list() - self._colorLegendMinValue = values.min() - self._colorLegendMaxValue = values.max() - - valueRange = self._colorLegendMaxValue - self._colorLegendMinValue - normalizedValues = ( - (values - self._colorLegendMinValue) / valueRange - if valueRange > 0 + color_legend_stop_points: list[tuple[float, QColor]] = list() + self._color_legend_min_value = values.min() + self._color_legend_max_value = values.max() + + value_range = self._color_legend_max_value - self._color_legend_min_value + normalized_values = ( + (values - self._color_legend_min_value) / value_range + if value_range > 0 else numpy.full_like(values, 0.5) ) - for x, rgba in zip(normalizedValues.clip(0, 1), rgbaArray): + for x, rgba in zip(normalized_values.clip(0, 1), rgba_array): color = QColor() color.setRgbF(rgba[0], rgba[1], rgba[2], rgba[3]) - colorLegendStopPoints.append((x, color)) + color_legend_stop_points.append((x, color)) - self._colorLegendStopPoints = colorLegendStopPoints - self._isColorLegendCyclic = isCyclic + self._color_legend_stop_points = color_legend_stop_points + self._is_color_legend_cyclic = is_cyclic self.scene().update() - def setColorLegendVisible(self, visible: bool) -> None: - self._isColorLegendVisible = visible + def set_color_legend_visible(self, visible: bool) -> None: + self._is_color_legend_visible = visible self.scene().update() @property - def _colorLegendTicks(self) -> Iterator[float]: - for tick in range(self._colorLegendNumberOfTicks): - a = tick / (self._colorLegendNumberOfTicks - 1) - yield (1.0 - a) * self._colorLegendMinValue + a * self._colorLegendMaxValue + def _color_legend_ticks(self) -> Iterator[float]: + for tick in range(self._color_legend_num_ticks): + a = tick / (self._color_legend_num_ticks - 1) + yield (1.0 - a) * self._color_legend_min_value + a * self._color_legend_max_value - def drawForeground(self, painter: QPainter, rect: QRectF) -> None: - if not self._isColorLegendVisible: + def drawForeground(self, painter: QPainter, rect: QRectF) -> None: # noqa: N802 + if not self._is_color_legend_visible: return - fgPainter = QPainter(self.viewport()) + fg_painter = QPainter(self.viewport()) pen = QPen() pen.setWidth(3) - fgPainter.setPen(pen) + fg_painter.setPen(pen) - fontMetrics = fgPainter.fontMetrics() - dx = fontMetrics.horizontalAdvance('m') - dy = fontMetrics.lineSpacing() + font_metrics = fg_painter.fontMetrics() + dx = font_metrics.horizontalAdvance('m') + dy = font_metrics.lineSpacing() - widgetRect = self.viewport().rect() + widget_rect = self.viewport().rect() - if self._isColorLegendCyclic: - legendDiameter = 6 * dx - legendMargin = 2 * dx + if self._is_color_legend_cyclic: + legend_diameter = 6 * dx + legend_margin = 2 * dx - legendRect = QRect(0, 0, legendDiameter, legendDiameter) - legendRect.moveRight(widgetRect.right() - legendMargin) - legendRect.moveBottom(widgetRect.height() - legendMargin) + legend_rect = QRect(0, 0, legend_diameter, legend_diameter) + legend_rect.moveRight(widget_rect.right() - legend_margin) + legend_rect.moveBottom(widget_rect.height() - legend_margin) - cgradient = QConicalGradient(legendRect.center(), 90.0) - cgradient.setStops(self._colorLegendStopPoints) - fgPainter.setBrush(cgradient) - fgPainter.drawEllipse(legendRect) + cgradient = QConicalGradient(legend_rect.center(), 90.0) + cgradient.setStops(self._color_legend_stop_points) + fg_painter.setBrush(cgradient) + fg_painter.drawEllipse(legend_rect) else: - tickLabels = [f'{tick:5g}' for tick in self._colorLegendTicks] - tickLabelWidth = max(fontMetrics.width(label) for label in tickLabels) + tick_labels = [f'{tick:5g}' for tick in self._color_legend_ticks] + tick_label_width = max(font_metrics.width(label) for label in tick_labels) - legendWidth = 2 * dx - legendHeight = (2 * len(tickLabels) - 1) * dy - legendMargin = tickLabelWidth + 2 * dx + legend_width = 2 * dx + legend_height = (2 * len(tick_labels) - 1) * dy + legend_margin = tick_label_width + 2 * dx - legendRect = QRect(0, 0, legendWidth, legendHeight) - legendRect.moveRight(widgetRect.right() - legendMargin) - legendRect.moveTop((widgetRect.height() - legendHeight) // 2) + legend_rect = QRect(0, 0, legend_width, legend_height) + legend_rect.moveRight(widget_rect.right() - legend_margin) + legend_rect.moveTop((widget_rect.height() - legend_height) // 2) - lgradient = QLinearGradient(legendRect.bottomLeft(), legendRect.topLeft()) - lgradient.setStops(self._colorLegendStopPoints) - fgPainter.setBrush(lgradient) - fgPainter.drawRect(legendRect) + lgradient = QLinearGradient(legend_rect.bottomLeft(), legend_rect.topLeft()) + lgradient.setStops(self._color_legend_stop_points) + fg_painter.setBrush(lgradient) + fg_painter.drawRect(legend_rect) - tickX0 = legendRect.right() + dx - tickY0 = legendRect.bottom() + fontMetrics.strikeOutPos() + tick_x0 = legend_rect.right() + dx + tick_y0 = legend_rect.bottom() + font_metrics.strikeOutPos() - for tickIndex, tickLabel in enumerate(tickLabels): - tickDY = (tickIndex * legendRect.height()) // (len(tickLabels) - 1) - viewportPoint = QPoint(tickX0, tickY0 - tickDY) - fgPainter.drawText(viewportPoint, tickLabel) + for tick_index, tick_label in enumerate(tick_labels): + tick_dy = (tick_index * legend_rect.height()) // (len(tick_labels) - 1) + viewport_point = QPoint(tick_x0, tick_y0 - tick_dy) + fg_painter.drawText(viewport_point, tick_label) class ImageView(QWidget): def __init__(self, parent: QWidget | None) -> None: super().__init__(parent) - self.imageRibbon = ImageRibbon.createInstance() - self.imageWidget = ImageWidget() + self.image_ribbon = ImageRibbon.create_instance() + self.image_widget = ImageWidget() @classmethod - def createInstance(cls, parent: QWidget | None = None) -> ImageView: + def create_instance(cls, parent: QWidget | None = None) -> ImageView: view = cls(parent) layout = QVBoxLayout() layout.setContentsMargins(0, 0, 0, 0) - layout.setMenuBar(view.imageRibbon) - layout.addWidget(view.imageWidget) + layout.setMenuBar(view.image_ribbon) + layout.addWidget(view.image_widget) view.setLayout(layout) view.setSizePolicy(QSizePolicy.Policy.Minimum, QSizePolicy.Policy.Minimum) diff --git a/src/ptychodus/view/make_qrc.sh b/src/ptychodus/view/make_qrc.sh index ea09461f..37cf1f87 100755 --- a/src/ptychodus/view/make_qrc.sh +++ b/src/ptychodus/view/make_qrc.sh @@ -1,6 +1,6 @@ #!/bin/sh -wget https://github.com/FortAwesome/Font-Awesome/archive/6.5.2.tar.gz -O font-awesome.tar.gz +wget https://github.com/FortAwesome/Font-Awesome/archive/6.7.2.tar.gz -O font-awesome.tar.gz tar xf font-awesome.tar.gz pyrcc5 resources.qrc -o resources.py rm -i font-awesome.tar.gz diff --git a/src/ptychodus/view/object.py b/src/ptychodus/view/object.py index 92e3861d..ba1d95ee 100644 --- a/src/ptychodus/view/object.py +++ b/src/ptychodus/view/object.py @@ -21,27 +21,27 @@ class FourierRingCorrelationDialog(QDialog): def __init__(self, parent: QWidget | None = None) -> None: super().__init__(parent) - self.product1Label = QLabel('Product 1:') - self.product1ComboBox = QComboBox() - self.product2Label = QLabel('Product 2:') - self.product2ComboBox = QComboBox() + self.product1_label = QLabel('Product 1:') + self.product1_combo_box = QComboBox() + self.product2_label = QLabel('Product 2:') + self.product2_combo_box = QComboBox() self.figure = Figure() - self.figureCanvas = FigureCanvasQTAgg(self.figure) - self.navigationToolbar = NavigationToolbar(self.figureCanvas, self) + self.figure_canvas = FigureCanvasQTAgg(self.figure) + self.navigation_toolbar = NavigationToolbar(self.figure_canvas, self) self.axes = self.figure.add_subplot(111) - parametersLayout = QGridLayout() - parametersLayout.addWidget(self.product1Label, 0, 0) - parametersLayout.addWidget(self.product1ComboBox, 0, 1) - parametersLayout.addWidget(self.product2Label, 0, 2) - parametersLayout.addWidget(self.product2ComboBox, 0, 3) - parametersLayout.setColumnStretch(1, 1) - parametersLayout.setColumnStretch(3, 1) + parameters_layout = QGridLayout() + parameters_layout.addWidget(self.product1_label, 0, 0) + parameters_layout.addWidget(self.product1_combo_box, 0, 1) + parameters_layout.addWidget(self.product2_label, 0, 2) + parameters_layout.addWidget(self.product2_combo_box, 0, 3) + parameters_layout.setColumnStretch(1, 1) + parameters_layout.setColumnStretch(3, 1) layout = QVBoxLayout() - layout.addWidget(self.navigationToolbar) - layout.addWidget(self.figureCanvas) - layout.addLayout(parametersLayout) + layout.addWidget(self.navigation_toolbar) + layout.addWidget(self.figure_canvas) + layout.addLayout(parameters_layout) self.setLayout(layout) @@ -49,21 +49,21 @@ class XMCDParametersView(QGroupBox): def __init__(self, parent: QWidget | None = None) -> None: super().__init__('Parameters', parent) - self.polarizationGroupBox = QGroupBox('Polarization') - self.lcircComboBox = QComboBox() - self.rcircComboBox = QComboBox() - self.saveButton = QPushButton('Save') - self.visualizationParametersView = VisualizationParametersView.createInstance() + self.polarization_group_box = QGroupBox('Polarization') + self.lcirc_combo_box = QComboBox() + self.rcirc_combo_box = QComboBox() + self.save_button = QPushButton('Save') + self.visualization_parameters_view = VisualizationParametersView.create_instance() - polarizationLayout = QFormLayout() - polarizationLayout.addRow('Left Circular:', self.lcircComboBox) - polarizationLayout.addRow('Right Circular:', self.rcircComboBox) - polarizationLayout.addRow(self.saveButton) - self.polarizationGroupBox.setLayout(polarizationLayout) + polarization_layout = QFormLayout() + polarization_layout.addRow('Left Circular:', self.lcirc_combo_box) + polarization_layout.addRow('Right Circular:', self.rcirc_combo_box) + polarization_layout.addRow(self.save_button) + self.polarization_group_box.setLayout(polarization_layout) layout = QVBoxLayout() - layout.addWidget(self.polarizationGroupBox) - layout.addWidget(self.visualizationParametersView) + layout.addWidget(self.polarization_group_box) + layout.addWidget(self.visualization_parameters_view) layout.addStretch() self.setLayout(layout) @@ -71,19 +71,19 @@ def __init__(self, parent: QWidget | None = None) -> None: class XMCDDialog(QDialog): def __init__(self, parent: QWidget | None = None) -> None: super().__init__(parent) - self.differenceWidget = VisualizationWidget.createInstance('Difference') - self.ratioWidget = VisualizationWidget.createInstance('Ratio') - self.sumWidget = VisualizationWidget.createInstance('Sum') - self.parametersView = XMCDParametersView() - self.statusBar = QStatusBar() + self.difference_widget = VisualizationWidget.create_instance('Difference') + self.ratio_widget = VisualizationWidget.create_instance('Ratio') + self.sum_widget = VisualizationWidget.create_instance('Sum') + self.parameters_view = XMCDParametersView() + self.status_bar = QStatusBar() - contentsLayout = QGridLayout() - contentsLayout.addWidget(self.differenceWidget, 0, 0) - contentsLayout.addWidget(self.ratioWidget, 0, 1) - contentsLayout.addWidget(self.sumWidget, 1, 0) - contentsLayout.addWidget(self.parametersView, 1, 1) + contents_layout = QGridLayout() + contents_layout.addWidget(self.difference_widget, 0, 0) + contents_layout.addWidget(self.ratio_widget, 0, 1) + contents_layout.addWidget(self.sum_widget, 1, 0) + contents_layout.addWidget(self.parameters_view, 1, 1) layout = QVBoxLayout() - layout.addLayout(contentsLayout) - layout.addWidget(self.statusBar) + layout.addLayout(contents_layout) + layout.addWidget(self.status_bar) self.setLayout(layout) diff --git a/src/ptychodus/view/patterns.py b/src/ptychodus/view/patterns.py index 218b9452..1c7a3ad0 100644 --- a/src/ptychodus/view/patterns.py +++ b/src/ptychodus/view/patterns.py @@ -4,22 +4,16 @@ from PyQt5.QtWidgets import ( QAbstractButton, QCheckBox, - QComboBox, QDialog, QDialogButtonBox, - QFormLayout, - QGridLayout, QGroupBox, QHBoxLayout, + QHeaderView, QLabel, - QLineEdit, QPushButton, - QSpinBox, - QTableView, QTreeView, QVBoxLayout, QWidget, - QWizard, QWizardPage, ) @@ -32,272 +26,100 @@ def __init__(self, parent: QWidget | None = None) -> None: class PatternsButtonBox(QWidget): def __init__(self, parent: QWidget | None = None) -> None: super().__init__(parent) - self.openButton = QPushButton('Open') - self.saveButton = QPushButton('Save') - self.infoButton = QPushButton('Info') - self.closeButton = QPushButton('Close') + self.open_button = QPushButton('Open') + self.save_button = QPushButton('Save') + self.info_button = QPushButton('Info') + self.close_button = QPushButton('Close') layout = QHBoxLayout() layout.setContentsMargins(0, 0, 0, 0) - layout.addWidget(self.openButton) - layout.addWidget(self.saveButton) - layout.addWidget(self.infoButton) - layout.addWidget(self.closeButton) + layout.addWidget(self.open_button) + layout.addWidget(self.save_button) + layout.addWidget(self.info_button) + layout.addWidget(self.close_button) self.setLayout(layout) class OpenDatasetWizardPage(QWizardPage): - def __init__(self, parent: QWidget | None) -> None: + def __init__(self, parent: QWidget | None = None) -> None: super().__init__(parent) - self._isComplete = False + self._is_complete = False - def isComplete(self) -> bool: - return self._isComplete + def isComplete(self) -> bool: # noqa: N802 + """Overrides QWizardPage.isComplete()""" + return self._is_complete - def _setComplete(self, complete: bool) -> None: - if self._isComplete != complete: - self._isComplete = complete + def _set_complete(self, complete: bool) -> None: + if self._is_complete != complete: + self._is_complete = complete self.completeChanged.emit() -class OpenDatasetWizardFilesPage(OpenDatasetWizardPage): - def __init__(self, parent: QWidget | None) -> None: - super().__init__(parent) - self.directoryComboBox = QComboBox() - self.fileSystemTableView = QTableView() - self.fileTypeLabel = QLabel('Choose File Type:') - self.fileTypeComboBox = QComboBox() - - @classmethod - def createInstance(cls, parent: QWidget | None = None) -> OpenDatasetWizardFilesPage: - view = cls(parent) - view.setTitle('Choose Dataset File(s)') - - layout = QVBoxLayout() - layout.addWidget(view.directoryComboBox) - layout.addWidget(view.fileSystemTableView) - layout.addWidget(view.fileTypeLabel) - layout.addWidget(view.fileTypeComboBox) - view.setLayout(layout) - - return view - - class OpenDatasetWizardMetadataPage(OpenDatasetWizardPage): - def __init__(self, parent: QWidget | None) -> None: - super().__init__(parent) - self.detectorPixelCountCheckBox = QCheckBox('Detector Pixel Count') - self.detectorPixelSizeCheckBox = QCheckBox('Detector Pixel Size') - self.detectorBitDepthCheckBox = QCheckBox('Detector Bit Depth') - self.detectorDistanceCheckBox = QCheckBox('Detector Distance') - self.patternCropCenterCheckBox = QCheckBox('Pattern Crop Center') - self.patternCropExtentCheckBox = QCheckBox('Pattern Crop Extent') - self.probeEnergyCheckBox = QCheckBox('Probe Energy') - - @classmethod - def createInstance(cls, parent: QWidget | None = None) -> OpenDatasetWizardMetadataPage: - view = cls(parent) - view.setTitle('Import Metadata') - - layout = QVBoxLayout() - layout.addWidget(view.detectorPixelCountCheckBox) - layout.addWidget(view.detectorPixelSizeCheckBox) - layout.addWidget(view.detectorBitDepthCheckBox) - layout.addWidget(view.detectorDistanceCheckBox) - layout.addWidget(view.patternCropCenterCheckBox) - layout.addWidget(view.patternCropExtentCheckBox) - layout.addWidget(view.probeEnergyCheckBox) - layout.addStretch() - view.setLayout(layout) - - return view - - -class OpenDatasetWizardPatternLoadView(QGroupBox): - def __init__(self, parent: QWidget | None) -> None: - super().__init__('Load', parent) - self.numberOfThreadsSpinBox = QSpinBox() - - @classmethod - def createInstance(cls, parent: QWidget | None = None) -> OpenDatasetWizardPatternLoadView: - view = cls(parent) - - layout = QFormLayout() - layout.addRow('Number of Data Threads:', view.numberOfThreadsSpinBox) - view.setLayout(layout) - - return view - - -class OpenDatasetWizardPatternCropView(QGroupBox): - def __init__(self, parent: QWidget | None) -> None: - super().__init__('Crop', parent) - self.centerLabel = QLabel('Center [px]:') - self.centerXSpinBox = QSpinBox() - self.centerYSpinBox = QSpinBox() - self.extentLabel = QLabel('Extent [px]:') - self.extentXSpinBox = QSpinBox() - self.extentYSpinBox = QSpinBox() - - @classmethod - def createInstance(cls, parent: QWidget | None = None) -> OpenDatasetWizardPatternCropView: - view = cls(parent) - - layout = QGridLayout() - layout.addWidget(view.centerLabel, 0, 0) - layout.addWidget(view.centerXSpinBox, 0, 1) - layout.addWidget(view.centerYSpinBox, 0, 2) - layout.addWidget(view.extentLabel, 1, 0) - layout.addWidget(view.extentXSpinBox, 1, 1) - layout.addWidget(view.extentYSpinBox, 1, 2) - layout.setColumnStretch(1, 1) - layout.setColumnStretch(2, 1) - view.setLayout(layout) - - return view - - -class OpenDatasetWizardPatternTransformView(QGroupBox): - def __init__(self, parent: QWidget | None) -> None: - super().__init__('Transform', parent) - self.valueLowerBoundCheckBox = QCheckBox('Value Lower Bound:') - self.valueLowerBoundSpinBox = QSpinBox() - self.valueUpperBoundCheckBox = QCheckBox('Value Upper Bound:') - self.valueUpperBoundSpinBox = QSpinBox() - self.axesLabel = QLabel('Axes:') - self.flipXCheckBox = QCheckBox('Flip X') - self.flipYCheckBox = QCheckBox('Flip Y') - - @classmethod - def createInstance(cls, parent: QWidget | None = None) -> OpenDatasetWizardPatternTransformView: - view = cls(parent) - - layout = QGridLayout() - layout.addWidget(view.valueLowerBoundCheckBox, 0, 0) - layout.addWidget(view.valueLowerBoundSpinBox, 0, 1, 1, 2) - layout.addWidget(view.valueUpperBoundCheckBox, 1, 0) - layout.addWidget(view.valueUpperBoundSpinBox, 1, 1, 1, 2) - layout.addWidget(view.axesLabel, 2, 0) - layout.addWidget(view.flipXCheckBox, 2, 1, Qt.AlignmentFlag.AlignHCenter) - layout.addWidget(view.flipYCheckBox, 2, 2, Qt.AlignmentFlag.AlignHCenter) - layout.setColumnStretch(2, 1) - layout.setColumnStretch(3, 1) - view.setLayout(layout) - - return view - - -class OpenDatasetWizardPatternMemoryMapView(QGroupBox): - def __init__(self, parent: QWidget | None) -> None: - super().__init__('Memory Map Diffraction Data', parent) - self.scratchDirectoryLabel = QLabel('Scratch Directory:') - self.scratchDirectoryLineEdit = QLineEdit() - self.scratchDirectoryBrowseButton = QPushButton('Browse') - - @classmethod - def createInstance(cls, parent: QWidget | None = None) -> OpenDatasetWizardPatternMemoryMapView: - view = cls(parent) - - layout = QGridLayout() - layout.addWidget(view.scratchDirectoryLabel, 1, 0) - layout.addWidget(view.scratchDirectoryLineEdit, 1, 1) - layout.addWidget(view.scratchDirectoryBrowseButton, 1, 2) - layout.setColumnStretch(1, 1) - view.setLayout(layout) - - return view - - -class OpenDatasetWizardPatternsPage(OpenDatasetWizardPage): - def __init__(self, parent: QWidget | None) -> None: + def __init__(self, parent: QWidget | None = None) -> None: super().__init__(parent) - self.loadView = OpenDatasetWizardPatternLoadView.createInstance() - self.memoryMapView = OpenDatasetWizardPatternMemoryMapView.createInstance() - self.cropView = OpenDatasetWizardPatternCropView.createInstance() - self.transformView = OpenDatasetWizardPatternTransformView.createInstance() + self.detector_extent_check_box = QCheckBox('Detector Extent') + self.detector_pixel_size_check_box = QCheckBox('Detector Pixel Size') + self.detector_bit_depth_check_box = QCheckBox('Detector Bit Depth') + self.detector_distance_check_box = QCheckBox('Detector Distance') + self.pattern_crop_center_check_box = QCheckBox('Pattern Crop Center') + self.pattern_crop_extent_check_box = QCheckBox('Pattern Crop Extent') + self.probe_photon_count_check_box = QCheckBox('Probe Photon Count') + self.probe_energy_check_box = QCheckBox('Probe Energy') - @classmethod - def createInstance(cls, parent: QWidget | None = None) -> OpenDatasetWizardPatternsPage: - view = cls(parent) - view.setTitle('Pattern Processing') + self.setTitle('Import Metadata') layout = QVBoxLayout() - layout.addWidget(view.loadView) - layout.addWidget(view.memoryMapView) - layout.addWidget(view.cropView) - layout.addWidget(view.transformView) + layout.addWidget(self.detector_extent_check_box) + layout.addWidget(self.detector_pixel_size_check_box) + layout.addWidget(self.detector_bit_depth_check_box) + layout.addWidget(self.detector_distance_check_box) + layout.addWidget(self.pattern_crop_center_check_box) + layout.addWidget(self.pattern_crop_extent_check_box) + layout.addWidget(self.probe_photon_count_check_box) + layout.addWidget(self.probe_energy_check_box) layout.addStretch() - view.setLayout(layout) - - return view - - -class OpenDatasetWizard(QWizard): - def __init__(self, parent: QWidget | None) -> None: - super().__init__(parent) - self.filesPage = OpenDatasetWizardFilesPage.createInstance() - self.metadataPage = OpenDatasetWizardMetadataPage.createInstance() - self.patternsPage = OpenDatasetWizardPatternsPage.createInstance() - - @classmethod - def createInstance(cls, parent: QWidget | None = None) -> OpenDatasetWizard: - view = cls(parent) - - view.setWindowTitle('Open Dataset') - view.addPage(view.filesPage) - view.addPage(view.metadataPage) - view.addPage(view.patternsPage) - - return view + self.setLayout(layout) class PatternsInfoDialog(QDialog): - def __init__(self, parent: QWidget | None) -> None: + def __init__(self, parent: QWidget | None = None) -> None: super().__init__(parent) - self.treeView = QTreeView() - self.buttonBox = QDialogButtonBox() + self.tree_view = QTreeView() + self.button_box = QDialogButtonBox() - @classmethod - def createInstance(cls, parent: QWidget | None = None) -> PatternsInfoDialog: - view = cls(parent) - view.treeView.header().setDefaultAlignment(Qt.AlignmentFlag.AlignCenter) + tree_header = self.tree_view.header() + tree_header.setDefaultAlignment(Qt.AlignmentFlag.AlignCenter) + tree_header.setSectionResizeMode(QHeaderView.ResizeToContents) - view.buttonBox.addButton(QDialogButtonBox.StandardButton.Ok) - view.buttonBox.clicked.connect(view._handleButtonBoxClicked) + self.button_box.addButton(QDialogButtonBox.StandardButton.Ok) + self.button_box.clicked.connect(self._handle_button_box_clicked) layout = QVBoxLayout() - layout.addWidget(view.treeView) - layout.addWidget(view.buttonBox) - view.setLayout(layout) - - return view + layout.addWidget(self.tree_view) + layout.addWidget(self.button_box) + self.setLayout(layout) - def _handleButtonBoxClicked(self, button: QAbstractButton) -> None: - if self.buttonBox.buttonRole(button) == QDialogButtonBox.ButtonRole.AcceptRole: + def _handle_button_box_clicked(self, button: QAbstractButton) -> None: + if self.button_box.buttonRole(button) == QDialogButtonBox.ButtonRole.AcceptRole: self.accept() else: self.reject() class PatternsView(QWidget): - def __init__(self, parent: QWidget | None) -> None: + def __init__(self, parent: QWidget | None = None) -> None: super().__init__(parent) - self.detectorView = DetectorView() - self.treeView = QTreeView() - self.infoLabel = QLabel() - self.buttonBox = PatternsButtonBox() - self.openDatasetWizard = OpenDatasetWizard.createInstance(self) + self.detector_view = DetectorView() + self.tree_view = QTreeView() + self.info_label = QLabel() + self.button_box = PatternsButtonBox() - @classmethod - def createInstance(cls, parent: QWidget | None = None) -> PatternsView: - view = cls(parent) - view.treeView.header().setDefaultAlignment(Qt.AlignmentFlag.AlignCenter) + self.tree_view.header().setDefaultAlignment(Qt.AlignmentFlag.AlignCenter) layout = QVBoxLayout() - layout.addWidget(view.detectorView) - layout.addWidget(view.treeView) - layout.addWidget(view.infoLabel) - layout.addWidget(view.buttonBox) - view.setLayout(layout) - - return view + layout.addWidget(self.detector_view) + layout.addWidget(self.tree_view) + layout.addWidget(self.info_label) + layout.addWidget(self.button_box) + self.setLayout(layout) diff --git a/src/ptychodus/view/probe.py b/src/ptychodus/view/probe.py index ff7cc83d..50aadb25 100644 --- a/src/ptychodus/view/probe.py +++ b/src/ptychodus/view/probe.py @@ -26,22 +26,22 @@ class ProbePropagationParametersView(QGroupBox): def __init__(self, parent: QWidget | None = None) -> None: super().__init__('Parameters', parent) - self.beginCoordinateWidget = LengthWidget.createInstance(isSigned=True) - self.endCoordinateWidget = LengthWidget.createInstance(isSigned=True) - self.numberOfStepsSpinBox = QSpinBox() - self.visualizationParametersView = VisualizationParametersView.createInstance() + self.begin_coordinate_widget = LengthWidget.create_instance(is_signed=True) + self.end_coordinate_widget = LengthWidget.create_instance(is_signed=True) + self.num_steps_spin_box = QSpinBox() + self.visualization_parameters_view = VisualizationParametersView.create_instance() - propagationLayout = QFormLayout() - propagationLayout.addRow('Begin Coordinate:', self.beginCoordinateWidget) - propagationLayout.addRow('End Coordinate:', self.endCoordinateWidget) - propagationLayout.addRow('Number of Steps:', self.numberOfStepsSpinBox) + propagation_layout = QFormLayout() + propagation_layout.addRow('Begin Coordinate:', self.begin_coordinate_widget) + propagation_layout.addRow('End Coordinate:', self.end_coordinate_widget) + propagation_layout.addRow('Number of Steps:', self.num_steps_spin_box) - propagationGroupBox = QGroupBox('Propagation') - propagationGroupBox.setLayout(propagationLayout) + propagation_group_box = QGroupBox('Propagation') + propagation_group_box.setLayout(propagation_layout) layout = QVBoxLayout() - layout.addWidget(propagationGroupBox) - layout.addWidget(self.visualizationParametersView) + layout.addWidget(propagation_group_box) + layout.addWidget(self.visualization_parameters_view) layout.addStretch() self.setLayout(layout) @@ -49,197 +49,199 @@ def __init__(self, parent: QWidget | None = None) -> None: class ProbePropagationDialog(QDialog): def __init__(self, parent: QWidget | None = None) -> None: super().__init__(parent) - self.xyView = VisualizationWidget.createInstance('XY Plane') - self.zxView = VisualizationWidget.createInstance('ZX Plane') - self.parametersView = ProbePropagationParametersView() - self.zyView = VisualizationWidget.createInstance('ZY Plane') - self.propagateButton = QPushButton('Propagate') - self.saveButton = QPushButton('Save') - self.coordinateSlider = QSlider(Qt.Orientation.Horizontal) - self.coordinateLabel = QLabel() - self.statusBar = QStatusBar() - - actionLayout = QHBoxLayout() - actionLayout.addWidget(self.propagateButton) - actionLayout.addWidget(self.saveButton) - - coordinateLayout = QHBoxLayout() - coordinateLayout.setContentsMargins(0, 0, 0, 0) - coordinateLayout.addWidget(self.coordinateSlider) - coordinateLayout.addWidget(self.coordinateLabel) - - contentsLayout = QGridLayout() - contentsLayout.addWidget(self.xyView, 0, 0) - contentsLayout.addWidget(self.zxView, 0, 1) - contentsLayout.addWidget(self.parametersView, 1, 0) - contentsLayout.addWidget(self.zyView, 1, 1) - contentsLayout.addLayout(actionLayout, 2, 0) - contentsLayout.addLayout(coordinateLayout, 2, 1) - contentsLayout.setColumnStretch(0, 1) - contentsLayout.setColumnStretch(1, 2) + self.xy_view = VisualizationWidget.create_instance('XY Plane') + self.zx_view = VisualizationWidget.create_instance('ZX Plane') + self.parameters_view = ProbePropagationParametersView() + self.zy_view = VisualizationWidget.create_instance('ZY Plane') + self.propagate_button = QPushButton('Propagate') + self.save_button = QPushButton('Save') + self.coordinate_slider = QSlider(Qt.Orientation.Horizontal) + self.coordinate_label = QLabel() + self.status_bar = QStatusBar() + + action_layout = QHBoxLayout() + action_layout.addWidget(self.propagate_button) + action_layout.addWidget(self.save_button) + + coordinate_layout = QHBoxLayout() + coordinate_layout.setContentsMargins(0, 0, 0, 0) + coordinate_layout.addWidget(self.coordinate_slider) + coordinate_layout.addWidget(self.coordinate_label) + + contents_layout = QGridLayout() + contents_layout.addWidget(self.xy_view, 0, 0) + contents_layout.addWidget(self.zx_view, 0, 1) + contents_layout.addWidget(self.parameters_view, 1, 0) + contents_layout.addWidget(self.zy_view, 1, 1) + contents_layout.addLayout(action_layout, 2, 0) + contents_layout.addLayout(coordinate_layout, 2, 1) + contents_layout.setColumnStretch(0, 1) + contents_layout.setColumnStretch(1, 2) layout = QVBoxLayout() - layout.addLayout(contentsLayout) - layout.addWidget(self.statusBar) + layout.addLayout(contents_layout) + layout.addWidget(self.status_bar) self.setLayout(layout) class STXMDialog(QDialog): def __init__(self, parent: QWidget | None = None) -> None: super().__init__(parent) - self.visualizationWidget = VisualizationWidget.createInstance('Transmission') - self.visualizationParametersView = VisualizationParametersView.createInstance() - self.saveButton = QPushButton('Save') - self.statusBar = QStatusBar() + self.visualization_widget = VisualizationWidget.create_instance('Transmission') + self.visualization_parameters_view = VisualizationParametersView.create_instance() + self.save_button = QPushButton('Save') + self.status_bar = QStatusBar() - parameterLayout = QVBoxLayout() - parameterLayout.addWidget(self.visualizationParametersView) - parameterLayout.addStretch() - parameterLayout.addWidget(self.saveButton) + parameter_layout = QVBoxLayout() + parameter_layout.addWidget(self.visualization_parameters_view) + parameter_layout.addStretch() + parameter_layout.addWidget(self.save_button) - contentsLayout = QHBoxLayout() - contentsLayout.addWidget(self.visualizationWidget, 1) - contentsLayout.addLayout(parameterLayout) + contents_layout = QHBoxLayout() + contents_layout.addWidget(self.visualization_widget, 1) + contents_layout.addLayout(parameter_layout) layout = QVBoxLayout() - layout.addLayout(contentsLayout) - layout.addWidget(self.statusBar) + layout.addLayout(contents_layout) + layout.addWidget(self.status_bar) self.setLayout(layout) -class ExposureParametersView(QGroupBox): +class IlluminationParametersView(QGroupBox): def __init__(self, parent: QWidget | None = None) -> None: super().__init__('Parameters', parent) - self.quantitativeProbeCheckBox = QCheckBox('Quantitative Probe') - self.photonFluxLineEdit = DecimalLineEdit.createInstance() - self.exposureTimeLineEdit = DecimalLineEdit.createInstance() - self.massAttenuationLabel = QLabel('Mass Attenuation [m\u00b2/kg]:') - self.massAttenuationLineEdit = DecimalLineEdit.createInstance() + self.quantitative_probe_check_box = QCheckBox('Quantitative Probe') + self.photon_flux_line_edit = DecimalLineEdit.create_instance() + self.exposure_time_line_edit = DecimalLineEdit.create_instance() + self.mass_attenuation_label = QLabel('Mass Attenuation [m\u00b2/kg]:') + self.mass_attenuation_line_edit = DecimalLineEdit.create_instance() layout = QFormLayout() - layout.addRow(self.quantitativeProbeCheckBox) - layout.addRow('Photon Flux [ph/s]:', self.photonFluxLineEdit) - layout.addRow('Exposure Time [s]:', self.exposureTimeLineEdit) - layout.addRow(self.massAttenuationLabel) - layout.addRow(self.massAttenuationLineEdit) + layout.addRow(self.quantitative_probe_check_box) + layout.addRow('Photon Flux [ph/s]:', self.photon_flux_line_edit) + layout.addRow('Exposure Time [s]:', self.exposure_time_line_edit) + layout.addRow(self.mass_attenuation_label) + layout.addRow(self.mass_attenuation_line_edit) self.setLayout(layout) -class ExposureQuantityView(QGroupBox): +class IlluminationQuantityView(QGroupBox): def __init__(self, parent: QWidget | None = None) -> None: super().__init__('Quantity', parent) - self.photonCountButton = QRadioButton('Photon Count') - self.photonFluxButton = QRadioButton('Photon Flux [Hz]') - self.exposureButton = QRadioButton('Exposure [J/m\u00b2]') - self.irradianceButton = QRadioButton('Irradiance [W/m\u00b2]') - self.doseButton = QRadioButton('Dose [Gy]') - self.doseRateButton = QRadioButton('Dose Rate [Gy/s]') + self.photon_number_button = QRadioButton('Photon Number') + self.photon_fluence_button = QRadioButton('Photon Fluence [1/m\u00b2]') + self.photon_fluence_rate_button = QRadioButton('Photon Fluence Rate [Hz/m\u00b2]') + self.energy_fluence_button = QRadioButton('Energy Fluence [J/m\u00b2]') + self.energy_fluence_rate_button = QRadioButton('Energy Fluence Rate [W/m\u00b2]') + self.dose_button = QRadioButton('Dose [Gy]') + self.dose_rate_button = QRadioButton('Dose Rate [Gy/s]') layout = QVBoxLayout() - layout.addWidget(self.photonCountButton) - layout.addWidget(self.photonFluxButton) - layout.addWidget(self.exposureButton) - layout.addWidget(self.irradianceButton) - layout.addWidget(self.doseButton) - layout.addWidget(self.doseRateButton) + layout.addWidget(self.photon_number_button) + layout.addWidget(self.photon_fluence_button) + layout.addWidget(self.photon_fluence_rate_button) + layout.addWidget(self.energy_fluence_button) + layout.addWidget(self.energy_fluence_rate_button) + layout.addWidget(self.dose_button) + layout.addWidget(self.dose_rate_button) self.setLayout(layout) -class ExposureDialog(QDialog): +class IlluminationDialog(QDialog): def __init__(self, parent: QWidget | None = None) -> None: super().__init__(parent) - self.visualizationWidget = VisualizationWidget.createInstance('Visualization') - self.exposureParametersView = ExposureParametersView() - self.exposureQuantityView = ExposureQuantityView() - self.visualizationParametersView = VisualizationParametersView.createInstance() - self.saveButton = QPushButton('Save') - self.statusBar = QStatusBar() - - parameterLayout = QVBoxLayout() - parameterLayout.addWidget(self.exposureParametersView) - parameterLayout.addWidget(self.exposureQuantityView) - parameterLayout.addWidget(self.visualizationParametersView) - parameterLayout.addWidget(self.saveButton) - parameterLayout.addStretch() - - contentsLayout = QHBoxLayout() - contentsLayout.addWidget(self.visualizationWidget, 1) - contentsLayout.addLayout(parameterLayout) + self.visualization_widget = VisualizationWidget.create_instance('Visualization') + self.exposure_parameters_view = IlluminationParametersView() + self.exposure_quantity_view = IlluminationQuantityView() + self.visualization_parameters_view = VisualizationParametersView.create_instance() + self.save_button = QPushButton('Save') + self.status_bar = QStatusBar() + + parameter_layout = QVBoxLayout() + parameter_layout.addWidget(self.exposure_parameters_view) + parameter_layout.addWidget(self.exposure_quantity_view) + parameter_layout.addWidget(self.visualization_parameters_view) + parameter_layout.addWidget(self.save_button) + parameter_layout.addStretch() + + contents_layout = QHBoxLayout() + contents_layout.addWidget(self.visualization_widget, 1) + contents_layout.addLayout(parameter_layout) layout = QVBoxLayout() - layout.addLayout(contentsLayout) - layout.addWidget(self.statusBar) + layout.addLayout(contents_layout) + layout.addWidget(self.status_bar) self.setLayout(layout) class FluorescenceVSPIParametersView(QWidget): def __init__(self, parent: QWidget | None = None) -> None: super().__init__(parent) - self.dampingFactorLineEdit = DecimalLineEdit.createInstance() - self.maxIterationsSpinBox = QSpinBox() + self.damping_factor_line_edit = DecimalLineEdit.create_instance() + self.max_iterations_spin_box = QSpinBox() layout = QFormLayout() layout.setContentsMargins(0, 0, 0, 0) - layout.addRow('Damping Factor:', self.dampingFactorLineEdit) - layout.addRow('Max Iterations:', self.maxIterationsSpinBox) + layout.addRow('Damping Factor:', self.damping_factor_line_edit) + layout.addRow('Max Iterations:', self.max_iterations_spin_box) self.setLayout(layout) class FluorescenceTwoStepParametersView(QWidget): def __init__(self, parent: QWidget | None = None) -> None: super().__init__(parent) - self.upscalingStrategyComboBox = QComboBox() - self.deconvolutionStrategyComboBox = QComboBox() + self.upscaling_strategy_combo_box = QComboBox() + self.deconvolution_strategy_combo_box = QComboBox() layout = QFormLayout() layout.setContentsMargins(0, 0, 0, 0) - layout.addRow('Upscaling Strategy:', self.upscalingStrategyComboBox) - layout.addRow('Deconvolution Strategy:', self.deconvolutionStrategyComboBox) + layout.addRow('Upscaling Strategy:', self.upscaling_strategy_combo_box) + layout.addRow('Deconvolution Strategy:', self.deconvolution_strategy_combo_box) self.setLayout(layout) class FluorescenceParametersView(QGroupBox): def __init__(self, parent: QWidget | None = None) -> None: super().__init__('Enhancement Strategy', parent) - self.openButton = QPushButton('Open Measured Dataset') - self.algorithmComboBox = QComboBox() - self.stackedWidget = QStackedWidget() - self.enhanceButton = QPushButton('Enhance') - self.saveButton = QPushButton('Save Enhanced Dataset') + self.open_button = QPushButton('Open Measured Dataset') + self.algorithm_combo_box = QComboBox() + self.stacked_widget = QStackedWidget() + self.enhance_button = QPushButton('Enhance') + self.save_button = QPushButton('Save Enhanced Dataset') - self.stackedWidget.layout().setContentsMargins(0, 0, 0, 0) + self.stacked_widget.layout().setContentsMargins(0, 0, 0, 0) layout = QFormLayout() - layout.addRow(self.openButton) - layout.addRow('Algorithm:', self.algorithmComboBox) - layout.addRow(self.stackedWidget) - layout.addRow(self.enhanceButton) - layout.addRow(self.saveButton) + layout.addRow(self.open_button) + layout.addRow('Algorithm:', self.algorithm_combo_box) + layout.addRow(self.stacked_widget) + layout.addRow(self.enhance_button) + layout.addRow(self.save_button) self.setLayout(layout) class FluorescenceDialog(QDialog): def __init__(self, parent: QWidget | None = None) -> None: super().__init__(parent) - self.measuredWidget = VisualizationWidget.createInstance('Measured') - self.enhancedWidget = VisualizationWidget.createInstance('Enhanced') - self.fluorescenceParametersView = FluorescenceParametersView() - self.fluorescenceChannelListView = QListView() - self.visualizationParametersView = VisualizationParametersView.createInstance() - self.statusBar = QStatusBar() - - parameterLayout = QVBoxLayout() - parameterLayout.addWidget(self.fluorescenceParametersView) - parameterLayout.addWidget(self.fluorescenceChannelListView, 1) - parameterLayout.addWidget(self.visualizationParametersView) - parameterLayout.addStretch() - - contentsLayout = QHBoxLayout() - contentsLayout.addWidget(self.measuredWidget, 1) - contentsLayout.addWidget(self.enhancedWidget, 1) - contentsLayout.addLayout(parameterLayout) + self.measured_widget = VisualizationWidget.create_instance('Measured') + self.enhanced_widget = VisualizationWidget.create_instance('Enhanced') + self.fluorescence_parameters_view = FluorescenceParametersView() + self.fluorescence_channel_list_view = QListView() + self.visualization_parameters_view = VisualizationParametersView.create_instance() + self.status_bar = QStatusBar() + + parameter_layout = QVBoxLayout() + parameter_layout.addWidget(self.fluorescence_parameters_view) + parameter_layout.addWidget(self.fluorescence_channel_list_view, 1) + parameter_layout.addWidget(self.visualization_parameters_view) + parameter_layout.addStretch() + + contents_layout = QHBoxLayout() + contents_layout.addWidget(self.measured_widget, 1) + contents_layout.addWidget(self.enhanced_widget, 1) + contents_layout.addLayout(parameter_layout) layout = QVBoxLayout() - layout.addLayout(contentsLayout) - layout.addWidget(self.statusBar) + layout.addLayout(contents_layout) + layout.addWidget(self.status_bar) self.setLayout(layout) diff --git a/src/ptychodus/view/product.py b/src/ptychodus/view/product.py index 3825cdd4..3fc9f77c 100644 --- a/src/ptychodus/view/product.py +++ b/src/ptychodus/view/product.py @@ -16,31 +16,68 @@ ) +class ProductEditorPropertiesView(QGroupBox): + def __init__(self, parent: QWidget | None = None) -> None: + super().__init__('Properties') + self.table_view = QTableView() + + layout = QVBoxLayout() + layout.addWidget(self.table_view) + self.setLayout(layout) + + +class ProductEditorActionsView(QGroupBox): + def __init__(self, parent: QWidget | None = None) -> None: + super().__init__('Actions') + self.estimate_probe_photon_count_button = QPushButton('Estimate Probe Photon Count') + + layout = QVBoxLayout() + layout.addWidget(self.estimate_probe_photon_count_button) + layout.addStretch() + self.setLayout(layout) + + +class ProductEditorCommentsView(QGroupBox): + def __init__(self, parent: QWidget | None = None) -> None: + super().__init__('Comments') + self.text_edit = QPlainTextEdit() + + layout = QVBoxLayout() + layout.addWidget(self.text_edit) + self.setLayout(layout) + + class ProductEditorDialog(QDialog): def __init__(self, parent: QWidget | None = None) -> None: super().__init__(parent) - self.tableView = QTableView() - self.textEdit = QPlainTextEdit() - self.buttonBox = QDialogButtonBox() - - commentsLayout = QVBoxLayout() - commentsLayout.setContentsMargins(0, 0, 0, 0) - commentsLayout.addWidget(self.textEdit) + self.properties_view = ProductEditorPropertiesView() + self.actions_view = ProductEditorActionsView() + self.comments_view = ProductEditorCommentsView() + self.button_box = QDialogButtonBox() - commentsBox = QGroupBox('Comments') - commentsBox.setLayout(commentsLayout) + self.button_box.addButton(QDialogButtonBox.StandardButton.Ok) + self.button_box.clicked.connect(self._handle_button_box_clicked) - self.buttonBox.addButton(QDialogButtonBox.StandardButton.Ok) - self.buttonBox.clicked.connect(self._handleButtonBoxClicked) + top_layout = QHBoxLayout() + top_layout.addWidget(self.properties_view) + top_layout.addWidget(self.actions_view) layout = QVBoxLayout() - layout.addWidget(self.tableView) - layout.addWidget(commentsBox) - layout.addWidget(self.buttonBox) + layout.addLayout(top_layout) + layout.addWidget(self.comments_view) + layout.addWidget(self.button_box) self.setLayout(layout) - def _handleButtonBoxClicked(self, button: QAbstractButton) -> None: - if self.buttonBox.buttonRole(button) == QDialogButtonBox.ButtonRole.AcceptRole: + @property + def table_view(self) -> QTableView: + return self.properties_view.table_view + + @property + def text_edit(self) -> QPlainTextEdit: + return self.comments_view.text_edit + + def _handle_button_box_clicked(self, button: QAbstractButton) -> None: + if self.button_box.buttonRole(button) == QDialogButtonBox.ButtonRole.AcceptRole: self.accept() else: self.reject() @@ -49,34 +86,34 @@ def _handleButtonBoxClicked(self, button: QAbstractButton) -> None: class ProductButtonBox(QWidget): def __init__(self, parent: QWidget | None = None) -> None: super().__init__(parent) - self.insertMenu = QMenu() - self.insertButton = QPushButton('Insert') - self.saveMenu = QMenu() - self.saveButton = QPushButton('Save') - self.editButton = QPushButton('Edit') - self.removeButton = QPushButton('Remove') + self.insert_menu = QMenu() + self.insert_button = QPushButton('Insert') + self.save_menu = QMenu() + self.save_button = QPushButton('Save') + self.edit_button = QPushButton('Edit') + self.remove_button = QPushButton('Remove') - self.insertButton.setMenu(self.insertMenu) - self.saveButton.setMenu(self.saveMenu) + self.insert_button.setMenu(self.insert_menu) + self.save_button.setMenu(self.save_menu) layout = QHBoxLayout() layout.setContentsMargins(0, 0, 0, 0) - layout.addWidget(self.insertButton) - layout.addWidget(self.saveButton) - layout.addWidget(self.editButton) - layout.addWidget(self.removeButton) + layout.addWidget(self.insert_button) + layout.addWidget(self.save_button) + layout.addWidget(self.edit_button) + layout.addWidget(self.remove_button) self.setLayout(layout) class ProductView(QWidget): def __init__(self, parent: QWidget | None = None) -> None: super().__init__(parent) - self.tableView = QTableView() - self.infoLabel = QLabel() - self.buttonBox = ProductButtonBox() + self.table_view = QTableView() + self.info_label = QLabel() + self.button_box = ProductButtonBox() layout = QVBoxLayout() - layout.addWidget(self.tableView) - layout.addWidget(self.infoLabel) - layout.addWidget(self.buttonBox) + layout.addWidget(self.table_view) + layout.addWidget(self.info_label) + layout.addWidget(self.button_box) self.setLayout(layout) diff --git a/src/ptychodus/view/ptychonn.py b/src/ptychodus/view/ptychonn.py deleted file mode 100644 index dc937584..00000000 --- a/src/ptychodus/view/ptychonn.py +++ /dev/null @@ -1,78 +0,0 @@ -from __future__ import annotations - -from PyQt5.QtCore import Qt -from PyQt5.QtWidgets import ( - QCheckBox, - QFormLayout, - QGroupBox, - QSpinBox, - QVBoxLayout, - QWidget, -) - -from .widgets import DecimalLineEdit, DecimalSlider - - -class PtychoNNModelParametersView(QGroupBox): - def __init__(self, parent: QWidget | None) -> None: - super().__init__('Model Parameters', parent) - self.numberOfConvolutionKernelsSpinBox = QSpinBox() - self.batchSizeSpinBox = QSpinBox() - self.useBatchNormalizationCheckBox = QCheckBox('Use Batch Normalization') - - @classmethod - def createInstance(cls, parent: QWidget | None = None) -> PtychoNNModelParametersView: - view = cls(parent) - - layout = QFormLayout() - layout.addRow('Convolution Kernels:', view.numberOfConvolutionKernelsSpinBox) - layout.addRow('Batch Size:', view.batchSizeSpinBox) - layout.addRow(view.useBatchNormalizationCheckBox) - view.setLayout(layout) - - return view - - -class PtychoNNTrainingParametersView(QGroupBox): - def __init__(self, parent: QWidget | None) -> None: - super().__init__('Training Parameters', parent) - self.validationSetFractionalSizeSlider = DecimalSlider.createInstance( - Qt.Orientation.Horizontal, numberOfTicks=20 - ) - self.maximumLearningRateLineEdit = DecimalLineEdit.createInstance() - self.minimumLearningRateLineEdit = DecimalLineEdit.createInstance() - self.trainingEpochsSpinBox = QSpinBox() - self.statusIntervalSpinBox = QSpinBox() - - @classmethod - def createInstance(cls, parent: QWidget | None = None) -> PtychoNNTrainingParametersView: - view = cls(parent) - - layout = QFormLayout() - layout.addRow('Validation Set Fractional Size:', view.validationSetFractionalSizeSlider) - layout.addRow('Maximum Learning Rate:', view.maximumLearningRateLineEdit) - layout.addRow('Minimum Learning Rate:', view.minimumLearningRateLineEdit) - layout.addRow('Training Epochs:', view.trainingEpochsSpinBox) - layout.addRow('Status Interval:', view.statusIntervalSpinBox) - view.setLayout(layout) - - return view - - -class PtychoNNParametersView(QWidget): - def __init__(self, parent: QWidget | None) -> None: - super().__init__(parent) - self.modelParametersView = PtychoNNModelParametersView.createInstance() - self.trainingParametersView = PtychoNNTrainingParametersView.createInstance() - - @classmethod - def createInstance(cls, parent: QWidget | None = None) -> PtychoNNParametersView: - view = cls(parent) - - layout = QVBoxLayout() - layout.addWidget(view.modelParametersView) - layout.addWidget(view.trainingParametersView) - layout.addStretch() - view.setLayout(layout) - - return view diff --git a/src/ptychodus/view/reconstructor.py b/src/ptychodus/view/reconstructor.py index 1f482ab0..26627c67 100644 --- a/src/ptychodus/view/reconstructor.py +++ b/src/ptychodus/view/reconstructor.py @@ -1,5 +1,3 @@ -from __future__ import annotations - from PyQt5.QtWidgets import ( QComboBox, QDialog, @@ -22,106 +20,81 @@ from matplotlib.figure import Figure -class ReconstructorView(QGroupBox): - def __init__(self, parent: QWidget | None) -> None: +class ReconstructorParametersView(QGroupBox): + def __init__(self, parent: QWidget | None = None) -> None: super().__init__('Parameters', parent) - self.algorithmComboBox = QComboBox() - self.productComboBox = QComboBox() - self.modelButton = QPushButton('Model') - self.modelMenu = QMenu() - self.trainerButton = QPushButton('Trainer') - self.trainerMenu = QMenu() - self.reconstructorButton = QPushButton('Reconstructor') - self.reconstructorMenu = QMenu() - - @classmethod - def createInstance(cls, parent: QWidget | None = None) -> ReconstructorView: - view = cls(parent) - - view.modelButton.setMenu(view.modelMenu) - view.trainerButton.setMenu(view.trainerMenu) - view.reconstructorButton.setMenu(view.reconstructorMenu) - - actionLayout = QHBoxLayout() - actionLayout.setContentsMargins(0, 0, 0, 0) - actionLayout.addWidget(view.modelButton) - actionLayout.addWidget(view.trainerButton) - actionLayout.addWidget(view.reconstructorButton) + self.algorithm_combo_box = QComboBox() + self.product_combo_box = QComboBox() - layout = QFormLayout() - layout.addRow('Algorithm:', view.algorithmComboBox) - layout.addRow('Product:', view.productComboBox) - layout.addRow('Action:', actionLayout) - view.setLayout(layout) + self.reconstructor_menu = QMenu() + self.reconstructor_button = QPushButton('Reconstructor') + self.reconstructor_button.setMenu(self.reconstructor_menu) + + self.trainer_menu = QMenu() + self.trainer_button = QPushButton('Trainer') + self.trainer_button.setMenu(self.trainer_menu) - return view + action_layout = QHBoxLayout() + action_layout.setContentsMargins(0, 0, 0, 0) + action_layout.addWidget(self.reconstructor_button) + action_layout.addWidget(self.trainer_button) + + layout = QFormLayout() + layout.addRow('Algorithm:', self.algorithm_combo_box) + layout.addRow('Product:', self.product_combo_box) + layout.addRow('Action:', action_layout) + self.setLayout(layout) class ReconstructorProgressDialog(QDialog): - def __init__(self, parent: QWidget | None) -> None: + def __init__(self, parent: QWidget | None = None) -> None: super().__init__(parent) - self.textEdit = QPlainTextEdit() - self.progressBar = QProgressBar() - self.buttonBox = QDialogButtonBox() - - @classmethod - def createInstance(cls, parent: QWidget | None = None) -> ReconstructorProgressDialog: - dialog = cls(parent) - dialog.setWindowTitle('Reconstruction Progress') - dialog.buttonBox.addButton(QDialogButtonBox.Ok) - dialog.buttonBox.accepted.connect(dialog.accept) - dialog.buttonBox.addButton(QDialogButtonBox.Cancel) - dialog.buttonBox.rejected.connect(dialog.reject) + self.text_edit = QPlainTextEdit() + self.progress_bar = QProgressBar() + self.button_box = QDialogButtonBox() - layout = QVBoxLayout() - layout.addWidget(dialog.textEdit) - layout.addWidget(dialog.progressBar) - layout.addWidget(dialog.buttonBox) - dialog.setLayout(layout) + self.setWindowTitle('Reconstruction Progress') + self.button_box.addButton(QDialogButtonBox.Ok) + self.button_box.accepted.connect(self.accept) + self.button_box.addButton(QDialogButtonBox.Cancel) + self.button_box.rejected.connect(self.reject) - return dialog + layout = QVBoxLayout() + layout.addWidget(self.text_edit) + layout.addWidget(self.progress_bar) + layout.addWidget(self.button_box) + self.setLayout(layout) -class ReconstructorParametersView(QWidget): - def __init__(self, parent: QWidget | None) -> None: +class ReconstructorView(QWidget): + def __init__(self, parent: QWidget | None = None) -> None: super().__init__(parent) - self.reconstructorView = ReconstructorView.createInstance() - self.stackedWidget = QStackedWidget() - self.scrollArea = QScrollArea() - self.progressDialog = ReconstructorProgressDialog.createInstance() # TODO use this - - @classmethod - def createInstance(cls, parent: QWidget | None = None) -> ReconstructorParametersView: - view = cls(parent) + self.parameters_view = ReconstructorParametersView() - view.scrollArea.setWidgetResizable(True) - view.scrollArea.setWidget(view.stackedWidget) + self.stacked_widget = QStackedWidget() + self.stacked_widget.layout().setContentsMargins(0, 0, 0, 0) - view.stackedWidget.layout().setContentsMargins(0, 0, 0, 0) + self.scroll_area = QScrollArea() + self.scroll_area.setWidgetResizable(True) + self.scroll_area.setWidget(self.stacked_widget) layout = QVBoxLayout() - layout.addWidget(view.reconstructorView) - layout.addWidget(view.scrollArea) - view.setLayout(layout) + layout.addWidget(self.parameters_view) + layout.addWidget(self.scroll_area) + self.setLayout(layout) - return view + self.progress_dialog = ReconstructorProgressDialog() class ReconstructorPlotView(QWidget): - def __init__(self, parent: QWidget | None) -> None: + def __init__(self, parent: QWidget | None = None) -> None: super().__init__(parent) self.figure = Figure() - self.figureCanvas = FigureCanvasQTAgg(self.figure) - self.navigationToolbar = NavigationToolbar(self.figureCanvas, self) + self.figure_canvas = FigureCanvasQTAgg(self.figure) + self.navigation_toolbar = NavigationToolbar(self.figure_canvas, self) self.axes = self.figure.add_subplot(111) - @classmethod - def createInstance(cls, parent: QWidget | None = None) -> ReconstructorPlotView: - view = cls(parent) - layout = QVBoxLayout() - layout.addWidget(view.navigationToolbar) - layout.addWidget(view.figureCanvas) - view.setLayout(layout) - - return view + layout.addWidget(self.navigation_toolbar) + layout.addWidget(self.figure_canvas) + self.setLayout(layout) diff --git a/src/ptychodus/view/repository.py b/src/ptychodus/view/repository.py index 849a16c9..f4acf843 100644 --- a/src/ptychodus/view/repository.py +++ b/src/ptychodus/view/repository.py @@ -20,46 +20,46 @@ class RepositoryButtonBox(QWidget): def __init__(self, parent: QWidget | None = None) -> None: super().__init__(parent) - self.loadButton = QPushButton('Load') - self.loadMenu = QMenu() - self.saveButton = QPushButton('Save') - self.saveMenu = QMenu() - self.editButton = QPushButton('Edit') - self.analyzeButton = QPushButton('Analyze') - self.analyzeMenu = QMenu() - - self.loadButton.setMenu(self.loadMenu) - self.saveButton.setMenu(self.saveMenu) - self.analyzeButton.setMenu(self.analyzeMenu) + self.load_button = QPushButton('Load') + self.load_menu = QMenu() + self.save_button = QPushButton('Save') + self.save_menu = QMenu() + self.edit_button = QPushButton('Edit') + self.analyze_button = QPushButton('Analyze') + self.analyze_menu = QMenu() + + self.load_button.setMenu(self.load_menu) + self.save_button.setMenu(self.save_menu) + self.analyze_button.setMenu(self.analyze_menu) layout = QHBoxLayout() layout.setContentsMargins(0, 0, 0, 0) - layout.addWidget(self.loadButton) - layout.addWidget(self.saveButton) - layout.addWidget(self.editButton) - layout.addWidget(self.analyzeButton) + layout.addWidget(self.load_button) + layout.addWidget(self.save_button) + layout.addWidget(self.edit_button) + layout.addWidget(self.analyze_button) self.setLayout(layout) class RepositoryItemCopierDialog(QDialog): def __init__(self, parent: QWidget | None = None) -> None: super().__init__(parent) - self.sourceComboBox = QComboBox() - self.destinationComboBox = QComboBox() - self.buttonBox = QDialogButtonBox() + self.source_combo_box = QComboBox() + self.destination_combo_box = QComboBox() + self.button_box = QDialogButtonBox() - self.buttonBox.addButton(QDialogButtonBox.StandardButton.Ok) - self.buttonBox.addButton(QDialogButtonBox.StandardButton.Cancel) - self.buttonBox.clicked.connect(self._handleButtonBoxClicked) + self.button_box.addButton(QDialogButtonBox.StandardButton.Ok) + self.button_box.addButton(QDialogButtonBox.StandardButton.Cancel) + self.button_box.clicked.connect(self._handle_button_box_clicked) layout = QFormLayout() - layout.addRow('From:', self.sourceComboBox) - layout.addRow('To:', self.destinationComboBox) - layout.addRow(self.buttonBox) + layout.addRow('From:', self.source_combo_box) + layout.addRow('To:', self.destination_combo_box) + layout.addRow(self.button_box) self.setLayout(layout) - def _handleButtonBoxClicked(self, button: QAbstractButton) -> None: - if self.buttonBox.buttonRole(button) == QDialogButtonBox.ButtonRole.AcceptRole: + def _handle_button_box_clicked(self, button: QAbstractButton) -> None: + if self.button_box.buttonRole(button) == QDialogButtonBox.ButtonRole.AcceptRole: self.accept() else: self.reject() @@ -68,26 +68,26 @@ def _handleButtonBoxClicked(self, button: QAbstractButton) -> None: class RepositoryTableView(QWidget): def __init__(self, parent: QWidget | None = None) -> None: super().__init__(parent) - self.tableView = QTableView() - self.buttonBox = RepositoryButtonBox() - self.copierDialog = RepositoryItemCopierDialog() + self.table_view = QTableView() + self.button_box = RepositoryButtonBox() + self.copier_dialog = RepositoryItemCopierDialog() layout = QVBoxLayout() - layout.addWidget(self.tableView) - layout.addWidget(self.buttonBox) + layout.addWidget(self.table_view) + layout.addWidget(self.button_box) self.setLayout(layout) class RepositoryTreeView(QWidget): def __init__(self, parent: QWidget | None = None) -> None: super().__init__(parent) - self.treeView = QTreeView() - self.buttonBox = RepositoryButtonBox() - self.copierDialog = RepositoryItemCopierDialog() + self.tree_view = QTreeView() + self.button_box = RepositoryButtonBox() + self.copier_dialog = RepositoryItemCopierDialog() - self.treeView.header().setDefaultAlignment(Qt.AlignmentFlag.AlignCenter) + self.tree_view.header().setDefaultAlignment(Qt.AlignmentFlag.AlignCenter) layout = QVBoxLayout() - layout.addWidget(self.treeView) - layout.addWidget(self.buttonBox) + layout.addWidget(self.tree_view) + layout.addWidget(self.button_box) self.setLayout(layout) diff --git a/src/ptychodus/view/resources.py b/src/ptychodus/view/resources.py index 373aa76b..e2fea1e3 100644 --- a/src/ptychodus/view/resources.py +++ b/src/ptychodus/view/resources.py @@ -2,273 +2,13 @@ # Resource object code # -# Created by: The Resource Compiler for PyQt5 (Qt v5.15.2) +# Created by: The Resource Compiler for PyQt5 (Qt v5.15.8) # # WARNING! All changes made in this file will be lost! from PyQt5 import QtCore qt_resource_data = b'\ -\x00\x00\x03\x1e\ -\x3c\ -\x73\x76\x67\x20\x78\x6d\x6c\x6e\x73\x3d\x22\x68\x74\x74\x70\x3a\ -\x2f\x2f\x77\x77\x77\x2e\x77\x33\x2e\x6f\x72\x67\x2f\x32\x30\x30\ -\x30\x2f\x73\x76\x67\x22\x20\x76\x69\x65\x77\x42\x6f\x78\x3d\x22\ -\x30\x20\x30\x20\x35\x37\x36\x20\x35\x31\x32\x22\x3e\x3c\x21\x2d\ -\x2d\x21\x20\x46\x6f\x6e\x74\x20\x41\x77\x65\x73\x6f\x6d\x65\x20\ -\x46\x72\x65\x65\x20\x36\x2e\x35\x2e\x32\x20\x62\x79\x20\x40\x66\ -\x6f\x6e\x74\x61\x77\x65\x73\x6f\x6d\x65\x20\x2d\x20\x68\x74\x74\ -\x70\x73\x3a\x2f\x2f\x66\x6f\x6e\x74\x61\x77\x65\x73\x6f\x6d\x65\ -\x2e\x63\x6f\x6d\x20\x4c\x69\x63\x65\x6e\x73\x65\x20\x2d\x20\x68\ -\x74\x74\x70\x73\x3a\x2f\x2f\x66\x6f\x6e\x74\x61\x77\x65\x73\x6f\ -\x6d\x65\x2e\x63\x6f\x6d\x2f\x6c\x69\x63\x65\x6e\x73\x65\x2f\x66\ -\x72\x65\x65\x20\x28\x49\x63\x6f\x6e\x73\x3a\x20\x43\x43\x20\x42\ -\x59\x20\x34\x2e\x30\x2c\x20\x46\x6f\x6e\x74\x73\x3a\x20\x53\x49\ -\x4c\x20\x4f\x46\x4c\x20\x31\x2e\x31\x2c\x20\x43\x6f\x64\x65\x3a\ -\x20\x4d\x49\x54\x20\x4c\x69\x63\x65\x6e\x73\x65\x29\x20\x43\x6f\ -\x70\x79\x72\x69\x67\x68\x74\x20\x32\x30\x32\x34\x20\x46\x6f\x6e\ -\x74\x69\x63\x6f\x6e\x73\x2c\x20\x49\x6e\x63\x2e\x20\x2d\x2d\x3e\ -\x3c\x70\x61\x74\x68\x20\x64\x3d\x22\x4d\x35\x34\x33\x2e\x38\x20\ -\x32\x38\x37\x2e\x36\x63\x31\x37\x20\x30\x20\x33\x32\x2d\x31\x34\ -\x20\x33\x32\x2d\x33\x32\x2e\x31\x63\x31\x2d\x39\x2d\x33\x2d\x31\ -\x37\x2d\x31\x31\x2d\x32\x34\x4c\x35\x31\x32\x20\x31\x38\x35\x56\ -\x36\x34\x63\x30\x2d\x31\x37\x2e\x37\x2d\x31\x34\x2e\x33\x2d\x33\ -\x32\x2d\x33\x32\x2d\x33\x32\x48\x34\x34\x38\x63\x2d\x31\x37\x2e\ -\x37\x20\x30\x2d\x33\x32\x20\x31\x34\x2e\x33\x2d\x33\x32\x20\x33\ -\x32\x76\x33\x36\x2e\x37\x4c\x33\x30\x39\x2e\x35\x20\x37\x63\x2d\ -\x36\x2d\x35\x2d\x31\x34\x2d\x37\x2d\x32\x31\x2d\x37\x73\x2d\x31\ -\x35\x20\x31\x2d\x32\x32\x20\x38\x4c\x31\x30\x20\x32\x33\x31\x2e\ -\x35\x63\x2d\x37\x20\x37\x2d\x31\x30\x20\x31\x35\x2d\x31\x30\x20\ -\x32\x34\x63\x30\x20\x31\x38\x20\x31\x34\x20\x33\x32\x2e\x31\x20\ -\x33\x32\x20\x33\x32\x2e\x31\x68\x33\x32\x76\x36\x39\x2e\x37\x63\ -\x2d\x2e\x31\x20\x2e\x39\x2d\x2e\x31\x20\x31\x2e\x38\x2d\x2e\x31\ -\x20\x32\x2e\x38\x56\x34\x37\x32\x63\x30\x20\x32\x32\x2e\x31\x20\ -\x31\x37\x2e\x39\x20\x34\x30\x20\x34\x30\x20\x34\x30\x68\x31\x36\ -\x63\x31\x2e\x32\x20\x30\x20\x32\x2e\x34\x2d\x2e\x31\x20\x33\x2e\ -\x36\x2d\x2e\x32\x63\x31\x2e\x35\x20\x2e\x31\x20\x33\x20\x2e\x32\ -\x20\x34\x2e\x35\x20\x2e\x32\x48\x31\x36\x30\x68\x32\x34\x63\x32\ -\x32\x2e\x31\x20\x30\x20\x34\x30\x2d\x31\x37\x2e\x39\x20\x34\x30\ -\x2d\x34\x30\x56\x34\x34\x38\x20\x33\x38\x34\x63\x30\x2d\x31\x37\ -\x2e\x37\x20\x31\x34\x2e\x33\x2d\x33\x32\x20\x33\x32\x2d\x33\x32\ -\x68\x36\x34\x63\x31\x37\x2e\x37\x20\x30\x20\x33\x32\x20\x31\x34\ -\x2e\x33\x20\x33\x32\x20\x33\x32\x76\x36\x34\x20\x32\x34\x63\x30\ -\x20\x32\x32\x2e\x31\x20\x31\x37\x2e\x39\x20\x34\x30\x20\x34\x30\ -\x20\x34\x30\x68\x32\x34\x20\x33\x32\x2e\x35\x63\x31\x2e\x34\x20\ -\x30\x20\x32\x2e\x38\x20\x30\x20\x34\x2e\x32\x2d\x2e\x31\x63\x31\ -\x2e\x31\x20\x2e\x31\x20\x32\x2e\x32\x20\x2e\x31\x20\x33\x2e\x33\ -\x20\x2e\x31\x68\x31\x36\x63\x32\x32\x2e\x31\x20\x30\x20\x34\x30\ -\x2d\x31\x37\x2e\x39\x20\x34\x30\x2d\x34\x30\x56\x34\x35\x35\x2e\ -\x38\x63\x2e\x33\x2d\x32\x2e\x36\x20\x2e\x35\x2d\x35\x2e\x33\x20\ -\x2e\x35\x2d\x38\x2e\x31\x6c\x2d\x2e\x37\x2d\x31\x36\x30\x2e\x32\ -\x68\x33\x32\x7a\x22\x2f\x3e\x3c\x2f\x73\x76\x67\x3e\ -\x00\x00\x03\xa4\ -\x3c\ -\x73\x76\x67\x20\x78\x6d\x6c\x6e\x73\x3d\x22\x68\x74\x74\x70\x3a\ -\x2f\x2f\x77\x77\x77\x2e\x77\x33\x2e\x6f\x72\x67\x2f\x32\x30\x30\ -\x30\x2f\x73\x76\x67\x22\x20\x76\x69\x65\x77\x42\x6f\x78\x3d\x22\ -\x30\x20\x30\x20\x35\x31\x32\x20\x35\x31\x32\x22\x3e\x3c\x21\x2d\ -\x2d\x21\x20\x46\x6f\x6e\x74\x20\x41\x77\x65\x73\x6f\x6d\x65\x20\ -\x46\x72\x65\x65\x20\x36\x2e\x35\x2e\x32\x20\x62\x79\x20\x40\x66\ -\x6f\x6e\x74\x61\x77\x65\x73\x6f\x6d\x65\x20\x2d\x20\x68\x74\x74\ -\x70\x73\x3a\x2f\x2f\x66\x6f\x6e\x74\x61\x77\x65\x73\x6f\x6d\x65\ -\x2e\x63\x6f\x6d\x20\x4c\x69\x63\x65\x6e\x73\x65\x20\x2d\x20\x68\ -\x74\x74\x70\x73\x3a\x2f\x2f\x66\x6f\x6e\x74\x61\x77\x65\x73\x6f\ -\x6d\x65\x2e\x63\x6f\x6d\x2f\x6c\x69\x63\x65\x6e\x73\x65\x2f\x66\ -\x72\x65\x65\x20\x28\x49\x63\x6f\x6e\x73\x3a\x20\x43\x43\x20\x42\ -\x59\x20\x34\x2e\x30\x2c\x20\x46\x6f\x6e\x74\x73\x3a\x20\x53\x49\ -\x4c\x20\x4f\x46\x4c\x20\x31\x2e\x31\x2c\x20\x43\x6f\x64\x65\x3a\ -\x20\x4d\x49\x54\x20\x4c\x69\x63\x65\x6e\x73\x65\x29\x20\x43\x6f\ -\x70\x79\x72\x69\x67\x68\x74\x20\x32\x30\x32\x34\x20\x46\x6f\x6e\ -\x74\x69\x63\x6f\x6e\x73\x2c\x20\x49\x6e\x63\x2e\x20\x2d\x2d\x3e\ -\x3c\x70\x61\x74\x68\x20\x64\x3d\x22\x4d\x32\x37\x38\x2e\x36\x20\ -\x39\x2e\x34\x63\x2d\x31\x32\x2e\x35\x2d\x31\x32\x2e\x35\x2d\x33\ -\x32\x2e\x38\x2d\x31\x32\x2e\x35\x2d\x34\x35\x2e\x33\x20\x30\x6c\ -\x2d\x36\x34\x20\x36\x34\x63\x2d\x31\x32\x2e\x35\x20\x31\x32\x2e\ -\x35\x2d\x31\x32\x2e\x35\x20\x33\x32\x2e\x38\x20\x30\x20\x34\x35\ -\x2e\x33\x73\x33\x32\x2e\x38\x20\x31\x32\x2e\x35\x20\x34\x35\x2e\ -\x33\x20\x30\x6c\x39\x2e\x34\x2d\x39\x2e\x34\x56\x32\x32\x34\x48\ -\x31\x30\x39\x2e\x33\x6c\x39\x2e\x34\x2d\x39\x2e\x34\x63\x31\x32\ -\x2e\x35\x2d\x31\x32\x2e\x35\x20\x31\x32\x2e\x35\x2d\x33\x32\x2e\ -\x38\x20\x30\x2d\x34\x35\x2e\x33\x73\x2d\x33\x32\x2e\x38\x2d\x31\ -\x32\x2e\x35\x2d\x34\x35\x2e\x33\x20\x30\x6c\x2d\x36\x34\x20\x36\ -\x34\x63\x2d\x31\x32\x2e\x35\x20\x31\x32\x2e\x35\x2d\x31\x32\x2e\ -\x35\x20\x33\x32\x2e\x38\x20\x30\x20\x34\x35\x2e\x33\x6c\x36\x34\ -\x20\x36\x34\x63\x31\x32\x2e\x35\x20\x31\x32\x2e\x35\x20\x33\x32\ -\x2e\x38\x20\x31\x32\x2e\x35\x20\x34\x35\x2e\x33\x20\x30\x73\x31\ -\x32\x2e\x35\x2d\x33\x32\x2e\x38\x20\x30\x2d\x34\x35\x2e\x33\x6c\ -\x2d\x39\x2e\x34\x2d\x39\x2e\x34\x48\x32\x32\x34\x56\x34\x30\x32\ -\x2e\x37\x6c\x2d\x39\x2e\x34\x2d\x39\x2e\x34\x63\x2d\x31\x32\x2e\ -\x35\x2d\x31\x32\x2e\x35\x2d\x33\x32\x2e\x38\x2d\x31\x32\x2e\x35\ -\x2d\x34\x35\x2e\x33\x20\x30\x73\x2d\x31\x32\x2e\x35\x20\x33\x32\ -\x2e\x38\x20\x30\x20\x34\x35\x2e\x33\x6c\x36\x34\x20\x36\x34\x63\ -\x31\x32\x2e\x35\x20\x31\x32\x2e\x35\x20\x33\x32\x2e\x38\x20\x31\ -\x32\x2e\x35\x20\x34\x35\x2e\x33\x20\x30\x6c\x36\x34\x2d\x36\x34\ -\x63\x31\x32\x2e\x35\x2d\x31\x32\x2e\x35\x20\x31\x32\x2e\x35\x2d\ -\x33\x32\x2e\x38\x20\x30\x2d\x34\x35\x2e\x33\x73\x2d\x33\x32\x2e\ -\x38\x2d\x31\x32\x2e\x35\x2d\x34\x35\x2e\x33\x20\x30\x6c\x2d\x39\ -\x2e\x34\x20\x39\x2e\x34\x56\x32\x38\x38\x48\x34\x30\x32\x2e\x37\ -\x6c\x2d\x39\x2e\x34\x20\x39\x2e\x34\x63\x2d\x31\x32\x2e\x35\x20\ -\x31\x32\x2e\x35\x2d\x31\x32\x2e\x35\x20\x33\x32\x2e\x38\x20\x30\ -\x20\x34\x35\x2e\x33\x73\x33\x32\x2e\x38\x20\x31\x32\x2e\x35\x20\ -\x34\x35\x2e\x33\x20\x30\x6c\x36\x34\x2d\x36\x34\x63\x31\x32\x2e\ -\x35\x2d\x31\x32\x2e\x35\x20\x31\x32\x2e\x35\x2d\x33\x32\x2e\x38\ -\x20\x30\x2d\x34\x35\x2e\x33\x6c\x2d\x36\x34\x2d\x36\x34\x63\x2d\ -\x31\x32\x2e\x35\x2d\x31\x32\x2e\x35\x2d\x33\x32\x2e\x38\x2d\x31\ -\x32\x2e\x35\x2d\x34\x35\x2e\x33\x20\x30\x73\x2d\x31\x32\x2e\x35\ -\x20\x33\x32\x2e\x38\x20\x30\x20\x34\x35\x2e\x33\x6c\x39\x2e\x34\ -\x20\x39\x2e\x34\x48\x32\x38\x38\x56\x31\x30\x39\x2e\x33\x6c\x39\ -\x2e\x34\x20\x39\x2e\x34\x63\x31\x32\x2e\x35\x20\x31\x32\x2e\x35\ -\x20\x33\x32\x2e\x38\x20\x31\x32\x2e\x35\x20\x34\x35\x2e\x33\x20\ -\x30\x73\x31\x32\x2e\x35\x2d\x33\x32\x2e\x38\x20\x30\x2d\x34\x35\ -\x2e\x33\x6c\x2d\x36\x34\x2d\x36\x34\x7a\x22\x2f\x3e\x3c\x2f\x73\ -\x76\x67\x3e\ -\x00\x00\x02\xc2\ -\x3c\ -\x73\x76\x67\x20\x78\x6d\x6c\x6e\x73\x3d\x22\x68\x74\x74\x70\x3a\ -\x2f\x2f\x77\x77\x77\x2e\x77\x33\x2e\x6f\x72\x67\x2f\x32\x30\x30\ -\x30\x2f\x73\x76\x67\x22\x20\x76\x69\x65\x77\x42\x6f\x78\x3d\x22\ -\x30\x20\x30\x20\x34\x34\x38\x20\x35\x31\x32\x22\x3e\x3c\x21\x2d\ -\x2d\x21\x20\x46\x6f\x6e\x74\x20\x41\x77\x65\x73\x6f\x6d\x65\x20\ -\x46\x72\x65\x65\x20\x36\x2e\x35\x2e\x32\x20\x62\x79\x20\x40\x66\ -\x6f\x6e\x74\x61\x77\x65\x73\x6f\x6d\x65\x20\x2d\x20\x68\x74\x74\ -\x70\x73\x3a\x2f\x2f\x66\x6f\x6e\x74\x61\x77\x65\x73\x6f\x6d\x65\ -\x2e\x63\x6f\x6d\x20\x4c\x69\x63\x65\x6e\x73\x65\x20\x2d\x20\x68\ -\x74\x74\x70\x73\x3a\x2f\x2f\x66\x6f\x6e\x74\x61\x77\x65\x73\x6f\ -\x6d\x65\x2e\x63\x6f\x6d\x2f\x6c\x69\x63\x65\x6e\x73\x65\x2f\x66\ -\x72\x65\x65\x20\x28\x49\x63\x6f\x6e\x73\x3a\x20\x43\x43\x20\x42\ -\x59\x20\x34\x2e\x30\x2c\x20\x46\x6f\x6e\x74\x73\x3a\x20\x53\x49\ -\x4c\x20\x4f\x46\x4c\x20\x31\x2e\x31\x2c\x20\x43\x6f\x64\x65\x3a\ -\x20\x4d\x49\x54\x20\x4c\x69\x63\x65\x6e\x73\x65\x29\x20\x43\x6f\ -\x70\x79\x72\x69\x67\x68\x74\x20\x32\x30\x32\x34\x20\x46\x6f\x6e\ -\x74\x69\x63\x6f\x6e\x73\x2c\x20\x49\x6e\x63\x2e\x20\x2d\x2d\x3e\ -\x3c\x70\x61\x74\x68\x20\x64\x3d\x22\x4d\x34\x38\x20\x39\x36\x56\ -\x34\x31\x36\x63\x30\x20\x38\x2e\x38\x20\x37\x2e\x32\x20\x31\x36\ -\x20\x31\x36\x20\x31\x36\x48\x33\x38\x34\x63\x38\x2e\x38\x20\x30\ -\x20\x31\x36\x2d\x37\x2e\x32\x20\x31\x36\x2d\x31\x36\x56\x31\x37\ -\x30\x2e\x35\x63\x30\x2d\x34\x2e\x32\x2d\x31\x2e\x37\x2d\x38\x2e\ -\x33\x2d\x34\x2e\x37\x2d\x31\x31\x2e\x33\x6c\x33\x33\x2e\x39\x2d\ -\x33\x33\x2e\x39\x63\x31\x32\x20\x31\x32\x20\x31\x38\x2e\x37\x20\ -\x32\x38\x2e\x33\x20\x31\x38\x2e\x37\x20\x34\x35\x2e\x33\x56\x34\ -\x31\x36\x63\x30\x20\x33\x35\x2e\x33\x2d\x32\x38\x2e\x37\x20\x36\ -\x34\x2d\x36\x34\x20\x36\x34\x48\x36\x34\x63\x2d\x33\x35\x2e\x33\ -\x20\x30\x2d\x36\x34\x2d\x32\x38\x2e\x37\x2d\x36\x34\x2d\x36\x34\ -\x56\x39\x36\x43\x30\x20\x36\x30\x2e\x37\x20\x32\x38\x2e\x37\x20\ -\x33\x32\x20\x36\x34\x20\x33\x32\x48\x33\x30\x39\x2e\x35\x63\x31\ -\x37\x20\x30\x20\x33\x33\x2e\x33\x20\x36\x2e\x37\x20\x34\x35\x2e\ -\x33\x20\x31\x38\x2e\x37\x6c\x37\x34\x2e\x35\x20\x37\x34\x2e\x35\ -\x2d\x33\x33\x2e\x39\x20\x33\x33\x2e\x39\x4c\x33\x32\x30\x2e\x38\ -\x20\x38\x34\x2e\x37\x63\x2d\x2e\x33\x2d\x2e\x33\x2d\x2e\x35\x2d\ -\x2e\x35\x2d\x2e\x38\x2d\x2e\x38\x56\x31\x38\x34\x63\x30\x20\x31\ -\x33\x2e\x33\x2d\x31\x30\x2e\x37\x20\x32\x34\x2d\x32\x34\x20\x32\ -\x34\x48\x31\x30\x34\x63\x2d\x31\x33\x2e\x33\x20\x30\x2d\x32\x34\ -\x2d\x31\x30\x2e\x37\x2d\x32\x34\x2d\x32\x34\x56\x38\x30\x48\x36\ -\x34\x63\x2d\x38\x2e\x38\x20\x30\x2d\x31\x36\x20\x37\x2e\x32\x2d\ -\x31\x36\x20\x31\x36\x7a\x6d\x38\x30\x2d\x31\x36\x76\x38\x30\x48\ -\x32\x37\x32\x56\x38\x30\x48\x31\x32\x38\x7a\x6d\x33\x32\x20\x32\ -\x34\x30\x61\x36\x34\x20\x36\x34\x20\x30\x20\x31\x20\x31\x20\x31\ -\x32\x38\x20\x30\x20\x36\x34\x20\x36\x34\x20\x30\x20\x31\x20\x31\ -\x20\x2d\x31\x32\x38\x20\x30\x7a\x22\x2f\x3e\x3c\x2f\x73\x76\x67\ -\x3e\ -\x00\x00\x03\x78\ -\x3c\ -\x73\x76\x67\x20\x78\x6d\x6c\x6e\x73\x3d\x22\x68\x74\x74\x70\x3a\ -\x2f\x2f\x77\x77\x77\x2e\x77\x33\x2e\x6f\x72\x67\x2f\x32\x30\x30\ -\x30\x2f\x73\x76\x67\x22\x20\x76\x69\x65\x77\x42\x6f\x78\x3d\x22\ -\x30\x20\x30\x20\x35\x31\x32\x20\x35\x31\x32\x22\x3e\x3c\x21\x2d\ -\x2d\x21\x20\x46\x6f\x6e\x74\x20\x41\x77\x65\x73\x6f\x6d\x65\x20\ -\x46\x72\x65\x65\x20\x36\x2e\x35\x2e\x32\x20\x62\x79\x20\x40\x66\ -\x6f\x6e\x74\x61\x77\x65\x73\x6f\x6d\x65\x20\x2d\x20\x68\x74\x74\ -\x70\x73\x3a\x2f\x2f\x66\x6f\x6e\x74\x61\x77\x65\x73\x6f\x6d\x65\ -\x2e\x63\x6f\x6d\x20\x4c\x69\x63\x65\x6e\x73\x65\x20\x2d\x20\x68\ -\x74\x74\x70\x73\x3a\x2f\x2f\x66\x6f\x6e\x74\x61\x77\x65\x73\x6f\ -\x6d\x65\x2e\x63\x6f\x6d\x2f\x6c\x69\x63\x65\x6e\x73\x65\x2f\x66\ -\x72\x65\x65\x20\x28\x49\x63\x6f\x6e\x73\x3a\x20\x43\x43\x20\x42\ -\x59\x20\x34\x2e\x30\x2c\x20\x46\x6f\x6e\x74\x73\x3a\x20\x53\x49\ -\x4c\x20\x4f\x46\x4c\x20\x31\x2e\x31\x2c\x20\x43\x6f\x64\x65\x3a\ -\x20\x4d\x49\x54\x20\x4c\x69\x63\x65\x6e\x73\x65\x29\x20\x43\x6f\ -\x70\x79\x72\x69\x67\x68\x74\x20\x32\x30\x32\x34\x20\x46\x6f\x6e\ -\x74\x69\x63\x6f\x6e\x73\x2c\x20\x49\x6e\x63\x2e\x20\x2d\x2d\x3e\ -\x3c\x70\x61\x74\x68\x20\x64\x3d\x22\x4d\x35\x31\x32\x20\x39\x36\ -\x63\x30\x20\x35\x30\x2e\x32\x2d\x35\x39\x2e\x31\x20\x31\x32\x35\ -\x2e\x31\x2d\x38\x34\x2e\x36\x20\x31\x35\x35\x63\x2d\x33\x2e\x38\ -\x20\x34\x2e\x34\x2d\x39\x2e\x34\x20\x36\x2e\x31\x2d\x31\x34\x2e\ -\x35\x20\x35\x48\x33\x32\x30\x63\x2d\x31\x37\x2e\x37\x20\x30\x2d\ -\x33\x32\x20\x31\x34\x2e\x33\x2d\x33\x32\x20\x33\x32\x73\x31\x34\ -\x2e\x33\x20\x33\x32\x20\x33\x32\x20\x33\x32\x68\x39\x36\x63\x35\ -\x33\x20\x30\x20\x39\x36\x20\x34\x33\x20\x39\x36\x20\x39\x36\x73\ -\x2d\x34\x33\x20\x39\x36\x2d\x39\x36\x20\x39\x36\x48\x31\x33\x39\ -\x2e\x36\x63\x38\x2e\x37\x2d\x39\x2e\x39\x20\x31\x39\x2e\x33\x2d\ -\x32\x32\x2e\x36\x20\x33\x30\x2d\x33\x36\x2e\x38\x63\x36\x2e\x33\ -\x2d\x38\x2e\x34\x20\x31\x32\x2e\x38\x2d\x31\x37\x2e\x36\x20\x31\ -\x39\x2d\x32\x37\x2e\x32\x48\x34\x31\x36\x63\x31\x37\x2e\x37\x20\ -\x30\x20\x33\x32\x2d\x31\x34\x2e\x33\x20\x33\x32\x2d\x33\x32\x73\ -\x2d\x31\x34\x2e\x33\x2d\x33\x32\x2d\x33\x32\x2d\x33\x32\x48\x33\ -\x32\x30\x63\x2d\x35\x33\x20\x30\x2d\x39\x36\x2d\x34\x33\x2d\x39\ -\x36\x2d\x39\x36\x73\x34\x33\x2d\x39\x36\x20\x39\x36\x2d\x39\x36\ -\x68\x33\x39\x2e\x38\x63\x2d\x32\x31\x2d\x33\x31\x2e\x35\x2d\x33\ -\x39\x2e\x38\x2d\x36\x37\x2e\x37\x2d\x33\x39\x2e\x38\x2d\x39\x36\ -\x63\x30\x2d\x35\x33\x20\x34\x33\x2d\x39\x36\x20\x39\x36\x2d\x39\ -\x36\x73\x39\x36\x20\x34\x33\x20\x39\x36\x20\x39\x36\x7a\x4d\x31\ -\x31\x37\x2e\x31\x20\x34\x38\x39\x2e\x31\x63\x2d\x33\x2e\x38\x20\ -\x34\x2e\x33\x2d\x37\x2e\x32\x20\x38\x2e\x31\x2d\x31\x30\x2e\x31\ -\x20\x31\x31\x2e\x33\x6c\x2d\x31\x2e\x38\x20\x32\x2d\x2e\x32\x2d\ -\x2e\x32\x63\x2d\x36\x20\x34\x2e\x36\x2d\x31\x34\x2e\x36\x20\x34\ -\x2d\x32\x30\x2d\x31\x2e\x38\x43\x35\x39\x2e\x38\x20\x34\x37\x33\ -\x20\x30\x20\x34\x30\x32\x2e\x35\x20\x30\x20\x33\x35\x32\x63\x30\ -\x2d\x35\x33\x20\x34\x33\x2d\x39\x36\x20\x39\x36\x2d\x39\x36\x73\ -\x39\x36\x20\x34\x33\x20\x39\x36\x20\x39\x36\x63\x30\x20\x33\x30\ -\x2d\x32\x31\x2e\x31\x20\x36\x37\x2d\x34\x33\x2e\x35\x20\x39\x37\ -\x2e\x39\x63\x2d\x31\x30\x2e\x37\x20\x31\x34\x2e\x37\x2d\x32\x31\ -\x2e\x37\x20\x32\x38\x2d\x33\x30\x2e\x38\x20\x33\x38\x2e\x35\x6c\ -\x2d\x2e\x36\x20\x2e\x37\x7a\x4d\x31\x32\x38\x20\x33\x35\x32\x61\ -\x33\x32\x20\x33\x32\x20\x30\x20\x31\x20\x30\x20\x2d\x36\x34\x20\ -\x30\x20\x33\x32\x20\x33\x32\x20\x30\x20\x31\x20\x30\x20\x36\x34\ -\x20\x30\x7a\x4d\x34\x31\x36\x20\x31\x32\x38\x61\x33\x32\x20\x33\ -\x32\x20\x30\x20\x31\x20\x30\x20\x30\x2d\x36\x34\x20\x33\x32\x20\ -\x33\x32\x20\x30\x20\x31\x20\x30\x20\x30\x20\x36\x34\x7a\x22\x2f\ -\x3e\x3c\x2f\x73\x76\x67\x3e\ -\x00\x00\x02\x7d\ -\x3c\ -\x73\x76\x67\x20\x78\x6d\x6c\x6e\x73\x3d\x22\x68\x74\x74\x70\x3a\ -\x2f\x2f\x77\x77\x77\x2e\x77\x33\x2e\x6f\x72\x67\x2f\x32\x30\x30\ -\x30\x2f\x73\x76\x67\x22\x20\x76\x69\x65\x77\x42\x6f\x78\x3d\x22\ -\x30\x20\x30\x20\x35\x31\x32\x20\x35\x31\x32\x22\x3e\x3c\x21\x2d\ -\x2d\x21\x20\x46\x6f\x6e\x74\x20\x41\x77\x65\x73\x6f\x6d\x65\x20\ -\x46\x72\x65\x65\x20\x36\x2e\x35\x2e\x32\x20\x62\x79\x20\x40\x66\ -\x6f\x6e\x74\x61\x77\x65\x73\x6f\x6d\x65\x20\x2d\x20\x68\x74\x74\ -\x70\x73\x3a\x2f\x2f\x66\x6f\x6e\x74\x61\x77\x65\x73\x6f\x6d\x65\ -\x2e\x63\x6f\x6d\x20\x4c\x69\x63\x65\x6e\x73\x65\x20\x2d\x20\x68\ -\x74\x74\x70\x73\x3a\x2f\x2f\x66\x6f\x6e\x74\x61\x77\x65\x73\x6f\ -\x6d\x65\x2e\x63\x6f\x6d\x2f\x6c\x69\x63\x65\x6e\x73\x65\x2f\x66\ -\x72\x65\x65\x20\x28\x49\x63\x6f\x6e\x73\x3a\x20\x43\x43\x20\x42\ -\x59\x20\x34\x2e\x30\x2c\x20\x46\x6f\x6e\x74\x73\x3a\x20\x53\x49\ -\x4c\x20\x4f\x46\x4c\x20\x31\x2e\x31\x2c\x20\x43\x6f\x64\x65\x3a\ -\x20\x4d\x49\x54\x20\x4c\x69\x63\x65\x6e\x73\x65\x29\x20\x43\x6f\ -\x70\x79\x72\x69\x67\x68\x74\x20\x32\x30\x32\x34\x20\x46\x6f\x6e\ -\x74\x69\x63\x6f\x6e\x73\x2c\x20\x49\x6e\x63\x2e\x20\x2d\x2d\x3e\ -\x3c\x70\x61\x74\x68\x20\x64\x3d\x22\x4d\x36\x34\x20\x36\x34\x63\ -\x30\x2d\x31\x37\x2e\x37\x2d\x31\x34\x2e\x33\x2d\x33\x32\x2d\x33\ -\x32\x2d\x33\x32\x53\x30\x20\x34\x36\x2e\x33\x20\x30\x20\x36\x34\ -\x56\x34\x30\x30\x63\x30\x20\x34\x34\x2e\x32\x20\x33\x35\x2e\x38\ -\x20\x38\x30\x20\x38\x30\x20\x38\x30\x48\x34\x38\x30\x63\x31\x37\ -\x2e\x37\x20\x30\x20\x33\x32\x2d\x31\x34\x2e\x33\x20\x33\x32\x2d\ -\x33\x32\x73\x2d\x31\x34\x2e\x33\x2d\x33\x32\x2d\x33\x32\x2d\x33\ -\x32\x48\x38\x30\x63\x2d\x38\x2e\x38\x20\x30\x2d\x31\x36\x2d\x37\ -\x2e\x32\x2d\x31\x36\x2d\x31\x36\x56\x36\x34\x7a\x6d\x34\x30\x36\ -\x2e\x36\x20\x38\x36\x2e\x36\x63\x31\x32\x2e\x35\x2d\x31\x32\x2e\ -\x35\x20\x31\x32\x2e\x35\x2d\x33\x32\x2e\x38\x20\x30\x2d\x34\x35\ -\x2e\x33\x73\x2d\x33\x32\x2e\x38\x2d\x31\x32\x2e\x35\x2d\x34\x35\ -\x2e\x33\x20\x30\x4c\x33\x32\x30\x20\x32\x31\x30\x2e\x37\x6c\x2d\ -\x35\x37\x2e\x34\x2d\x35\x37\x2e\x34\x63\x2d\x31\x32\x2e\x35\x2d\ -\x31\x32\x2e\x35\x2d\x33\x32\x2e\x38\x2d\x31\x32\x2e\x35\x2d\x34\ -\x35\x2e\x33\x20\x30\x6c\x2d\x31\x31\x32\x20\x31\x31\x32\x63\x2d\ -\x31\x32\x2e\x35\x20\x31\x32\x2e\x35\x2d\x31\x32\x2e\x35\x20\x33\ -\x32\x2e\x38\x20\x30\x20\x34\x35\x2e\x33\x73\x33\x32\x2e\x38\x20\ -\x31\x32\x2e\x35\x20\x34\x35\x2e\x33\x20\x30\x4c\x32\x34\x30\x20\ -\x32\x32\x31\x2e\x33\x6c\x35\x37\x2e\x34\x20\x35\x37\x2e\x34\x63\ -\x31\x32\x2e\x35\x20\x31\x32\x2e\x35\x20\x33\x32\x2e\x38\x20\x31\ -\x32\x2e\x35\x20\x34\x35\x2e\x33\x20\x30\x6c\x31\x32\x38\x2d\x31\ -\x32\x38\x7a\x22\x2f\x3e\x3c\x2f\x73\x76\x67\x3e\ \x00\x00\x03\x5b\ \x3c\ \x73\x76\x67\x20\x78\x6d\x6c\x6e\x73\x3d\x22\x68\x74\x74\x70\x3a\ @@ -276,7 +16,7 @@ \x30\x2f\x73\x76\x67\x22\x20\x76\x69\x65\x77\x42\x6f\x78\x3d\x22\ \x30\x20\x30\x20\x35\x31\x32\x20\x35\x31\x32\x22\x3e\x3c\x21\x2d\ \x2d\x21\x20\x46\x6f\x6e\x74\x20\x41\x77\x65\x73\x6f\x6d\x65\x20\ -\x46\x72\x65\x65\x20\x36\x2e\x35\x2e\x32\x20\x62\x79\x20\x40\x66\ +\x46\x72\x65\x65\x20\x36\x2e\x37\x2e\x32\x20\x62\x79\x20\x40\x66\ \x6f\x6e\x74\x61\x77\x65\x73\x6f\x6d\x65\x20\x2d\x20\x68\x74\x74\ \x70\x73\x3a\x2f\x2f\x66\x6f\x6e\x74\x61\x77\x65\x73\x6f\x6d\x65\ \x2e\x63\x6f\x6d\x20\x4c\x69\x63\x65\x6e\x73\x65\x20\x2d\x20\x68\ @@ -325,14 +65,14 @@ \x32\x20\x33\x32\x20\x30\x20\x31\x20\x30\x20\x30\x2d\x36\x34\x20\ \x33\x32\x20\x33\x32\x20\x30\x20\x31\x20\x30\x20\x30\x20\x36\x34\ \x7a\x22\x2f\x3e\x3c\x2f\x73\x76\x67\x3e\ -\x00\x00\x02\xc5\ +\x00\x00\x02\xea\ \x3c\ \x73\x76\x67\x20\x78\x6d\x6c\x6e\x73\x3d\x22\x68\x74\x74\x70\x3a\ \x2f\x2f\x77\x77\x77\x2e\x77\x33\x2e\x6f\x72\x67\x2f\x32\x30\x30\ \x30\x2f\x73\x76\x67\x22\x20\x76\x69\x65\x77\x42\x6f\x78\x3d\x22\ -\x30\x20\x30\x20\x35\x31\x32\x20\x35\x31\x32\x22\x3e\x3c\x21\x2d\ +\x30\x20\x30\x20\x34\x34\x38\x20\x35\x31\x32\x22\x3e\x3c\x21\x2d\ \x2d\x21\x20\x46\x6f\x6e\x74\x20\x41\x77\x65\x73\x6f\x6d\x65\x20\ -\x46\x72\x65\x65\x20\x36\x2e\x35\x2e\x32\x20\x62\x79\x20\x40\x66\ +\x46\x72\x65\x65\x20\x36\x2e\x37\x2e\x32\x20\x62\x79\x20\x40\x66\ \x6f\x6e\x74\x61\x77\x65\x73\x6f\x6d\x65\x20\x2d\x20\x68\x74\x74\ \x70\x73\x3a\x2f\x2f\x66\x6f\x6e\x74\x61\x77\x65\x73\x6f\x6d\x65\ \x2e\x63\x6f\x6d\x20\x4c\x69\x63\x65\x6e\x73\x65\x20\x2d\x20\x68\ @@ -344,34 +84,36 @@ \x20\x4d\x49\x54\x20\x4c\x69\x63\x65\x6e\x73\x65\x29\x20\x43\x6f\ \x70\x79\x72\x69\x67\x68\x74\x20\x32\x30\x32\x34\x20\x46\x6f\x6e\ \x74\x69\x63\x6f\x6e\x73\x2c\x20\x49\x6e\x63\x2e\x20\x2d\x2d\x3e\ -\x3c\x70\x61\x74\x68\x20\x64\x3d\x22\x4d\x31\x37\x37\x2e\x39\x20\ -\x34\x39\x34\x2e\x31\x63\x2d\x31\x38\x2e\x37\x20\x31\x38\x2e\x37\ -\x2d\x34\x39\x2e\x31\x20\x31\x38\x2e\x37\x2d\x36\x37\x2e\x39\x20\ -\x30\x4c\x31\x37\x2e\x39\x20\x34\x30\x31\x2e\x39\x63\x2d\x31\x38\ -\x2e\x37\x2d\x31\x38\x2e\x37\x2d\x31\x38\x2e\x37\x2d\x34\x39\x2e\ -\x31\x20\x30\x2d\x36\x37\x2e\x39\x6c\x35\x30\x2e\x37\x2d\x35\x30\ -\x2e\x37\x20\x34\x38\x20\x34\x38\x63\x36\x2e\x32\x20\x36\x2e\x32\ -\x20\x31\x36\x2e\x34\x20\x36\x2e\x32\x20\x32\x32\x2e\x36\x20\x30\ -\x73\x36\x2e\x32\x2d\x31\x36\x2e\x34\x20\x30\x2d\x32\x32\x2e\x36\ -\x6c\x2d\x34\x38\x2d\x34\x38\x20\x34\x31\x2e\x34\x2d\x34\x31\x2e\ -\x34\x20\x34\x38\x20\x34\x38\x63\x36\x2e\x32\x20\x36\x2e\x32\x20\ -\x31\x36\x2e\x34\x20\x36\x2e\x32\x20\x32\x32\x2e\x36\x20\x30\x73\ -\x36\x2e\x32\x2d\x31\x36\x2e\x34\x20\x30\x2d\x32\x32\x2e\x36\x6c\ -\x2d\x34\x38\x2d\x34\x38\x20\x34\x31\x2e\x34\x2d\x34\x31\x2e\x34\ -\x20\x34\x38\x20\x34\x38\x63\x36\x2e\x32\x20\x36\x2e\x32\x20\x31\ -\x36\x2e\x34\x20\x36\x2e\x32\x20\x32\x32\x2e\x36\x20\x30\x73\x36\ -\x2e\x32\x2d\x31\x36\x2e\x34\x20\x30\x2d\x32\x32\x2e\x36\x6c\x2d\ -\x34\x38\x2d\x34\x38\x20\x34\x31\x2e\x34\x2d\x34\x31\x2e\x34\x20\ -\x34\x38\x20\x34\x38\x63\x36\x2e\x32\x20\x36\x2e\x32\x20\x31\x36\ -\x2e\x34\x20\x36\x2e\x32\x20\x32\x32\x2e\x36\x20\x30\x73\x36\x2e\ -\x32\x2d\x31\x36\x2e\x34\x20\x30\x2d\x32\x32\x2e\x36\x6c\x2d\x34\ -\x38\x2d\x34\x38\x20\x35\x30\x2e\x37\x2d\x35\x30\x2e\x37\x63\x31\ -\x38\x2e\x37\x2d\x31\x38\x2e\x37\x20\x34\x39\x2e\x31\x2d\x31\x38\ -\x2e\x37\x20\x36\x37\x2e\x39\x20\x30\x6c\x39\x32\x2e\x31\x20\x39\ -\x32\x2e\x31\x63\x31\x38\x2e\x37\x20\x31\x38\x2e\x37\x20\x31\x38\ -\x2e\x37\x20\x34\x39\x2e\x31\x20\x30\x20\x36\x37\x2e\x39\x4c\x31\ -\x37\x37\x2e\x39\x20\x34\x39\x34\x2e\x31\x7a\x22\x2f\x3e\x3c\x2f\ -\x73\x76\x67\x3e\ +\x3c\x70\x61\x74\x68\x20\x64\x3d\x22\x4d\x34\x38\x20\x39\x36\x6c\ +\x30\x20\x33\x32\x30\x63\x30\x20\x38\x2e\x38\x20\x37\x2e\x32\x20\ +\x31\x36\x20\x31\x36\x20\x31\x36\x6c\x33\x32\x30\x20\x30\x63\x38\ +\x2e\x38\x20\x30\x20\x31\x36\x2d\x37\x2e\x32\x20\x31\x36\x2d\x31\ +\x36\x6c\x30\x2d\x32\x34\x35\x2e\x35\x63\x30\x2d\x34\x2e\x32\x2d\ +\x31\x2e\x37\x2d\x38\x2e\x33\x2d\x34\x2e\x37\x2d\x31\x31\x2e\x33\ +\x6c\x33\x33\x2e\x39\x2d\x33\x33\x2e\x39\x63\x31\x32\x20\x31\x32\ +\x20\x31\x38\x2e\x37\x20\x32\x38\x2e\x33\x20\x31\x38\x2e\x37\x20\ +\x34\x35\x2e\x33\x4c\x34\x34\x38\x20\x34\x31\x36\x63\x30\x20\x33\ +\x35\x2e\x33\x2d\x32\x38\x2e\x37\x20\x36\x34\x2d\x36\x34\x20\x36\ +\x34\x4c\x36\x34\x20\x34\x38\x30\x63\x2d\x33\x35\x2e\x33\x20\x30\ +\x2d\x36\x34\x2d\x32\x38\x2e\x37\x2d\x36\x34\x2d\x36\x34\x4c\x30\ +\x20\x39\x36\x43\x30\x20\x36\x30\x2e\x37\x20\x32\x38\x2e\x37\x20\ +\x33\x32\x20\x36\x34\x20\x33\x32\x6c\x32\x34\x35\x2e\x35\x20\x30\ +\x63\x31\x37\x20\x30\x20\x33\x33\x2e\x33\x20\x36\x2e\x37\x20\x34\ +\x35\x2e\x33\x20\x31\x38\x2e\x37\x6c\x37\x34\x2e\x35\x20\x37\x34\ +\x2e\x35\x2d\x33\x33\x2e\x39\x20\x33\x33\x2e\x39\x4c\x33\x32\x30\ +\x2e\x38\x20\x38\x34\x2e\x37\x63\x2d\x2e\x33\x2d\x2e\x33\x2d\x2e\ +\x35\x2d\x2e\x35\x2d\x2e\x38\x2d\x2e\x38\x4c\x33\x32\x30\x20\x31\ +\x38\x34\x63\x30\x20\x31\x33\x2e\x33\x2d\x31\x30\x2e\x37\x20\x32\ +\x34\x2d\x32\x34\x20\x32\x34\x6c\x2d\x31\x39\x32\x20\x30\x63\x2d\ +\x31\x33\x2e\x33\x20\x30\x2d\x32\x34\x2d\x31\x30\x2e\x37\x2d\x32\ +\x34\x2d\x32\x34\x4c\x38\x30\x20\x38\x30\x20\x36\x34\x20\x38\x30\ +\x63\x2d\x38\x2e\x38\x20\x30\x2d\x31\x36\x20\x37\x2e\x32\x2d\x31\ +\x36\x20\x31\x36\x7a\x6d\x38\x30\x2d\x31\x36\x6c\x30\x20\x38\x30\ +\x20\x31\x34\x34\x20\x30\x20\x30\x2d\x38\x30\x4c\x31\x32\x38\x20\ +\x38\x30\x7a\x6d\x33\x32\x20\x32\x34\x30\x61\x36\x34\x20\x36\x34\ +\x20\x30\x20\x31\x20\x31\x20\x31\x32\x38\x20\x30\x20\x36\x34\x20\ +\x36\x34\x20\x30\x20\x31\x20\x31\x20\x2d\x31\x32\x38\x20\x30\x7a\ +\x22\x2f\x3e\x3c\x2f\x73\x76\x67\x3e\ \x00\x00\x03\x42\ \x3c\ \x73\x76\x67\x20\x78\x6d\x6c\x6e\x73\x3d\x22\x68\x74\x74\x70\x3a\ @@ -379,7 +121,7 @@ \x30\x2f\x73\x76\x67\x22\x20\x76\x69\x65\x77\x42\x6f\x78\x3d\x22\ \x30\x20\x30\x20\x35\x31\x32\x20\x35\x31\x32\x22\x3e\x3c\x21\x2d\ \x2d\x21\x20\x46\x6f\x6e\x74\x20\x41\x77\x65\x73\x6f\x6d\x65\x20\ -\x46\x72\x65\x65\x20\x36\x2e\x35\x2e\x32\x20\x62\x79\x20\x40\x66\ +\x46\x72\x65\x65\x20\x36\x2e\x37\x2e\x32\x20\x62\x79\x20\x40\x66\ \x6f\x6e\x74\x61\x77\x65\x73\x6f\x6d\x65\x20\x2d\x20\x68\x74\x74\ \x70\x73\x3a\x2f\x2f\x66\x6f\x6e\x74\x61\x77\x65\x73\x6f\x6d\x65\ \x2e\x63\x6f\x6d\x20\x4c\x69\x63\x65\x6e\x73\x65\x20\x2d\x20\x68\ @@ -427,14 +169,14 @@ \x30\x63\x2d\x33\x35\x2e\x33\x20\x30\x2d\x36\x34\x2d\x32\x38\x2e\ \x37\x2d\x36\x34\x2d\x36\x34\x7a\x22\x2f\x3e\x3c\x2f\x73\x76\x67\ \x3e\ -\x00\x00\x03\x98\ +\x00\x00\x04\x30\ \x3c\ \x73\x76\x67\x20\x78\x6d\x6c\x6e\x73\x3d\x22\x68\x74\x74\x70\x3a\ \x2f\x2f\x77\x77\x77\x2e\x77\x33\x2e\x6f\x72\x67\x2f\x32\x30\x30\ \x30\x2f\x73\x76\x67\x22\x20\x76\x69\x65\x77\x42\x6f\x78\x3d\x22\ -\x30\x20\x30\x20\x35\x37\x36\x20\x35\x31\x32\x22\x3e\x3c\x21\x2d\ +\x30\x20\x30\x20\x35\x31\x32\x20\x35\x31\x32\x22\x3e\x3c\x21\x2d\ \x2d\x21\x20\x46\x6f\x6e\x74\x20\x41\x77\x65\x73\x6f\x6d\x65\x20\ -\x46\x72\x65\x65\x20\x36\x2e\x35\x2e\x32\x20\x62\x79\x20\x40\x66\ +\x46\x72\x65\x65\x20\x36\x2e\x37\x2e\x32\x20\x62\x79\x20\x40\x66\ \x6f\x6e\x74\x61\x77\x65\x73\x6f\x6d\x65\x20\x2d\x20\x68\x74\x74\ \x70\x73\x3a\x2f\x2f\x66\x6f\x6e\x74\x61\x77\x65\x73\x6f\x6d\x65\ \x2e\x63\x6f\x6d\x20\x4c\x69\x63\x65\x6e\x73\x65\x20\x2d\x20\x68\ @@ -446,55 +188,64 @@ \x20\x4d\x49\x54\x20\x4c\x69\x63\x65\x6e\x73\x65\x29\x20\x43\x6f\ \x70\x79\x72\x69\x67\x68\x74\x20\x32\x30\x32\x34\x20\x46\x6f\x6e\ \x74\x69\x63\x6f\x6e\x73\x2c\x20\x49\x6e\x63\x2e\x20\x2d\x2d\x3e\ -\x3c\x70\x61\x74\x68\x20\x64\x3d\x22\x4d\x32\x36\x34\x2e\x35\x20\ -\x35\x2e\x32\x63\x31\x34\x2e\x39\x2d\x36\x2e\x39\x20\x33\x32\x2e\ -\x31\x2d\x36\x2e\x39\x20\x34\x37\x20\x30\x6c\x32\x31\x38\x2e\x36\ -\x20\x31\x30\x31\x63\x38\x2e\x35\x20\x33\x2e\x39\x20\x31\x33\x2e\ -\x39\x20\x31\x32\x2e\x34\x20\x31\x33\x2e\x39\x20\x32\x31\x2e\x38\ -\x73\x2d\x35\x2e\x34\x20\x31\x37\x2e\x39\x2d\x31\x33\x2e\x39\x20\ -\x32\x31\x2e\x38\x6c\x2d\x32\x31\x38\x2e\x36\x20\x31\x30\x31\x63\ -\x2d\x31\x34\x2e\x39\x20\x36\x2e\x39\x2d\x33\x32\x2e\x31\x20\x36\ -\x2e\x39\x2d\x34\x37\x20\x30\x4c\x34\x35\x2e\x39\x20\x31\x34\x39\ -\x2e\x38\x43\x33\x37\x2e\x34\x20\x31\x34\x35\x2e\x38\x20\x33\x32\ -\x20\x31\x33\x37\x2e\x33\x20\x33\x32\x20\x31\x32\x38\x73\x35\x2e\ -\x34\x2d\x31\x37\x2e\x39\x20\x31\x33\x2e\x39\x2d\x32\x31\x2e\x38\ -\x4c\x32\x36\x34\x2e\x35\x20\x35\x2e\x32\x7a\x4d\x34\x37\x36\x2e\ -\x39\x20\x32\x30\x39\x2e\x36\x6c\x35\x33\x2e\x32\x20\x32\x34\x2e\ -\x36\x63\x38\x2e\x35\x20\x33\x2e\x39\x20\x31\x33\x2e\x39\x20\x31\ -\x32\x2e\x34\x20\x31\x33\x2e\x39\x20\x32\x31\x2e\x38\x73\x2d\x35\ -\x2e\x34\x20\x31\x37\x2e\x39\x2d\x31\x33\x2e\x39\x20\x32\x31\x2e\ -\x38\x6c\x2d\x32\x31\x38\x2e\x36\x20\x31\x30\x31\x63\x2d\x31\x34\ -\x2e\x39\x20\x36\x2e\x39\x2d\x33\x32\x2e\x31\x20\x36\x2e\x39\x2d\ -\x34\x37\x20\x30\x4c\x34\x35\x2e\x39\x20\x32\x37\x37\x2e\x38\x43\ -\x33\x37\x2e\x34\x20\x32\x37\x33\x2e\x38\x20\x33\x32\x20\x32\x36\ -\x35\x2e\x33\x20\x33\x32\x20\x32\x35\x36\x73\x35\x2e\x34\x2d\x31\ -\x37\x2e\x39\x20\x31\x33\x2e\x39\x2d\x32\x31\x2e\x38\x6c\x35\x33\ -\x2e\x32\x2d\x32\x34\x2e\x36\x20\x31\x35\x32\x20\x37\x30\x2e\x32\ -\x63\x32\x33\x2e\x34\x20\x31\x30\x2e\x38\x20\x35\x30\x2e\x34\x20\ -\x31\x30\x2e\x38\x20\x37\x33\x2e\x38\x20\x30\x6c\x31\x35\x32\x2d\ -\x37\x30\x2e\x32\x7a\x6d\x2d\x31\x35\x32\x20\x31\x39\x38\x2e\x32\ -\x6c\x31\x35\x32\x2d\x37\x30\x2e\x32\x20\x35\x33\x2e\x32\x20\x32\ -\x34\x2e\x36\x63\x38\x2e\x35\x20\x33\x2e\x39\x20\x31\x33\x2e\x39\ -\x20\x31\x32\x2e\x34\x20\x31\x33\x2e\x39\x20\x32\x31\x2e\x38\x73\ -\x2d\x35\x2e\x34\x20\x31\x37\x2e\x39\x2d\x31\x33\x2e\x39\x20\x32\ -\x31\x2e\x38\x6c\x2d\x32\x31\x38\x2e\x36\x20\x31\x30\x31\x63\x2d\ -\x31\x34\x2e\x39\x20\x36\x2e\x39\x2d\x33\x32\x2e\x31\x20\x36\x2e\ -\x39\x2d\x34\x37\x20\x30\x4c\x34\x35\x2e\x39\x20\x34\x30\x35\x2e\ -\x38\x43\x33\x37\x2e\x34\x20\x34\x30\x31\x2e\x38\x20\x33\x32\x20\ -\x33\x39\x33\x2e\x33\x20\x33\x32\x20\x33\x38\x34\x73\x35\x2e\x34\ -\x2d\x31\x37\x2e\x39\x20\x31\x33\x2e\x39\x2d\x32\x31\x2e\x38\x6c\ -\x35\x33\x2e\x32\x2d\x32\x34\x2e\x36\x20\x31\x35\x32\x20\x37\x30\ -\x2e\x32\x63\x32\x33\x2e\x34\x20\x31\x30\x2e\x38\x20\x35\x30\x2e\ -\x34\x20\x31\x30\x2e\x38\x20\x37\x33\x2e\x38\x20\x30\x7a\x22\x2f\ -\x3e\x3c\x2f\x73\x76\x67\x3e\ -\x00\x00\x02\x42\ +\x3c\x70\x61\x74\x68\x20\x64\x3d\x22\x4d\x37\x38\x2e\x36\x20\x35\ +\x43\x36\x39\x2e\x31\x2d\x32\x2e\x34\x20\x35\x35\x2e\x36\x2d\x31\ +\x2e\x35\x20\x34\x37\x20\x37\x4c\x37\x20\x34\x37\x63\x2d\x38\x2e\ +\x35\x20\x38\x2e\x35\x2d\x39\x2e\x34\x20\x32\x32\x2d\x32\x2e\x31\ +\x20\x33\x31\x2e\x36\x6c\x38\x30\x20\x31\x30\x34\x63\x34\x2e\x35\ +\x20\x35\x2e\x39\x20\x31\x31\x2e\x36\x20\x39\x2e\x34\x20\x31\x39\ +\x20\x39\x2e\x34\x6c\x35\x34\x2e\x31\x20\x30\x20\x31\x30\x39\x20\ +\x31\x30\x39\x63\x2d\x31\x34\x2e\x37\x20\x32\x39\x2d\x31\x30\x20\ +\x36\x35\x2e\x34\x20\x31\x34\x2e\x33\x20\x38\x39\x2e\x36\x6c\x31\ +\x31\x32\x20\x31\x31\x32\x63\x31\x32\x2e\x35\x20\x31\x32\x2e\x35\ +\x20\x33\x32\x2e\x38\x20\x31\x32\x2e\x35\x20\x34\x35\x2e\x33\x20\ +\x30\x6c\x36\x34\x2d\x36\x34\x63\x31\x32\x2e\x35\x2d\x31\x32\x2e\ +\x35\x20\x31\x32\x2e\x35\x2d\x33\x32\x2e\x38\x20\x30\x2d\x34\x35\ +\x2e\x33\x6c\x2d\x31\x31\x32\x2d\x31\x31\x32\x63\x2d\x32\x34\x2e\ +\x32\x2d\x32\x34\x2e\x32\x2d\x36\x30\x2e\x36\x2d\x32\x39\x2d\x38\ +\x39\x2e\x36\x2d\x31\x34\x2e\x33\x6c\x2d\x31\x30\x39\x2d\x31\x30\ +\x39\x20\x30\x2d\x35\x34\x2e\x31\x63\x30\x2d\x37\x2e\x35\x2d\x33\ +\x2e\x35\x2d\x31\x34\x2e\x35\x2d\x39\x2e\x34\x2d\x31\x39\x4c\x37\ +\x38\x2e\x36\x20\x35\x7a\x4d\x31\x39\x2e\x39\x20\x33\x39\x36\x2e\ +\x31\x43\x37\x2e\x32\x20\x34\x30\x38\x2e\x38\x20\x30\x20\x34\x32\ +\x36\x2e\x31\x20\x30\x20\x34\x34\x34\x2e\x31\x43\x30\x20\x34\x38\ +\x31\x2e\x36\x20\x33\x30\x2e\x34\x20\x35\x31\x32\x20\x36\x37\x2e\ +\x39\x20\x35\x31\x32\x63\x31\x38\x20\x30\x20\x33\x35\x2e\x33\x2d\ +\x37\x2e\x32\x20\x34\x38\x2d\x31\x39\x2e\x39\x4c\x32\x33\x33\x2e\ +\x37\x20\x33\x37\x34\x2e\x33\x63\x2d\x37\x2e\x38\x2d\x32\x30\x2e\ +\x39\x2d\x39\x2d\x34\x33\x2e\x36\x2d\x33\x2e\x36\x2d\x36\x35\x2e\ +\x31\x6c\x2d\x36\x31\x2e\x37\x2d\x36\x31\x2e\x37\x4c\x31\x39\x2e\ +\x39\x20\x33\x39\x36\x2e\x31\x7a\x4d\x35\x31\x32\x20\x31\x34\x34\ +\x63\x30\x2d\x31\x30\x2e\x35\x2d\x31\x2e\x31\x2d\x32\x30\x2e\x37\ +\x2d\x33\x2e\x32\x2d\x33\x30\x2e\x35\x63\x2d\x32\x2e\x34\x2d\x31\ +\x31\x2e\x32\x2d\x31\x36\x2e\x31\x2d\x31\x34\x2e\x31\x2d\x32\x34\ +\x2e\x32\x2d\x36\x6c\x2d\x36\x33\x2e\x39\x20\x36\x33\x2e\x39\x63\ +\x2d\x33\x20\x33\x2d\x37\x2e\x31\x20\x34\x2e\x37\x2d\x31\x31\x2e\ +\x33\x20\x34\x2e\x37\x4c\x33\x35\x32\x20\x31\x37\x36\x63\x2d\x38\ +\x2e\x38\x20\x30\x2d\x31\x36\x2d\x37\x2e\x32\x2d\x31\x36\x2d\x31\ +\x36\x6c\x30\x2d\x35\x37\x2e\x34\x63\x30\x2d\x34\x2e\x32\x20\x31\ +\x2e\x37\x2d\x38\x2e\x33\x20\x34\x2e\x37\x2d\x31\x31\x2e\x33\x6c\ +\x36\x33\x2e\x39\x2d\x36\x33\x2e\x39\x63\x38\x2e\x31\x2d\x38\x2e\ +\x31\x20\x35\x2e\x32\x2d\x32\x31\x2e\x38\x2d\x36\x2d\x32\x34\x2e\ +\x32\x43\x33\x38\x38\x2e\x37\x20\x31\x2e\x31\x20\x33\x37\x38\x2e\ +\x35\x20\x30\x20\x33\x36\x38\x20\x30\x43\x32\x38\x38\x2e\x35\x20\ +\x30\x20\x32\x32\x34\x20\x36\x34\x2e\x35\x20\x32\x32\x34\x20\x31\ +\x34\x34\x6c\x30\x20\x2e\x38\x20\x38\x35\x2e\x33\x20\x38\x35\x2e\ +\x33\x63\x33\x36\x2d\x39\x2e\x31\x20\x37\x35\x2e\x38\x20\x2e\x35\ +\x20\x31\x30\x34\x20\x32\x38\x2e\x37\x4c\x34\x32\x39\x20\x32\x37\ +\x34\x2e\x35\x63\x34\x39\x2d\x32\x33\x20\x38\x33\x2d\x37\x32\x2e\ +\x38\x20\x38\x33\x2d\x31\x33\x30\x2e\x35\x7a\x4d\x35\x36\x20\x34\ +\x33\x32\x61\x32\x34\x20\x32\x34\x20\x30\x20\x31\x20\x31\x20\x34\ +\x38\x20\x30\x20\x32\x34\x20\x32\x34\x20\x30\x20\x31\x20\x31\x20\ +\x2d\x34\x38\x20\x30\x7a\x22\x2f\x3e\x3c\x2f\x73\x76\x67\x3e\ +\x00\x00\x03\x89\ \x3c\ \x73\x76\x67\x20\x78\x6d\x6c\x6e\x73\x3d\x22\x68\x74\x74\x70\x3a\ \x2f\x2f\x77\x77\x77\x2e\x77\x33\x2e\x6f\x72\x67\x2f\x32\x30\x30\ \x30\x2f\x73\x76\x67\x22\x20\x76\x69\x65\x77\x42\x6f\x78\x3d\x22\ \x30\x20\x30\x20\x35\x31\x32\x20\x35\x31\x32\x22\x3e\x3c\x21\x2d\ \x2d\x21\x20\x46\x6f\x6e\x74\x20\x41\x77\x65\x73\x6f\x6d\x65\x20\ -\x46\x72\x65\x65\x20\x36\x2e\x35\x2e\x32\x20\x62\x79\x20\x40\x66\ +\x46\x72\x65\x65\x20\x36\x2e\x37\x2e\x32\x20\x62\x79\x20\x40\x66\ \x6f\x6e\x74\x61\x77\x65\x73\x6f\x6d\x65\x20\x2d\x20\x68\x74\x74\ \x70\x73\x3a\x2f\x2f\x66\x6f\x6e\x74\x61\x77\x65\x73\x6f\x6d\x65\ \x2e\x63\x6f\x6d\x20\x4c\x69\x63\x65\x6e\x73\x65\x20\x2d\x20\x68\ @@ -506,34 +257,54 @@ \x20\x4d\x49\x54\x20\x4c\x69\x63\x65\x6e\x73\x65\x29\x20\x43\x6f\ \x70\x79\x72\x69\x67\x68\x74\x20\x32\x30\x32\x34\x20\x46\x6f\x6e\ \x74\x69\x63\x6f\x6e\x73\x2c\x20\x49\x6e\x63\x2e\x20\x2d\x2d\x3e\ -\x3c\x70\x61\x74\x68\x20\x64\x3d\x22\x4d\x36\x34\x20\x33\x32\x43\ -\x32\x38\x2e\x37\x20\x33\x32\x20\x30\x20\x36\x30\x2e\x37\x20\x30\ -\x20\x39\x36\x56\x34\x31\x36\x63\x30\x20\x33\x35\x2e\x33\x20\x32\ -\x38\x2e\x37\x20\x36\x34\x20\x36\x34\x20\x36\x34\x48\x34\x34\x38\ -\x63\x33\x35\x2e\x33\x20\x30\x20\x36\x34\x2d\x32\x38\x2e\x37\x20\ -\x36\x34\x2d\x36\x34\x56\x39\x36\x63\x30\x2d\x33\x35\x2e\x33\x2d\ -\x32\x38\x2e\x37\x2d\x36\x34\x2d\x36\x34\x2d\x36\x34\x48\x36\x34\ -\x7a\x6d\x38\x38\x20\x36\x34\x76\x36\x34\x48\x36\x34\x56\x39\x36\ -\x68\x38\x38\x7a\x6d\x35\x36\x20\x30\x68\x38\x38\x76\x36\x34\x48\ -\x32\x30\x38\x56\x39\x36\x7a\x6d\x32\x34\x30\x20\x30\x76\x36\x34\ -\x48\x33\x36\x30\x56\x39\x36\x68\x38\x38\x7a\x4d\x36\x34\x20\x32\ -\x32\x34\x68\x38\x38\x76\x36\x34\x48\x36\x34\x56\x32\x32\x34\x7a\ -\x6d\x32\x33\x32\x20\x30\x76\x36\x34\x48\x32\x30\x38\x56\x32\x32\ -\x34\x68\x38\x38\x7a\x6d\x36\x34\x20\x30\x68\x38\x38\x76\x36\x34\ -\x48\x33\x36\x30\x56\x32\x32\x34\x7a\x4d\x31\x35\x32\x20\x33\x35\ -\x32\x76\x36\x34\x48\x36\x34\x56\x33\x35\x32\x68\x38\x38\x7a\x6d\ -\x35\x36\x20\x30\x68\x38\x38\x76\x36\x34\x48\x32\x30\x38\x56\x33\ -\x35\x32\x7a\x6d\x32\x34\x30\x20\x30\x76\x36\x34\x48\x33\x36\x30\ -\x56\x33\x35\x32\x68\x38\x38\x7a\x22\x2f\x3e\x3c\x2f\x73\x76\x67\ -\x3e\ -\x00\x00\x03\x65\ +\x3c\x70\x61\x74\x68\x20\x64\x3d\x22\x4d\x35\x31\x32\x20\x39\x36\ +\x63\x30\x20\x35\x30\x2e\x32\x2d\x35\x39\x2e\x31\x20\x31\x32\x35\ +\x2e\x31\x2d\x38\x34\x2e\x36\x20\x31\x35\x35\x63\x2d\x33\x2e\x38\ +\x20\x34\x2e\x34\x2d\x39\x2e\x34\x20\x36\x2e\x31\x2d\x31\x34\x2e\ +\x35\x20\x35\x4c\x33\x32\x30\x20\x32\x35\x36\x63\x2d\x31\x37\x2e\ +\x37\x20\x30\x2d\x33\x32\x20\x31\x34\x2e\x33\x2d\x33\x32\x20\x33\ +\x32\x73\x31\x34\x2e\x33\x20\x33\x32\x20\x33\x32\x20\x33\x32\x6c\ +\x39\x36\x20\x30\x63\x35\x33\x20\x30\x20\x39\x36\x20\x34\x33\x20\ +\x39\x36\x20\x39\x36\x73\x2d\x34\x33\x20\x39\x36\x2d\x39\x36\x20\ +\x39\x36\x6c\x2d\x32\x37\x36\x2e\x34\x20\x30\x63\x38\x2e\x37\x2d\ +\x39\x2e\x39\x20\x31\x39\x2e\x33\x2d\x32\x32\x2e\x36\x20\x33\x30\ +\x2d\x33\x36\x2e\x38\x63\x36\x2e\x33\x2d\x38\x2e\x34\x20\x31\x32\ +\x2e\x38\x2d\x31\x37\x2e\x36\x20\x31\x39\x2d\x32\x37\x2e\x32\x4c\ +\x34\x31\x36\x20\x34\x34\x38\x63\x31\x37\x2e\x37\x20\x30\x20\x33\ +\x32\x2d\x31\x34\x2e\x33\x20\x33\x32\x2d\x33\x32\x73\x2d\x31\x34\ +\x2e\x33\x2d\x33\x32\x2d\x33\x32\x2d\x33\x32\x6c\x2d\x39\x36\x20\ +\x30\x63\x2d\x35\x33\x20\x30\x2d\x39\x36\x2d\x34\x33\x2d\x39\x36\ +\x2d\x39\x36\x73\x34\x33\x2d\x39\x36\x20\x39\x36\x2d\x39\x36\x6c\ +\x33\x39\x2e\x38\x20\x30\x63\x2d\x32\x31\x2d\x33\x31\x2e\x35\x2d\ +\x33\x39\x2e\x38\x2d\x36\x37\x2e\x37\x2d\x33\x39\x2e\x38\x2d\x39\ +\x36\x63\x30\x2d\x35\x33\x20\x34\x33\x2d\x39\x36\x20\x39\x36\x2d\ +\x39\x36\x73\x39\x36\x20\x34\x33\x20\x39\x36\x20\x39\x36\x7a\x4d\ +\x31\x31\x37\x2e\x31\x20\x34\x38\x39\x2e\x31\x63\x2d\x33\x2e\x38\ +\x20\x34\x2e\x33\x2d\x37\x2e\x32\x20\x38\x2e\x31\x2d\x31\x30\x2e\ +\x31\x20\x31\x31\x2e\x33\x6c\x2d\x31\x2e\x38\x20\x32\x2d\x2e\x32\ +\x2d\x2e\x32\x63\x2d\x36\x20\x34\x2e\x36\x2d\x31\x34\x2e\x36\x20\ +\x34\x2d\x32\x30\x2d\x31\x2e\x38\x43\x35\x39\x2e\x38\x20\x34\x37\ +\x33\x20\x30\x20\x34\x30\x32\x2e\x35\x20\x30\x20\x33\x35\x32\x63\ +\x30\x2d\x35\x33\x20\x34\x33\x2d\x39\x36\x20\x39\x36\x2d\x39\x36\ +\x73\x39\x36\x20\x34\x33\x20\x39\x36\x20\x39\x36\x63\x30\x20\x33\ +\x30\x2d\x32\x31\x2e\x31\x20\x36\x37\x2d\x34\x33\x2e\x35\x20\x39\ +\x37\x2e\x39\x63\x2d\x31\x30\x2e\x37\x20\x31\x34\x2e\x37\x2d\x32\ +\x31\x2e\x37\x20\x32\x38\x2d\x33\x30\x2e\x38\x20\x33\x38\x2e\x35\ +\x6c\x2d\x2e\x36\x20\x2e\x37\x7a\x4d\x31\x32\x38\x20\x33\x35\x32\ +\x61\x33\x32\x20\x33\x32\x20\x30\x20\x31\x20\x30\x20\x2d\x36\x34\ +\x20\x30\x20\x33\x32\x20\x33\x32\x20\x30\x20\x31\x20\x30\x20\x36\ +\x34\x20\x30\x7a\x4d\x34\x31\x36\x20\x31\x32\x38\x61\x33\x32\x20\ +\x33\x32\x20\x30\x20\x31\x20\x30\x20\x30\x2d\x36\x34\x20\x33\x32\ +\x20\x33\x32\x20\x30\x20\x31\x20\x30\x20\x30\x20\x36\x34\x7a\x22\ +\x2f\x3e\x3c\x2f\x73\x76\x67\x3e\ +\x00\x00\x04\xa1\ \x3c\ \x73\x76\x67\x20\x78\x6d\x6c\x6e\x73\x3d\x22\x68\x74\x74\x70\x3a\ \x2f\x2f\x77\x77\x77\x2e\x77\x33\x2e\x6f\x72\x67\x2f\x32\x30\x30\ \x30\x2f\x73\x76\x67\x22\x20\x76\x69\x65\x77\x42\x6f\x78\x3d\x22\ -\x30\x20\x30\x20\x35\x31\x32\x20\x35\x31\x32\x22\x3e\x3c\x21\x2d\ +\x30\x20\x30\x20\x35\x37\x36\x20\x35\x31\x32\x22\x3e\x3c\x21\x2d\ \x2d\x21\x20\x46\x6f\x6e\x74\x20\x41\x77\x65\x73\x6f\x6d\x65\x20\ -\x46\x72\x65\x65\x20\x36\x2e\x35\x2e\x32\x20\x62\x79\x20\x40\x66\ +\x46\x72\x65\x65\x20\x36\x2e\x37\x2e\x32\x20\x62\x79\x20\x40\x66\ \x6f\x6e\x74\x61\x77\x65\x73\x6f\x6d\x65\x20\x2d\x20\x68\x74\x74\ \x70\x73\x3a\x2f\x2f\x66\x6f\x6e\x74\x61\x77\x65\x73\x6f\x6d\x65\ \x2e\x63\x6f\x6d\x20\x4c\x69\x63\x65\x6e\x73\x65\x20\x2d\x20\x68\ @@ -545,52 +316,72 @@ \x20\x4d\x49\x54\x20\x4c\x69\x63\x65\x6e\x73\x65\x29\x20\x43\x6f\ \x70\x79\x72\x69\x67\x68\x74\x20\x32\x30\x32\x34\x20\x46\x6f\x6e\ \x74\x69\x63\x6f\x6e\x73\x2c\x20\x49\x6e\x63\x2e\x20\x2d\x2d\x3e\ -\x3c\x70\x61\x74\x68\x20\x64\x3d\x22\x4d\x34\x30\x20\x34\x38\x43\ -\x32\x36\x2e\x37\x20\x34\x38\x20\x31\x36\x20\x35\x38\x2e\x37\x20\ -\x31\x36\x20\x37\x32\x76\x34\x38\x63\x30\x20\x31\x33\x2e\x33\x20\ -\x31\x30\x2e\x37\x20\x32\x34\x20\x32\x34\x20\x32\x34\x48\x38\x38\ -\x63\x31\x33\x2e\x33\x20\x30\x20\x32\x34\x2d\x31\x30\x2e\x37\x20\ -\x32\x34\x2d\x32\x34\x56\x37\x32\x63\x30\x2d\x31\x33\x2e\x33\x2d\ -\x31\x30\x2e\x37\x2d\x32\x34\x2d\x32\x34\x2d\x32\x34\x48\x34\x30\ -\x7a\x4d\x31\x39\x32\x20\x36\x34\x63\x2d\x31\x37\x2e\x37\x20\x30\ -\x2d\x33\x32\x20\x31\x34\x2e\x33\x2d\x33\x32\x20\x33\x32\x73\x31\ -\x34\x2e\x33\x20\x33\x32\x20\x33\x32\x20\x33\x32\x48\x34\x38\x30\ -\x63\x31\x37\x2e\x37\x20\x30\x20\x33\x32\x2d\x31\x34\x2e\x33\x20\ -\x33\x32\x2d\x33\x32\x73\x2d\x31\x34\x2e\x33\x2d\x33\x32\x2d\x33\ -\x32\x2d\x33\x32\x48\x31\x39\x32\x7a\x6d\x30\x20\x31\x36\x30\x63\ -\x2d\x31\x37\x2e\x37\x20\x30\x2d\x33\x32\x20\x31\x34\x2e\x33\x2d\ -\x33\x32\x20\x33\x32\x73\x31\x34\x2e\x33\x20\x33\x32\x20\x33\x32\ -\x20\x33\x32\x48\x34\x38\x30\x63\x31\x37\x2e\x37\x20\x30\x20\x33\ -\x32\x2d\x31\x34\x2e\x33\x20\x33\x32\x2d\x33\x32\x73\x2d\x31\x34\ -\x2e\x33\x2d\x33\x32\x2d\x33\x32\x2d\x33\x32\x48\x31\x39\x32\x7a\ -\x6d\x30\x20\x31\x36\x30\x63\x2d\x31\x37\x2e\x37\x20\x30\x2d\x33\ -\x32\x20\x31\x34\x2e\x33\x2d\x33\x32\x20\x33\x32\x73\x31\x34\x2e\ -\x33\x20\x33\x32\x20\x33\x32\x20\x33\x32\x48\x34\x38\x30\x63\x31\ -\x37\x2e\x37\x20\x30\x20\x33\x32\x2d\x31\x34\x2e\x33\x20\x33\x32\ -\x2d\x33\x32\x73\x2d\x31\x34\x2e\x33\x2d\x33\x32\x2d\x33\x32\x2d\ -\x33\x32\x48\x31\x39\x32\x7a\x4d\x31\x36\x20\x32\x33\x32\x76\x34\ -\x38\x63\x30\x20\x31\x33\x2e\x33\x20\x31\x30\x2e\x37\x20\x32\x34\ -\x20\x32\x34\x20\x32\x34\x48\x38\x38\x63\x31\x33\x2e\x33\x20\x30\ -\x20\x32\x34\x2d\x31\x30\x2e\x37\x20\x32\x34\x2d\x32\x34\x56\x32\ -\x33\x32\x63\x30\x2d\x31\x33\x2e\x33\x2d\x31\x30\x2e\x37\x2d\x32\ -\x34\x2d\x32\x34\x2d\x32\x34\x48\x34\x30\x63\x2d\x31\x33\x2e\x33\ -\x20\x30\x2d\x32\x34\x20\x31\x30\x2e\x37\x2d\x32\x34\x20\x32\x34\ -\x7a\x4d\x34\x30\x20\x33\x36\x38\x63\x2d\x31\x33\x2e\x33\x20\x30\ -\x2d\x32\x34\x20\x31\x30\x2e\x37\x2d\x32\x34\x20\x32\x34\x76\x34\ -\x38\x63\x30\x20\x31\x33\x2e\x33\x20\x31\x30\x2e\x37\x20\x32\x34\ -\x20\x32\x34\x20\x32\x34\x48\x38\x38\x63\x31\x33\x2e\x33\x20\x30\ -\x20\x32\x34\x2d\x31\x30\x2e\x37\x20\x32\x34\x2d\x32\x34\x56\x33\ -\x39\x32\x63\x30\x2d\x31\x33\x2e\x33\x2d\x31\x30\x2e\x37\x2d\x32\ -\x34\x2d\x32\x34\x2d\x32\x34\x48\x34\x30\x7a\x22\x2f\x3e\x3c\x2f\ -\x73\x76\x67\x3e\ -\x00\x00\x03\x13\ +\x3c\x70\x61\x74\x68\x20\x64\x3d\x22\x4d\x32\x33\x34\x2e\x37\x20\ +\x34\x32\x2e\x37\x4c\x31\x39\x37\x20\x35\x36\x2e\x38\x63\x2d\x33\ +\x20\x31\x2e\x31\x2d\x35\x20\x34\x2d\x35\x20\x37\x2e\x32\x73\x32\ +\x20\x36\x2e\x31\x20\x35\x20\x37\x2e\x32\x6c\x33\x37\x2e\x37\x20\ +\x31\x34\x2e\x31\x4c\x32\x34\x38\x2e\x38\x20\x31\x32\x33\x63\x31\ +\x2e\x31\x20\x33\x20\x34\x20\x35\x20\x37\x2e\x32\x20\x35\x73\x36\ +\x2e\x31\x2d\x32\x20\x37\x2e\x32\x2d\x35\x6c\x31\x34\x2e\x31\x2d\ +\x33\x37\x2e\x37\x4c\x33\x31\x35\x20\x37\x31\x2e\x32\x63\x33\x2d\ +\x31\x2e\x31\x20\x35\x2d\x34\x20\x35\x2d\x37\x2e\x32\x73\x2d\x32\ +\x2d\x36\x2e\x31\x2d\x35\x2d\x37\x2e\x32\x4c\x32\x37\x37\x2e\x33\ +\x20\x34\x32\x2e\x37\x20\x32\x36\x33\x2e\x32\x20\x35\x63\x2d\x31\ +\x2e\x31\x2d\x33\x2d\x34\x2d\x35\x2d\x37\x2e\x32\x2d\x35\x73\x2d\ +\x36\x2e\x31\x20\x32\x2d\x37\x2e\x32\x20\x35\x4c\x32\x33\x34\x2e\ +\x37\x20\x34\x32\x2e\x37\x7a\x4d\x34\x36\x2e\x31\x20\x33\x39\x35\ +\x2e\x34\x63\x2d\x31\x38\x2e\x37\x20\x31\x38\x2e\x37\x2d\x31\x38\ +\x2e\x37\x20\x34\x39\x2e\x31\x20\x30\x20\x36\x37\x2e\x39\x6c\x33\ +\x34\x2e\x36\x20\x33\x34\x2e\x36\x63\x31\x38\x2e\x37\x20\x31\x38\ +\x2e\x37\x20\x34\x39\x2e\x31\x20\x31\x38\x2e\x37\x20\x36\x37\x2e\ +\x39\x20\x30\x4c\x35\x32\x39\x2e\x39\x20\x31\x31\x36\x2e\x35\x63\ +\x31\x38\x2e\x37\x2d\x31\x38\x2e\x37\x20\x31\x38\x2e\x37\x2d\x34\ +\x39\x2e\x31\x20\x30\x2d\x36\x37\x2e\x39\x4c\x34\x39\x35\x2e\x33\ +\x20\x31\x34\x2e\x31\x63\x2d\x31\x38\x2e\x37\x2d\x31\x38\x2e\x37\ +\x2d\x34\x39\x2e\x31\x2d\x31\x38\x2e\x37\x2d\x36\x37\x2e\x39\x20\ +\x30\x4c\x34\x36\x2e\x31\x20\x33\x39\x35\x2e\x34\x7a\x4d\x34\x38\ +\x34\x2e\x36\x20\x38\x32\x2e\x36\x6c\x2d\x31\x30\x35\x20\x31\x30\ +\x35\x2d\x32\x33\x2e\x33\x2d\x32\x33\x2e\x33\x20\x31\x30\x35\x2d\ +\x31\x30\x35\x20\x32\x33\x2e\x33\x20\x32\x33\x2e\x33\x7a\x4d\x37\ +\x2e\x35\x20\x31\x31\x37\x2e\x32\x43\x33\x20\x31\x31\x38\x2e\x39\ +\x20\x30\x20\x31\x32\x33\x2e\x32\x20\x30\x20\x31\x32\x38\x73\x33\ +\x20\x39\x2e\x31\x20\x37\x2e\x35\x20\x31\x30\x2e\x38\x4c\x36\x34\ +\x20\x31\x36\x30\x6c\x32\x31\x2e\x32\x20\x35\x36\x2e\x35\x63\x31\ +\x2e\x37\x20\x34\x2e\x35\x20\x36\x20\x37\x2e\x35\x20\x31\x30\x2e\ +\x38\x20\x37\x2e\x35\x73\x39\x2e\x31\x2d\x33\x20\x31\x30\x2e\x38\ +\x2d\x37\x2e\x35\x4c\x31\x32\x38\x20\x31\x36\x30\x6c\x35\x36\x2e\ +\x35\x2d\x32\x31\x2e\x32\x63\x34\x2e\x35\x2d\x31\x2e\x37\x20\x37\ +\x2e\x35\x2d\x36\x20\x37\x2e\x35\x2d\x31\x30\x2e\x38\x73\x2d\x33\ +\x2d\x39\x2e\x31\x2d\x37\x2e\x35\x2d\x31\x30\x2e\x38\x4c\x31\x32\ +\x38\x20\x39\x36\x20\x31\x30\x36\x2e\x38\x20\x33\x39\x2e\x35\x43\ +\x31\x30\x35\x2e\x31\x20\x33\x35\x20\x31\x30\x30\x2e\x38\x20\x33\ +\x32\x20\x39\x36\x20\x33\x32\x73\x2d\x39\x2e\x31\x20\x33\x2d\x31\ +\x30\x2e\x38\x20\x37\x2e\x35\x4c\x36\x34\x20\x39\x36\x20\x37\x2e\ +\x35\x20\x31\x31\x37\x2e\x32\x7a\x6d\x33\x35\x32\x20\x32\x35\x36\ +\x63\x2d\x34\x2e\x35\x20\x31\x2e\x37\x2d\x37\x2e\x35\x20\x36\x2d\ +\x37\x2e\x35\x20\x31\x30\x2e\x38\x73\x33\x20\x39\x2e\x31\x20\x37\ +\x2e\x35\x20\x31\x30\x2e\x38\x4c\x34\x31\x36\x20\x34\x31\x36\x6c\ +\x32\x31\x2e\x32\x20\x35\x36\x2e\x35\x63\x31\x2e\x37\x20\x34\x2e\ +\x35\x20\x36\x20\x37\x2e\x35\x20\x31\x30\x2e\x38\x20\x37\x2e\x35\ +\x73\x39\x2e\x31\x2d\x33\x20\x31\x30\x2e\x38\x2d\x37\x2e\x35\x4c\ +\x34\x38\x30\x20\x34\x31\x36\x6c\x35\x36\x2e\x35\x2d\x32\x31\x2e\ +\x32\x63\x34\x2e\x35\x2d\x31\x2e\x37\x20\x37\x2e\x35\x2d\x36\x20\ +\x37\x2e\x35\x2d\x31\x30\x2e\x38\x73\x2d\x33\x2d\x39\x2e\x31\x2d\ +\x37\x2e\x35\x2d\x31\x30\x2e\x38\x4c\x34\x38\x30\x20\x33\x35\x32\ +\x6c\x2d\x32\x31\x2e\x32\x2d\x35\x36\x2e\x35\x63\x2d\x31\x2e\x37\ +\x2d\x34\x2e\x35\x2d\x36\x2d\x37\x2e\x35\x2d\x31\x30\x2e\x38\x2d\ +\x37\x2e\x35\x73\x2d\x39\x2e\x31\x20\x33\x2d\x31\x30\x2e\x38\x20\ +\x37\x2e\x35\x4c\x34\x31\x36\x20\x33\x35\x32\x6c\x2d\x35\x36\x2e\ +\x35\x20\x32\x31\x2e\x32\x7a\x22\x2f\x3e\x3c\x2f\x73\x76\x67\x3e\ +\ +\x00\x00\x03\x32\ \x3c\ \x73\x76\x67\x20\x78\x6d\x6c\x6e\x73\x3d\x22\x68\x74\x74\x70\x3a\ \x2f\x2f\x77\x77\x77\x2e\x77\x33\x2e\x6f\x72\x67\x2f\x32\x30\x30\ \x30\x2f\x73\x76\x67\x22\x20\x76\x69\x65\x77\x42\x6f\x78\x3d\x22\ -\x30\x20\x30\x20\x34\x34\x38\x20\x35\x31\x32\x22\x3e\x3c\x21\x2d\ +\x30\x20\x30\x20\x36\x34\x30\x20\x35\x31\x32\x22\x3e\x3c\x21\x2d\ \x2d\x21\x20\x46\x6f\x6e\x74\x20\x41\x77\x65\x73\x6f\x6d\x65\x20\ -\x46\x72\x65\x65\x20\x36\x2e\x35\x2e\x32\x20\x62\x79\x20\x40\x66\ +\x46\x72\x65\x65\x20\x36\x2e\x37\x2e\x32\x20\x62\x79\x20\x40\x66\ \x6f\x6e\x74\x61\x77\x65\x73\x6f\x6d\x65\x20\x2d\x20\x68\x74\x74\ \x70\x73\x3a\x2f\x2f\x66\x6f\x6e\x74\x61\x77\x65\x73\x6f\x6d\x65\ \x2e\x63\x6f\x6d\x20\x4c\x69\x63\x65\x6e\x73\x65\x20\x2d\x20\x68\ @@ -602,47 +393,49 @@ \x20\x4d\x49\x54\x20\x4c\x69\x63\x65\x6e\x73\x65\x29\x20\x43\x6f\ \x70\x79\x72\x69\x67\x68\x74\x20\x32\x30\x32\x34\x20\x46\x6f\x6e\ \x74\x69\x63\x6f\x6e\x73\x2c\x20\x49\x6e\x63\x2e\x20\x2d\x2d\x3e\ -\x3c\x70\x61\x74\x68\x20\x64\x3d\x22\x4d\x33\x36\x38\x20\x38\x30\ -\x68\x33\x32\x76\x33\x32\x48\x33\x36\x38\x56\x38\x30\x7a\x4d\x33\ -\x35\x32\x20\x33\x32\x63\x2d\x31\x37\x2e\x37\x20\x30\x2d\x33\x32\ -\x20\x31\x34\x2e\x33\x2d\x33\x32\x20\x33\x32\x48\x31\x32\x38\x63\ -\x30\x2d\x31\x37\x2e\x37\x2d\x31\x34\x2e\x33\x2d\x33\x32\x2d\x33\ -\x32\x2d\x33\x32\x48\x33\x32\x43\x31\x34\x2e\x33\x20\x33\x32\x20\ -\x30\x20\x34\x36\x2e\x33\x20\x30\x20\x36\x34\x76\x36\x34\x63\x30\ -\x20\x31\x37\x2e\x37\x20\x31\x34\x2e\x33\x20\x33\x32\x20\x33\x32\ -\x20\x33\x32\x56\x33\x35\x32\x63\x2d\x31\x37\x2e\x37\x20\x30\x2d\ -\x33\x32\x20\x31\x34\x2e\x33\x2d\x33\x32\x20\x33\x32\x76\x36\x34\ -\x63\x30\x20\x31\x37\x2e\x37\x20\x31\x34\x2e\x33\x20\x33\x32\x20\ -\x33\x32\x20\x33\x32\x48\x39\x36\x63\x31\x37\x2e\x37\x20\x30\x20\ -\x33\x32\x2d\x31\x34\x2e\x33\x20\x33\x32\x2d\x33\x32\x48\x33\x32\ -\x30\x63\x30\x20\x31\x37\x2e\x37\x20\x31\x34\x2e\x33\x20\x33\x32\ -\x20\x33\x32\x20\x33\x32\x68\x36\x34\x63\x31\x37\x2e\x37\x20\x30\ -\x20\x33\x32\x2d\x31\x34\x2e\x33\x20\x33\x32\x2d\x33\x32\x56\x33\ -\x38\x34\x63\x30\x2d\x31\x37\x2e\x37\x2d\x31\x34\x2e\x33\x2d\x33\ -\x32\x2d\x33\x32\x2d\x33\x32\x56\x31\x36\x30\x63\x31\x37\x2e\x37\ -\x20\x30\x20\x33\x32\x2d\x31\x34\x2e\x33\x20\x33\x32\x2d\x33\x32\ -\x56\x36\x34\x63\x30\x2d\x31\x37\x2e\x37\x2d\x31\x34\x2e\x33\x2d\ -\x33\x32\x2d\x33\x32\x2d\x33\x32\x48\x33\x35\x32\x7a\x4d\x39\x36\ -\x20\x31\x36\x30\x63\x31\x37\x2e\x37\x20\x30\x20\x33\x32\x2d\x31\ -\x34\x2e\x33\x20\x33\x32\x2d\x33\x32\x48\x33\x32\x30\x63\x30\x20\ -\x31\x37\x2e\x37\x20\x31\x34\x2e\x33\x20\x33\x32\x20\x33\x32\x20\ -\x33\x32\x56\x33\x35\x32\x63\x2d\x31\x37\x2e\x37\x20\x30\x2d\x33\ -\x32\x20\x31\x34\x2e\x33\x2d\x33\x32\x20\x33\x32\x48\x31\x32\x38\ -\x63\x30\x2d\x31\x37\x2e\x37\x2d\x31\x34\x2e\x33\x2d\x33\x32\x2d\ -\x33\x32\x2d\x33\x32\x56\x31\x36\x30\x7a\x4d\x34\x38\x20\x34\x30\ -\x30\x48\x38\x30\x76\x33\x32\x48\x34\x38\x56\x34\x30\x30\x7a\x6d\ -\x33\x32\x30\x20\x33\x32\x56\x34\x30\x30\x68\x33\x32\x76\x33\x32\ -\x48\x33\x36\x38\x7a\x4d\x34\x38\x20\x31\x31\x32\x56\x38\x30\x48\ -\x38\x30\x76\x33\x32\x48\x34\x38\x7a\x22\x2f\x3e\x3c\x2f\x73\x76\ -\x67\x3e\ -\x00\x00\x04\x26\ +\x3c\x70\x61\x74\x68\x20\x64\x3d\x22\x4d\x33\x32\x20\x36\x34\x63\ +\x31\x37\x2e\x37\x20\x30\x20\x33\x32\x20\x31\x34\x2e\x33\x20\x33\ +\x32\x20\x33\x32\x6c\x30\x20\x33\x32\x30\x63\x30\x20\x31\x37\x2e\ +\x37\x2d\x31\x34\x2e\x33\x20\x33\x32\x2d\x33\x32\x20\x33\x32\x73\ +\x2d\x33\x32\x2d\x31\x34\x2e\x33\x2d\x33\x32\x2d\x33\x32\x4c\x30\ +\x20\x39\x36\x43\x30\x20\x37\x38\x2e\x33\x20\x31\x34\x2e\x33\x20\ +\x36\x34\x20\x33\x32\x20\x36\x34\x7a\x6d\x32\x31\x34\x2e\x36\x20\ +\x37\x33\x2e\x34\x63\x31\x32\x2e\x35\x20\x31\x32\x2e\x35\x20\x31\ +\x32\x2e\x35\x20\x33\x32\x2e\x38\x20\x30\x20\x34\x35\x2e\x33\x4c\ +\x32\x30\x35\x2e\x33\x20\x32\x32\x34\x6c\x32\x32\x39\x2e\x35\x20\ +\x30\x2d\x34\x31\x2e\x34\x2d\x34\x31\x2e\x34\x63\x2d\x31\x32\x2e\ +\x35\x2d\x31\x32\x2e\x35\x2d\x31\x32\x2e\x35\x2d\x33\x32\x2e\x38\ +\x20\x30\x2d\x34\x35\x2e\x33\x73\x33\x32\x2e\x38\x2d\x31\x32\x2e\ +\x35\x20\x34\x35\x2e\x33\x20\x30\x6c\x39\x36\x20\x39\x36\x63\x31\ +\x32\x2e\x35\x20\x31\x32\x2e\x35\x20\x31\x32\x2e\x35\x20\x33\x32\ +\x2e\x38\x20\x30\x20\x34\x35\x2e\x33\x6c\x2d\x39\x36\x20\x39\x36\ +\x63\x2d\x31\x32\x2e\x35\x20\x31\x32\x2e\x35\x2d\x33\x32\x2e\x38\ +\x20\x31\x32\x2e\x35\x2d\x34\x35\x2e\x33\x20\x30\x73\x2d\x31\x32\ +\x2e\x35\x2d\x33\x32\x2e\x38\x20\x30\x2d\x34\x35\x2e\x33\x4c\x34\ +\x33\x34\x2e\x37\x20\x32\x38\x38\x6c\x2d\x32\x32\x39\x2e\x35\x20\ +\x30\x20\x34\x31\x2e\x34\x20\x34\x31\x2e\x34\x63\x31\x32\x2e\x35\ +\x20\x31\x32\x2e\x35\x20\x31\x32\x2e\x35\x20\x33\x32\x2e\x38\x20\ +\x30\x20\x34\x35\x2e\x33\x73\x2d\x33\x32\x2e\x38\x20\x31\x32\x2e\ +\x35\x2d\x34\x35\x2e\x33\x20\x30\x6c\x2d\x39\x36\x2d\x39\x36\x63\ +\x2d\x31\x32\x2e\x35\x2d\x31\x32\x2e\x35\x2d\x31\x32\x2e\x35\x2d\ +\x33\x32\x2e\x38\x20\x30\x2d\x34\x35\x2e\x33\x6c\x39\x36\x2d\x39\ +\x36\x63\x31\x32\x2e\x35\x2d\x31\x32\x2e\x35\x20\x33\x32\x2e\x38\ +\x2d\x31\x32\x2e\x35\x20\x34\x35\x2e\x33\x20\x30\x7a\x4d\x36\x34\ +\x30\x20\x39\x36\x6c\x30\x20\x33\x32\x30\x63\x30\x20\x31\x37\x2e\ +\x37\x2d\x31\x34\x2e\x33\x20\x33\x32\x2d\x33\x32\x20\x33\x32\x73\ +\x2d\x33\x32\x2d\x31\x34\x2e\x33\x2d\x33\x32\x2d\x33\x32\x6c\x30\ +\x2d\x33\x32\x30\x63\x30\x2d\x31\x37\x2e\x37\x20\x31\x34\x2e\x33\ +\x2d\x33\x32\x20\x33\x32\x2d\x33\x32\x73\x33\x32\x20\x31\x34\x2e\ +\x33\x20\x33\x32\x20\x33\x32\x7a\x22\x2f\x3e\x3c\x2f\x73\x76\x67\ +\x3e\ +\x00\x00\x03\x43\ \x3c\ \x73\x76\x67\x20\x78\x6d\x6c\x6e\x73\x3d\x22\x68\x74\x74\x70\x3a\ \x2f\x2f\x77\x77\x77\x2e\x77\x33\x2e\x6f\x72\x67\x2f\x32\x30\x30\ \x30\x2f\x73\x76\x67\x22\x20\x76\x69\x65\x77\x42\x6f\x78\x3d\x22\ -\x30\x20\x30\x20\x35\x31\x32\x20\x35\x31\x32\x22\x3e\x3c\x21\x2d\ +\x30\x20\x30\x20\x35\x37\x36\x20\x35\x31\x32\x22\x3e\x3c\x21\x2d\ \x2d\x21\x20\x46\x6f\x6e\x74\x20\x41\x77\x65\x73\x6f\x6d\x65\x20\ -\x46\x72\x65\x65\x20\x36\x2e\x35\x2e\x32\x20\x62\x79\x20\x40\x66\ +\x46\x72\x65\x65\x20\x36\x2e\x37\x2e\x32\x20\x62\x79\x20\x40\x66\ \x6f\x6e\x74\x61\x77\x65\x73\x6f\x6d\x65\x20\x2d\x20\x68\x74\x74\ \x70\x73\x3a\x2f\x2f\x66\x6f\x6e\x74\x61\x77\x65\x73\x6f\x6d\x65\ \x2e\x63\x6f\x6d\x20\x4c\x69\x63\x65\x6e\x73\x65\x20\x2d\x20\x68\ @@ -654,64 +447,50 @@ \x20\x4d\x49\x54\x20\x4c\x69\x63\x65\x6e\x73\x65\x29\x20\x43\x6f\ \x70\x79\x72\x69\x67\x68\x74\x20\x32\x30\x32\x34\x20\x46\x6f\x6e\ \x74\x69\x63\x6f\x6e\x73\x2c\x20\x49\x6e\x63\x2e\x20\x2d\x2d\x3e\ -\x3c\x70\x61\x74\x68\x20\x64\x3d\x22\x4d\x37\x38\x2e\x36\x20\x35\ -\x43\x36\x39\x2e\x31\x2d\x32\x2e\x34\x20\x35\x35\x2e\x36\x2d\x31\ -\x2e\x35\x20\x34\x37\x20\x37\x4c\x37\x20\x34\x37\x63\x2d\x38\x2e\ -\x35\x20\x38\x2e\x35\x2d\x39\x2e\x34\x20\x32\x32\x2d\x32\x2e\x31\ -\x20\x33\x31\x2e\x36\x6c\x38\x30\x20\x31\x30\x34\x63\x34\x2e\x35\ -\x20\x35\x2e\x39\x20\x31\x31\x2e\x36\x20\x39\x2e\x34\x20\x31\x39\ -\x20\x39\x2e\x34\x68\x35\x34\x2e\x31\x6c\x31\x30\x39\x20\x31\x30\ -\x39\x63\x2d\x31\x34\x2e\x37\x20\x32\x39\x2d\x31\x30\x20\x36\x35\ -\x2e\x34\x20\x31\x34\x2e\x33\x20\x38\x39\x2e\x36\x6c\x31\x31\x32\ -\x20\x31\x31\x32\x63\x31\x32\x2e\x35\x20\x31\x32\x2e\x35\x20\x33\ -\x32\x2e\x38\x20\x31\x32\x2e\x35\x20\x34\x35\x2e\x33\x20\x30\x6c\ -\x36\x34\x2d\x36\x34\x63\x31\x32\x2e\x35\x2d\x31\x32\x2e\x35\x20\ -\x31\x32\x2e\x35\x2d\x33\x32\x2e\x38\x20\x30\x2d\x34\x35\x2e\x33\ -\x6c\x2d\x31\x31\x32\x2d\x31\x31\x32\x63\x2d\x32\x34\x2e\x32\x2d\ -\x32\x34\x2e\x32\x2d\x36\x30\x2e\x36\x2d\x32\x39\x2d\x38\x39\x2e\ -\x36\x2d\x31\x34\x2e\x33\x6c\x2d\x31\x30\x39\x2d\x31\x30\x39\x56\ -\x31\x30\x34\x63\x30\x2d\x37\x2e\x35\x2d\x33\x2e\x35\x2d\x31\x34\ -\x2e\x35\x2d\x39\x2e\x34\x2d\x31\x39\x4c\x37\x38\x2e\x36\x20\x35\ -\x7a\x4d\x31\x39\x2e\x39\x20\x33\x39\x36\x2e\x31\x43\x37\x2e\x32\ -\x20\x34\x30\x38\x2e\x38\x20\x30\x20\x34\x32\x36\x2e\x31\x20\x30\ -\x20\x34\x34\x34\x2e\x31\x43\x30\x20\x34\x38\x31\x2e\x36\x20\x33\ -\x30\x2e\x34\x20\x35\x31\x32\x20\x36\x37\x2e\x39\x20\x35\x31\x32\ -\x63\x31\x38\x20\x30\x20\x33\x35\x2e\x33\x2d\x37\x2e\x32\x20\x34\ -\x38\x2d\x31\x39\x2e\x39\x4c\x32\x33\x33\x2e\x37\x20\x33\x37\x34\ -\x2e\x33\x63\x2d\x37\x2e\x38\x2d\x32\x30\x2e\x39\x2d\x39\x2d\x34\ -\x33\x2e\x36\x2d\x33\x2e\x36\x2d\x36\x35\x2e\x31\x6c\x2d\x36\x31\ -\x2e\x37\x2d\x36\x31\x2e\x37\x4c\x31\x39\x2e\x39\x20\x33\x39\x36\ -\x2e\x31\x7a\x4d\x35\x31\x32\x20\x31\x34\x34\x63\x30\x2d\x31\x30\ -\x2e\x35\x2d\x31\x2e\x31\x2d\x32\x30\x2e\x37\x2d\x33\x2e\x32\x2d\ -\x33\x30\x2e\x35\x63\x2d\x32\x2e\x34\x2d\x31\x31\x2e\x32\x2d\x31\ -\x36\x2e\x31\x2d\x31\x34\x2e\x31\x2d\x32\x34\x2e\x32\x2d\x36\x6c\ -\x2d\x36\x33\x2e\x39\x20\x36\x33\x2e\x39\x63\x2d\x33\x20\x33\x2d\ -\x37\x2e\x31\x20\x34\x2e\x37\x2d\x31\x31\x2e\x33\x20\x34\x2e\x37\ -\x48\x33\x35\x32\x63\x2d\x38\x2e\x38\x20\x30\x2d\x31\x36\x2d\x37\ -\x2e\x32\x2d\x31\x36\x2d\x31\x36\x56\x31\x30\x32\x2e\x36\x63\x30\ -\x2d\x34\x2e\x32\x20\x31\x2e\x37\x2d\x38\x2e\x33\x20\x34\x2e\x37\ -\x2d\x31\x31\x2e\x33\x6c\x36\x33\x2e\x39\x2d\x36\x33\x2e\x39\x63\ -\x38\x2e\x31\x2d\x38\x2e\x31\x20\x35\x2e\x32\x2d\x32\x31\x2e\x38\ -\x2d\x36\x2d\x32\x34\x2e\x32\x43\x33\x38\x38\x2e\x37\x20\x31\x2e\ -\x31\x20\x33\x37\x38\x2e\x35\x20\x30\x20\x33\x36\x38\x20\x30\x43\ -\x32\x38\x38\x2e\x35\x20\x30\x20\x32\x32\x34\x20\x36\x34\x2e\x35\ -\x20\x32\x32\x34\x20\x31\x34\x34\x6c\x30\x20\x2e\x38\x20\x38\x35\ -\x2e\x33\x20\x38\x35\x2e\x33\x63\x33\x36\x2d\x39\x2e\x31\x20\x37\ -\x35\x2e\x38\x20\x2e\x35\x20\x31\x30\x34\x20\x32\x38\x2e\x37\x4c\ -\x34\x32\x39\x20\x32\x37\x34\x2e\x35\x63\x34\x39\x2d\x32\x33\x20\ -\x38\x33\x2d\x37\x32\x2e\x38\x20\x38\x33\x2d\x31\x33\x30\x2e\x35\ -\x7a\x4d\x35\x36\x20\x34\x33\x32\x61\x32\x34\x20\x32\x34\x20\x30\ -\x20\x31\x20\x31\x20\x34\x38\x20\x30\x20\x32\x34\x20\x32\x34\x20\ -\x30\x20\x31\x20\x31\x20\x2d\x34\x38\x20\x30\x7a\x22\x2f\x3e\x3c\ -\x2f\x73\x76\x67\x3e\ -\x00\x00\x03\x2b\ +\x3c\x70\x61\x74\x68\x20\x64\x3d\x22\x4d\x35\x34\x33\x2e\x38\x20\ +\x32\x38\x37\x2e\x36\x63\x31\x37\x20\x30\x20\x33\x32\x2d\x31\x34\ +\x20\x33\x32\x2d\x33\x32\x2e\x31\x63\x31\x2d\x39\x2d\x33\x2d\x31\ +\x37\x2d\x31\x31\x2d\x32\x34\x4c\x35\x31\x32\x20\x31\x38\x35\x6c\ +\x30\x2d\x31\x32\x31\x63\x30\x2d\x31\x37\x2e\x37\x2d\x31\x34\x2e\ +\x33\x2d\x33\x32\x2d\x33\x32\x2d\x33\x32\x6c\x2d\x33\x32\x20\x30\ +\x63\x2d\x31\x37\x2e\x37\x20\x30\x2d\x33\x32\x20\x31\x34\x2e\x33\ +\x2d\x33\x32\x20\x33\x32\x6c\x30\x20\x33\x36\x2e\x37\x4c\x33\x30\ +\x39\x2e\x35\x20\x37\x63\x2d\x36\x2d\x35\x2d\x31\x34\x2d\x37\x2d\ +\x32\x31\x2d\x37\x73\x2d\x31\x35\x20\x31\x2d\x32\x32\x20\x38\x4c\ +\x31\x30\x20\x32\x33\x31\x2e\x35\x63\x2d\x37\x20\x37\x2d\x31\x30\ +\x20\x31\x35\x2d\x31\x30\x20\x32\x34\x63\x30\x20\x31\x38\x20\x31\ +\x34\x20\x33\x32\x2e\x31\x20\x33\x32\x20\x33\x32\x2e\x31\x6c\x33\ +\x32\x20\x30\x20\x30\x20\x36\x39\x2e\x37\x63\x2d\x2e\x31\x20\x2e\ +\x39\x2d\x2e\x31\x20\x31\x2e\x38\x2d\x2e\x31\x20\x32\x2e\x38\x6c\ +\x30\x20\x31\x31\x32\x63\x30\x20\x32\x32\x2e\x31\x20\x31\x37\x2e\ +\x39\x20\x34\x30\x20\x34\x30\x20\x34\x30\x6c\x31\x36\x20\x30\x63\ +\x31\x2e\x32\x20\x30\x20\x32\x2e\x34\x2d\x2e\x31\x20\x33\x2e\x36\ +\x2d\x2e\x32\x63\x31\x2e\x35\x20\x2e\x31\x20\x33\x20\x2e\x32\x20\ +\x34\x2e\x35\x20\x2e\x32\x6c\x33\x31\x2e\x39\x20\x30\x20\x32\x34\ +\x20\x30\x63\x32\x32\x2e\x31\x20\x30\x20\x34\x30\x2d\x31\x37\x2e\ +\x39\x20\x34\x30\x2d\x34\x30\x6c\x30\x2d\x32\x34\x20\x30\x2d\x36\ +\x34\x63\x30\x2d\x31\x37\x2e\x37\x20\x31\x34\x2e\x33\x2d\x33\x32\ +\x20\x33\x32\x2d\x33\x32\x6c\x36\x34\x20\x30\x63\x31\x37\x2e\x37\ +\x20\x30\x20\x33\x32\x20\x31\x34\x2e\x33\x20\x33\x32\x20\x33\x32\ +\x6c\x30\x20\x36\x34\x20\x30\x20\x32\x34\x63\x30\x20\x32\x32\x2e\ +\x31\x20\x31\x37\x2e\x39\x20\x34\x30\x20\x34\x30\x20\x34\x30\x6c\ +\x32\x34\x20\x30\x20\x33\x32\x2e\x35\x20\x30\x63\x31\x2e\x34\x20\ +\x30\x20\x32\x2e\x38\x20\x30\x20\x34\x2e\x32\x2d\x2e\x31\x63\x31\ +\x2e\x31\x20\x2e\x31\x20\x32\x2e\x32\x20\x2e\x31\x20\x33\x2e\x33\ +\x20\x2e\x31\x6c\x31\x36\x20\x30\x63\x32\x32\x2e\x31\x20\x30\x20\ +\x34\x30\x2d\x31\x37\x2e\x39\x20\x34\x30\x2d\x34\x30\x6c\x30\x2d\ +\x31\x36\x2e\x32\x63\x2e\x33\x2d\x32\x2e\x36\x20\x2e\x35\x2d\x35\ +\x2e\x33\x20\x2e\x35\x2d\x38\x2e\x31\x6c\x2d\x2e\x37\x2d\x31\x36\ +\x30\x2e\x32\x20\x33\x32\x20\x30\x7a\x22\x2f\x3e\x3c\x2f\x73\x76\ +\x67\x3e\ +\x00\x00\x03\x98\ \x3c\ \x73\x76\x67\x20\x78\x6d\x6c\x6e\x73\x3d\x22\x68\x74\x74\x70\x3a\ \x2f\x2f\x77\x77\x77\x2e\x77\x33\x2e\x6f\x72\x67\x2f\x32\x30\x30\ \x30\x2f\x73\x76\x67\x22\x20\x76\x69\x65\x77\x42\x6f\x78\x3d\x22\ -\x30\x20\x30\x20\x36\x34\x30\x20\x35\x31\x32\x22\x3e\x3c\x21\x2d\ +\x30\x20\x30\x20\x35\x37\x36\x20\x35\x31\x32\x22\x3e\x3c\x21\x2d\ \x2d\x21\x20\x46\x6f\x6e\x74\x20\x41\x77\x65\x73\x6f\x6d\x65\x20\ -\x46\x72\x65\x65\x20\x36\x2e\x35\x2e\x32\x20\x62\x79\x20\x40\x66\ +\x46\x72\x65\x65\x20\x36\x2e\x37\x2e\x32\x20\x62\x79\x20\x40\x66\ \x6f\x6e\x74\x61\x77\x65\x73\x6f\x6d\x65\x20\x2d\x20\x68\x74\x74\ \x70\x73\x3a\x2f\x2f\x66\x6f\x6e\x74\x61\x77\x65\x73\x6f\x6d\x65\ \x2e\x63\x6f\x6d\x20\x4c\x69\x63\x65\x6e\x73\x65\x20\x2d\x20\x68\ @@ -723,48 +502,55 @@ \x20\x4d\x49\x54\x20\x4c\x69\x63\x65\x6e\x73\x65\x29\x20\x43\x6f\ \x70\x79\x72\x69\x67\x68\x74\x20\x32\x30\x32\x34\x20\x46\x6f\x6e\ \x74\x69\x63\x6f\x6e\x73\x2c\x20\x49\x6e\x63\x2e\x20\x2d\x2d\x3e\ -\x3c\x70\x61\x74\x68\x20\x64\x3d\x22\x4d\x33\x32\x20\x36\x34\x63\ -\x31\x37\x2e\x37\x20\x30\x20\x33\x32\x20\x31\x34\x2e\x33\x20\x33\ -\x32\x20\x33\x32\x6c\x30\x20\x33\x32\x30\x63\x30\x20\x31\x37\x2e\ -\x37\x2d\x31\x34\x2e\x33\x20\x33\x32\x2d\x33\x32\x20\x33\x32\x73\ -\x2d\x33\x32\x2d\x31\x34\x2e\x33\x2d\x33\x32\x2d\x33\x32\x56\x39\ -\x36\x43\x30\x20\x37\x38\x2e\x33\x20\x31\x34\x2e\x33\x20\x36\x34\ -\x20\x33\x32\x20\x36\x34\x7a\x6d\x32\x31\x34\x2e\x36\x20\x37\x33\ -\x2e\x34\x63\x31\x32\x2e\x35\x20\x31\x32\x2e\x35\x20\x31\x32\x2e\ -\x35\x20\x33\x32\x2e\x38\x20\x30\x20\x34\x35\x2e\x33\x4c\x32\x30\ -\x35\x2e\x33\x20\x32\x32\x34\x6c\x32\x32\x39\x2e\x35\x20\x30\x2d\ -\x34\x31\x2e\x34\x2d\x34\x31\x2e\x34\x63\x2d\x31\x32\x2e\x35\x2d\ -\x31\x32\x2e\x35\x2d\x31\x32\x2e\x35\x2d\x33\x32\x2e\x38\x20\x30\ -\x2d\x34\x35\x2e\x33\x73\x33\x32\x2e\x38\x2d\x31\x32\x2e\x35\x20\ -\x34\x35\x2e\x33\x20\x30\x6c\x39\x36\x20\x39\x36\x63\x31\x32\x2e\ -\x35\x20\x31\x32\x2e\x35\x20\x31\x32\x2e\x35\x20\x33\x32\x2e\x38\ -\x20\x30\x20\x34\x35\x2e\x33\x6c\x2d\x39\x36\x20\x39\x36\x63\x2d\ -\x31\x32\x2e\x35\x20\x31\x32\x2e\x35\x2d\x33\x32\x2e\x38\x20\x31\ -\x32\x2e\x35\x2d\x34\x35\x2e\x33\x20\x30\x73\x2d\x31\x32\x2e\x35\ -\x2d\x33\x32\x2e\x38\x20\x30\x2d\x34\x35\x2e\x33\x4c\x34\x33\x34\ -\x2e\x37\x20\x32\x38\x38\x6c\x2d\x32\x32\x39\x2e\x35\x20\x30\x20\ -\x34\x31\x2e\x34\x20\x34\x31\x2e\x34\x63\x31\x32\x2e\x35\x20\x31\ -\x32\x2e\x35\x20\x31\x32\x2e\x35\x20\x33\x32\x2e\x38\x20\x30\x20\ -\x34\x35\x2e\x33\x73\x2d\x33\x32\x2e\x38\x20\x31\x32\x2e\x35\x2d\ -\x34\x35\x2e\x33\x20\x30\x6c\x2d\x39\x36\x2d\x39\x36\x63\x2d\x31\ -\x32\x2e\x35\x2d\x31\x32\x2e\x35\x2d\x31\x32\x2e\x35\x2d\x33\x32\ -\x2e\x38\x20\x30\x2d\x34\x35\x2e\x33\x6c\x39\x36\x2d\x39\x36\x63\ -\x31\x32\x2e\x35\x2d\x31\x32\x2e\x35\x20\x33\x32\x2e\x38\x2d\x31\ -\x32\x2e\x35\x20\x34\x35\x2e\x33\x20\x30\x7a\x4d\x36\x34\x30\x20\ -\x39\x36\x56\x34\x31\x36\x63\x30\x20\x31\x37\x2e\x37\x2d\x31\x34\ -\x2e\x33\x20\x33\x32\x2d\x33\x32\x20\x33\x32\x73\x2d\x33\x32\x2d\ -\x31\x34\x2e\x33\x2d\x33\x32\x2d\x33\x32\x56\x39\x36\x63\x30\x2d\ -\x31\x37\x2e\x37\x20\x31\x34\x2e\x33\x2d\x33\x32\x20\x33\x32\x2d\ -\x33\x32\x73\x33\x32\x20\x31\x34\x2e\x33\x20\x33\x32\x20\x33\x32\ -\x7a\x22\x2f\x3e\x3c\x2f\x73\x76\x67\x3e\ -\x00\x00\x03\x89\ -\x3c\ +\x3c\x70\x61\x74\x68\x20\x64\x3d\x22\x4d\x32\x36\x34\x2e\x35\x20\ +\x35\x2e\x32\x63\x31\x34\x2e\x39\x2d\x36\x2e\x39\x20\x33\x32\x2e\ +\x31\x2d\x36\x2e\x39\x20\x34\x37\x20\x30\x6c\x32\x31\x38\x2e\x36\ +\x20\x31\x30\x31\x63\x38\x2e\x35\x20\x33\x2e\x39\x20\x31\x33\x2e\ +\x39\x20\x31\x32\x2e\x34\x20\x31\x33\x2e\x39\x20\x32\x31\x2e\x38\ +\x73\x2d\x35\x2e\x34\x20\x31\x37\x2e\x39\x2d\x31\x33\x2e\x39\x20\ +\x32\x31\x2e\x38\x6c\x2d\x32\x31\x38\x2e\x36\x20\x31\x30\x31\x63\ +\x2d\x31\x34\x2e\x39\x20\x36\x2e\x39\x2d\x33\x32\x2e\x31\x20\x36\ +\x2e\x39\x2d\x34\x37\x20\x30\x4c\x34\x35\x2e\x39\x20\x31\x34\x39\ +\x2e\x38\x43\x33\x37\x2e\x34\x20\x31\x34\x35\x2e\x38\x20\x33\x32\ +\x20\x31\x33\x37\x2e\x33\x20\x33\x32\x20\x31\x32\x38\x73\x35\x2e\ +\x34\x2d\x31\x37\x2e\x39\x20\x31\x33\x2e\x39\x2d\x32\x31\x2e\x38\ +\x4c\x32\x36\x34\x2e\x35\x20\x35\x2e\x32\x7a\x4d\x34\x37\x36\x2e\ +\x39\x20\x32\x30\x39\x2e\x36\x6c\x35\x33\x2e\x32\x20\x32\x34\x2e\ +\x36\x63\x38\x2e\x35\x20\x33\x2e\x39\x20\x31\x33\x2e\x39\x20\x31\ +\x32\x2e\x34\x20\x31\x33\x2e\x39\x20\x32\x31\x2e\x38\x73\x2d\x35\ +\x2e\x34\x20\x31\x37\x2e\x39\x2d\x31\x33\x2e\x39\x20\x32\x31\x2e\ +\x38\x6c\x2d\x32\x31\x38\x2e\x36\x20\x31\x30\x31\x63\x2d\x31\x34\ +\x2e\x39\x20\x36\x2e\x39\x2d\x33\x32\x2e\x31\x20\x36\x2e\x39\x2d\ +\x34\x37\x20\x30\x4c\x34\x35\x2e\x39\x20\x32\x37\x37\x2e\x38\x43\ +\x33\x37\x2e\x34\x20\x32\x37\x33\x2e\x38\x20\x33\x32\x20\x32\x36\ +\x35\x2e\x33\x20\x33\x32\x20\x32\x35\x36\x73\x35\x2e\x34\x2d\x31\ +\x37\x2e\x39\x20\x31\x33\x2e\x39\x2d\x32\x31\x2e\x38\x6c\x35\x33\ +\x2e\x32\x2d\x32\x34\x2e\x36\x20\x31\x35\x32\x20\x37\x30\x2e\x32\ +\x63\x32\x33\x2e\x34\x20\x31\x30\x2e\x38\x20\x35\x30\x2e\x34\x20\ +\x31\x30\x2e\x38\x20\x37\x33\x2e\x38\x20\x30\x6c\x31\x35\x32\x2d\ +\x37\x30\x2e\x32\x7a\x6d\x2d\x31\x35\x32\x20\x31\x39\x38\x2e\x32\ +\x6c\x31\x35\x32\x2d\x37\x30\x2e\x32\x20\x35\x33\x2e\x32\x20\x32\ +\x34\x2e\x36\x63\x38\x2e\x35\x20\x33\x2e\x39\x20\x31\x33\x2e\x39\ +\x20\x31\x32\x2e\x34\x20\x31\x33\x2e\x39\x20\x32\x31\x2e\x38\x73\ +\x2d\x35\x2e\x34\x20\x31\x37\x2e\x39\x2d\x31\x33\x2e\x39\x20\x32\ +\x31\x2e\x38\x6c\x2d\x32\x31\x38\x2e\x36\x20\x31\x30\x31\x63\x2d\ +\x31\x34\x2e\x39\x20\x36\x2e\x39\x2d\x33\x32\x2e\x31\x20\x36\x2e\ +\x39\x2d\x34\x37\x20\x30\x4c\x34\x35\x2e\x39\x20\x34\x30\x35\x2e\ +\x38\x43\x33\x37\x2e\x34\x20\x34\x30\x31\x2e\x38\x20\x33\x32\x20\ +\x33\x39\x33\x2e\x33\x20\x33\x32\x20\x33\x38\x34\x73\x35\x2e\x34\ +\x2d\x31\x37\x2e\x39\x20\x31\x33\x2e\x39\x2d\x32\x31\x2e\x38\x6c\ +\x35\x33\x2e\x32\x2d\x32\x34\x2e\x36\x20\x31\x35\x32\x20\x37\x30\ +\x2e\x32\x63\x32\x33\x2e\x34\x20\x31\x30\x2e\x38\x20\x35\x30\x2e\ +\x34\x20\x31\x30\x2e\x38\x20\x37\x33\x2e\x38\x20\x30\x7a\x22\x2f\ +\x3e\x3c\x2f\x73\x76\x67\x3e\ +\x00\x00\x03\xb3\ +\x3c\ \x73\x76\x67\x20\x78\x6d\x6c\x6e\x73\x3d\x22\x68\x74\x74\x70\x3a\ \x2f\x2f\x77\x77\x77\x2e\x77\x33\x2e\x6f\x72\x67\x2f\x32\x30\x30\ \x30\x2f\x73\x76\x67\x22\x20\x76\x69\x65\x77\x42\x6f\x78\x3d\x22\ \x30\x20\x30\x20\x36\x34\x30\x20\x35\x31\x32\x22\x3e\x3c\x21\x2d\ \x2d\x21\x20\x46\x6f\x6e\x74\x20\x41\x77\x65\x73\x6f\x6d\x65\x20\ -\x46\x72\x65\x65\x20\x36\x2e\x35\x2e\x32\x20\x62\x79\x20\x40\x66\ +\x46\x72\x65\x65\x20\x36\x2e\x37\x2e\x32\x20\x62\x79\x20\x40\x66\ \x6f\x6e\x74\x61\x77\x65\x73\x6f\x6d\x65\x20\x2d\x20\x68\x74\x74\ \x70\x73\x3a\x2f\x2f\x66\x6f\x6e\x74\x61\x77\x65\x73\x6f\x6d\x65\ \x2e\x63\x6f\x6d\x20\x4c\x69\x63\x65\x6e\x73\x65\x20\x2d\x20\x68\ @@ -778,52 +564,55 @@ \x74\x69\x63\x6f\x6e\x73\x2c\x20\x49\x6e\x63\x2e\x20\x2d\x2d\x3e\ \x3c\x70\x61\x74\x68\x20\x64\x3d\x22\x4d\x33\x32\x30\x20\x30\x63\ \x31\x37\x2e\x37\x20\x30\x20\x33\x32\x20\x31\x34\x2e\x33\x20\x33\ -\x32\x20\x33\x32\x56\x39\x36\x48\x34\x37\x32\x63\x33\x39\x2e\x38\ -\x20\x30\x20\x37\x32\x20\x33\x32\x2e\x32\x20\x37\x32\x20\x37\x32\ -\x56\x34\x34\x30\x63\x30\x20\x33\x39\x2e\x38\x2d\x33\x32\x2e\x32\ -\x20\x37\x32\x2d\x37\x32\x20\x37\x32\x48\x31\x36\x38\x63\x2d\x33\ -\x39\x2e\x38\x20\x30\x2d\x37\x32\x2d\x33\x32\x2e\x32\x2d\x37\x32\ -\x2d\x37\x32\x56\x31\x36\x38\x63\x30\x2d\x33\x39\x2e\x38\x20\x33\ -\x32\x2e\x32\x2d\x37\x32\x20\x37\x32\x2d\x37\x32\x48\x32\x38\x38\ -\x56\x33\x32\x63\x30\x2d\x31\x37\x2e\x37\x20\x31\x34\x2e\x33\x2d\ -\x33\x32\x20\x33\x32\x2d\x33\x32\x7a\x4d\x32\x30\x38\x20\x33\x38\ -\x34\x63\x2d\x38\x2e\x38\x20\x30\x2d\x31\x36\x20\x37\x2e\x32\x2d\ -\x31\x36\x20\x31\x36\x73\x37\x2e\x32\x20\x31\x36\x20\x31\x36\x20\ -\x31\x36\x68\x33\x32\x63\x38\x2e\x38\x20\x30\x20\x31\x36\x2d\x37\ -\x2e\x32\x20\x31\x36\x2d\x31\x36\x73\x2d\x37\x2e\x32\x2d\x31\x36\ -\x2d\x31\x36\x2d\x31\x36\x48\x32\x30\x38\x7a\x6d\x39\x36\x20\x30\ +\x32\x20\x33\x32\x6c\x30\x20\x36\x34\x20\x31\x32\x30\x20\x30\x63\ +\x33\x39\x2e\x38\x20\x30\x20\x37\x32\x20\x33\x32\x2e\x32\x20\x37\ +\x32\x20\x37\x32\x6c\x30\x20\x32\x37\x32\x63\x30\x20\x33\x39\x2e\ +\x38\x2d\x33\x32\x2e\x32\x20\x37\x32\x2d\x37\x32\x20\x37\x32\x6c\ +\x2d\x33\x30\x34\x20\x30\x63\x2d\x33\x39\x2e\x38\x20\x30\x2d\x37\ +\x32\x2d\x33\x32\x2e\x32\x2d\x37\x32\x2d\x37\x32\x6c\x30\x2d\x32\ +\x37\x32\x63\x30\x2d\x33\x39\x2e\x38\x20\x33\x32\x2e\x32\x2d\x37\ +\x32\x20\x37\x32\x2d\x37\x32\x6c\x31\x32\x30\x20\x30\x20\x30\x2d\ +\x36\x34\x63\x30\x2d\x31\x37\x2e\x37\x20\x31\x34\x2e\x33\x2d\x33\ +\x32\x20\x33\x32\x2d\x33\x32\x7a\x4d\x32\x30\x38\x20\x33\x38\x34\ \x63\x2d\x38\x2e\x38\x20\x30\x2d\x31\x36\x20\x37\x2e\x32\x2d\x31\ \x36\x20\x31\x36\x73\x37\x2e\x32\x20\x31\x36\x20\x31\x36\x20\x31\ -\x36\x68\x33\x32\x63\x38\x2e\x38\x20\x30\x20\x31\x36\x2d\x37\x2e\ -\x32\x20\x31\x36\x2d\x31\x36\x73\x2d\x37\x2e\x32\x2d\x31\x36\x2d\ -\x31\x36\x2d\x31\x36\x48\x33\x30\x34\x7a\x6d\x39\x36\x20\x30\x63\ -\x2d\x38\x2e\x38\x20\x30\x2d\x31\x36\x20\x37\x2e\x32\x2d\x31\x36\ -\x20\x31\x36\x73\x37\x2e\x32\x20\x31\x36\x20\x31\x36\x20\x31\x36\ -\x68\x33\x32\x63\x38\x2e\x38\x20\x30\x20\x31\x36\x2d\x37\x2e\x32\ -\x20\x31\x36\x2d\x31\x36\x73\x2d\x37\x2e\x32\x2d\x31\x36\x2d\x31\ -\x36\x2d\x31\x36\x48\x34\x30\x30\x7a\x4d\x32\x36\x34\x20\x32\x35\ -\x36\x61\x34\x30\x20\x34\x30\x20\x30\x20\x31\x20\x30\x20\x2d\x38\ -\x30\x20\x30\x20\x34\x30\x20\x34\x30\x20\x30\x20\x31\x20\x30\x20\ -\x38\x30\x20\x30\x7a\x6d\x31\x35\x32\x20\x34\x30\x61\x34\x30\x20\ -\x34\x30\x20\x30\x20\x31\x20\x30\x20\x30\x2d\x38\x30\x20\x34\x30\ -\x20\x34\x30\x20\x30\x20\x31\x20\x30\x20\x30\x20\x38\x30\x7a\x4d\ -\x34\x38\x20\x32\x32\x34\x48\x36\x34\x56\x34\x31\x36\x48\x34\x38\ -\x63\x2d\x32\x36\x2e\x35\x20\x30\x2d\x34\x38\x2d\x32\x31\x2e\x35\ -\x2d\x34\x38\x2d\x34\x38\x56\x32\x37\x32\x63\x30\x2d\x32\x36\x2e\ -\x35\x20\x32\x31\x2e\x35\x2d\x34\x38\x20\x34\x38\x2d\x34\x38\x7a\ -\x6d\x35\x34\x34\x20\x30\x63\x32\x36\x2e\x35\x20\x30\x20\x34\x38\ -\x20\x32\x31\x2e\x35\x20\x34\x38\x20\x34\x38\x76\x39\x36\x63\x30\ -\x20\x32\x36\x2e\x35\x2d\x32\x31\x2e\x35\x20\x34\x38\x2d\x34\x38\ -\x20\x34\x38\x48\x35\x37\x36\x56\x32\x32\x34\x68\x31\x36\x7a\x22\ -\x2f\x3e\x3c\x2f\x73\x76\x67\x3e\ -\x00\x00\x05\x2e\ +\x36\x6c\x33\x32\x20\x30\x63\x38\x2e\x38\x20\x30\x20\x31\x36\x2d\ +\x37\x2e\x32\x20\x31\x36\x2d\x31\x36\x73\x2d\x37\x2e\x32\x2d\x31\ +\x36\x2d\x31\x36\x2d\x31\x36\x6c\x2d\x33\x32\x20\x30\x7a\x6d\x39\ +\x36\x20\x30\x63\x2d\x38\x2e\x38\x20\x30\x2d\x31\x36\x20\x37\x2e\ +\x32\x2d\x31\x36\x20\x31\x36\x73\x37\x2e\x32\x20\x31\x36\x20\x31\ +\x36\x20\x31\x36\x6c\x33\x32\x20\x30\x63\x38\x2e\x38\x20\x30\x20\ +\x31\x36\x2d\x37\x2e\x32\x20\x31\x36\x2d\x31\x36\x73\x2d\x37\x2e\ +\x32\x2d\x31\x36\x2d\x31\x36\x2d\x31\x36\x6c\x2d\x33\x32\x20\x30\ +\x7a\x6d\x39\x36\x20\x30\x63\x2d\x38\x2e\x38\x20\x30\x2d\x31\x36\ +\x20\x37\x2e\x32\x2d\x31\x36\x20\x31\x36\x73\x37\x2e\x32\x20\x31\ +\x36\x20\x31\x36\x20\x31\x36\x6c\x33\x32\x20\x30\x63\x38\x2e\x38\ +\x20\x30\x20\x31\x36\x2d\x37\x2e\x32\x20\x31\x36\x2d\x31\x36\x73\ +\x2d\x37\x2e\x32\x2d\x31\x36\x2d\x31\x36\x2d\x31\x36\x6c\x2d\x33\ +\x32\x20\x30\x7a\x4d\x32\x36\x34\x20\x32\x35\x36\x61\x34\x30\x20\ +\x34\x30\x20\x30\x20\x31\x20\x30\x20\x2d\x38\x30\x20\x30\x20\x34\ +\x30\x20\x34\x30\x20\x30\x20\x31\x20\x30\x20\x38\x30\x20\x30\x7a\ +\x6d\x31\x35\x32\x20\x34\x30\x61\x34\x30\x20\x34\x30\x20\x30\x20\ +\x31\x20\x30\x20\x30\x2d\x38\x30\x20\x34\x30\x20\x34\x30\x20\x30\ +\x20\x31\x20\x30\x20\x30\x20\x38\x30\x7a\x4d\x34\x38\x20\x32\x32\ +\x34\x6c\x31\x36\x20\x30\x20\x30\x20\x31\x39\x32\x2d\x31\x36\x20\ +\x30\x63\x2d\x32\x36\x2e\x35\x20\x30\x2d\x34\x38\x2d\x32\x31\x2e\ +\x35\x2d\x34\x38\x2d\x34\x38\x6c\x30\x2d\x39\x36\x63\x30\x2d\x32\ +\x36\x2e\x35\x20\x32\x31\x2e\x35\x2d\x34\x38\x20\x34\x38\x2d\x34\ +\x38\x7a\x6d\x35\x34\x34\x20\x30\x63\x32\x36\x2e\x35\x20\x30\x20\ +\x34\x38\x20\x32\x31\x2e\x35\x20\x34\x38\x20\x34\x38\x6c\x30\x20\ +\x39\x36\x63\x30\x20\x32\x36\x2e\x35\x2d\x32\x31\x2e\x35\x20\x34\ +\x38\x2d\x34\x38\x20\x34\x38\x6c\x2d\x31\x36\x20\x30\x20\x30\x2d\ +\x31\x39\x32\x20\x31\x36\x20\x30\x7a\x22\x2f\x3e\x3c\x2f\x73\x76\ +\x67\x3e\ +\x00\x00\x02\x88\ \x3c\ \x73\x76\x67\x20\x78\x6d\x6c\x6e\x73\x3d\x22\x68\x74\x74\x70\x3a\ \x2f\x2f\x77\x77\x77\x2e\x77\x33\x2e\x6f\x72\x67\x2f\x32\x30\x30\ \x30\x2f\x73\x76\x67\x22\x20\x76\x69\x65\x77\x42\x6f\x78\x3d\x22\ \x30\x20\x30\x20\x35\x31\x32\x20\x35\x31\x32\x22\x3e\x3c\x21\x2d\ \x2d\x21\x20\x46\x6f\x6e\x74\x20\x41\x77\x65\x73\x6f\x6d\x65\x20\ -\x46\x72\x65\x65\x20\x36\x2e\x35\x2e\x32\x20\x62\x79\x20\x40\x66\ +\x46\x72\x65\x65\x20\x36\x2e\x37\x2e\x32\x20\x62\x79\x20\x40\x66\ \x6f\x6e\x74\x61\x77\x65\x73\x6f\x6d\x65\x20\x2d\x20\x68\x74\x74\ \x70\x73\x3a\x2f\x2f\x66\x6f\x6e\x74\x61\x77\x65\x73\x6f\x6d\x65\ \x2e\x63\x6f\x6d\x20\x4c\x69\x63\x65\x6e\x73\x65\x20\x2d\x20\x68\ @@ -835,72 +624,30 @@ \x20\x4d\x49\x54\x20\x4c\x69\x63\x65\x6e\x73\x65\x29\x20\x43\x6f\ \x70\x79\x72\x69\x67\x68\x74\x20\x32\x30\x32\x34\x20\x46\x6f\x6e\ \x74\x69\x63\x6f\x6e\x73\x2c\x20\x49\x6e\x63\x2e\x20\x2d\x2d\x3e\ -\x3c\x70\x61\x74\x68\x20\x64\x3d\x22\x4d\x34\x39\x35\x2e\x39\x20\ -\x31\x36\x36\x2e\x36\x63\x33\x2e\x32\x20\x38\x2e\x37\x20\x2e\x35\ -\x20\x31\x38\x2e\x34\x2d\x36\x2e\x34\x20\x32\x34\x2e\x36\x6c\x2d\ -\x34\x33\x2e\x33\x20\x33\x39\x2e\x34\x63\x31\x2e\x31\x20\x38\x2e\ -\x33\x20\x31\x2e\x37\x20\x31\x36\x2e\x38\x20\x31\x2e\x37\x20\x32\ -\x35\x2e\x34\x73\x2d\x2e\x36\x20\x31\x37\x2e\x31\x2d\x31\x2e\x37\ -\x20\x32\x35\x2e\x34\x6c\x34\x33\x2e\x33\x20\x33\x39\x2e\x34\x63\ -\x36\x2e\x39\x20\x36\x2e\x32\x20\x39\x2e\x36\x20\x31\x35\x2e\x39\ -\x20\x36\x2e\x34\x20\x32\x34\x2e\x36\x63\x2d\x34\x2e\x34\x20\x31\ -\x31\x2e\x39\x2d\x39\x2e\x37\x20\x32\x33\x2e\x33\x2d\x31\x35\x2e\ -\x38\x20\x33\x34\x2e\x33\x6c\x2d\x34\x2e\x37\x20\x38\x2e\x31\x63\ -\x2d\x36\x2e\x36\x20\x31\x31\x2d\x31\x34\x20\x32\x31\x2e\x34\x2d\ -\x32\x32\x2e\x31\x20\x33\x31\x2e\x32\x63\x2d\x35\x2e\x39\x20\x37\ -\x2e\x32\x2d\x31\x35\x2e\x37\x20\x39\x2e\x36\x2d\x32\x34\x2e\x35\ -\x20\x36\x2e\x38\x6c\x2d\x35\x35\x2e\x37\x2d\x31\x37\x2e\x37\x63\ -\x2d\x31\x33\x2e\x34\x20\x31\x30\x2e\x33\x2d\x32\x38\x2e\x32\x20\ -\x31\x38\x2e\x39\x2d\x34\x34\x20\x32\x35\x2e\x34\x6c\x2d\x31\x32\ -\x2e\x35\x20\x35\x37\x2e\x31\x63\x2d\x32\x20\x39\x2e\x31\x2d\x39\ -\x20\x31\x36\x2e\x33\x2d\x31\x38\x2e\x32\x20\x31\x37\x2e\x38\x63\ -\x2d\x31\x33\x2e\x38\x20\x32\x2e\x33\x2d\x32\x38\x20\x33\x2e\x35\ -\x2d\x34\x32\x2e\x35\x20\x33\x2e\x35\x73\x2d\x32\x38\x2e\x37\x2d\ -\x31\x2e\x32\x2d\x34\x32\x2e\x35\x2d\x33\x2e\x35\x63\x2d\x39\x2e\ -\x32\x2d\x31\x2e\x35\x2d\x31\x36\x2e\x32\x2d\x38\x2e\x37\x2d\x31\ -\x38\x2e\x32\x2d\x31\x37\x2e\x38\x6c\x2d\x31\x32\x2e\x35\x2d\x35\ -\x37\x2e\x31\x63\x2d\x31\x35\x2e\x38\x2d\x36\x2e\x35\x2d\x33\x30\ -\x2e\x36\x2d\x31\x35\x2e\x31\x2d\x34\x34\x2d\x32\x35\x2e\x34\x4c\ -\x38\x33\x2e\x31\x20\x34\x32\x35\x2e\x39\x63\x2d\x38\x2e\x38\x20\ -\x32\x2e\x38\x2d\x31\x38\x2e\x36\x20\x2e\x33\x2d\x32\x34\x2e\x35\ -\x2d\x36\x2e\x38\x63\x2d\x38\x2e\x31\x2d\x39\x2e\x38\x2d\x31\x35\ -\x2e\x35\x2d\x32\x30\x2e\x32\x2d\x32\x32\x2e\x31\x2d\x33\x31\x2e\ -\x32\x6c\x2d\x34\x2e\x37\x2d\x38\x2e\x31\x63\x2d\x36\x2e\x31\x2d\ -\x31\x31\x2d\x31\x31\x2e\x34\x2d\x32\x32\x2e\x34\x2d\x31\x35\x2e\ -\x38\x2d\x33\x34\x2e\x33\x63\x2d\x33\x2e\x32\x2d\x38\x2e\x37\x2d\ -\x2e\x35\x2d\x31\x38\x2e\x34\x20\x36\x2e\x34\x2d\x32\x34\x2e\x36\ -\x6c\x34\x33\x2e\x33\x2d\x33\x39\x2e\x34\x43\x36\x34\x2e\x36\x20\ -\x32\x37\x33\x2e\x31\x20\x36\x34\x20\x32\x36\x34\x2e\x36\x20\x36\ -\x34\x20\x32\x35\x36\x73\x2e\x36\x2d\x31\x37\x2e\x31\x20\x31\x2e\ -\x37\x2d\x32\x35\x2e\x34\x4c\x32\x32\x2e\x34\x20\x31\x39\x31\x2e\ -\x32\x63\x2d\x36\x2e\x39\x2d\x36\x2e\x32\x2d\x39\x2e\x36\x2d\x31\ -\x35\x2e\x39\x2d\x36\x2e\x34\x2d\x32\x34\x2e\x36\x63\x34\x2e\x34\ -\x2d\x31\x31\x2e\x39\x20\x39\x2e\x37\x2d\x32\x33\x2e\x33\x20\x31\ -\x35\x2e\x38\x2d\x33\x34\x2e\x33\x6c\x34\x2e\x37\x2d\x38\x2e\x31\ -\x63\x36\x2e\x36\x2d\x31\x31\x20\x31\x34\x2d\x32\x31\x2e\x34\x20\ -\x32\x32\x2e\x31\x2d\x33\x31\x2e\x32\x63\x35\x2e\x39\x2d\x37\x2e\ -\x32\x20\x31\x35\x2e\x37\x2d\x39\x2e\x36\x20\x32\x34\x2e\x35\x2d\ -\x36\x2e\x38\x6c\x35\x35\x2e\x37\x20\x31\x37\x2e\x37\x63\x31\x33\ -\x2e\x34\x2d\x31\x30\x2e\x33\x20\x32\x38\x2e\x32\x2d\x31\x38\x2e\ -\x39\x20\x34\x34\x2d\x32\x35\x2e\x34\x6c\x31\x32\x2e\x35\x2d\x35\ -\x37\x2e\x31\x63\x32\x2d\x39\x2e\x31\x20\x39\x2d\x31\x36\x2e\x33\ -\x20\x31\x38\x2e\x32\x2d\x31\x37\x2e\x38\x43\x32\x32\x37\x2e\x33\ -\x20\x31\x2e\x32\x20\x32\x34\x31\x2e\x35\x20\x30\x20\x32\x35\x36\ -\x20\x30\x73\x32\x38\x2e\x37\x20\x31\x2e\x32\x20\x34\x32\x2e\x35\ -\x20\x33\x2e\x35\x63\x39\x2e\x32\x20\x31\x2e\x35\x20\x31\x36\x2e\ -\x32\x20\x38\x2e\x37\x20\x31\x38\x2e\x32\x20\x31\x37\x2e\x38\x6c\ -\x31\x32\x2e\x35\x20\x35\x37\x2e\x31\x63\x31\x35\x2e\x38\x20\x36\ -\x2e\x35\x20\x33\x30\x2e\x36\x20\x31\x35\x2e\x31\x20\x34\x34\x20\ -\x32\x35\x2e\x34\x6c\x35\x35\x2e\x37\x2d\x31\x37\x2e\x37\x63\x38\ -\x2e\x38\x2d\x32\x2e\x38\x20\x31\x38\x2e\x36\x2d\x2e\x33\x20\x32\ -\x34\x2e\x35\x20\x36\x2e\x38\x63\x38\x2e\x31\x20\x39\x2e\x38\x20\ -\x31\x35\x2e\x35\x20\x32\x30\x2e\x32\x20\x32\x32\x2e\x31\x20\x33\ -\x31\x2e\x32\x6c\x34\x2e\x37\x20\x38\x2e\x31\x63\x36\x2e\x31\x20\ -\x31\x31\x20\x31\x31\x2e\x34\x20\x32\x32\x2e\x34\x20\x31\x35\x2e\ -\x38\x20\x33\x34\x2e\x33\x7a\x4d\x32\x35\x36\x20\x33\x33\x36\x61\ -\x38\x30\x20\x38\x30\x20\x30\x20\x31\x20\x30\x20\x30\x2d\x31\x36\ -\x30\x20\x38\x30\x20\x38\x30\x20\x30\x20\x31\x20\x30\x20\x30\x20\ -\x31\x36\x30\x7a\x22\x2f\x3e\x3c\x2f\x73\x76\x67\x3e\ +\x3c\x70\x61\x74\x68\x20\x64\x3d\x22\x4d\x36\x34\x20\x36\x34\x63\ +\x30\x2d\x31\x37\x2e\x37\x2d\x31\x34\x2e\x33\x2d\x33\x32\x2d\x33\ +\x32\x2d\x33\x32\x53\x30\x20\x34\x36\x2e\x33\x20\x30\x20\x36\x34\ +\x4c\x30\x20\x34\x30\x30\x63\x30\x20\x34\x34\x2e\x32\x20\x33\x35\ +\x2e\x38\x20\x38\x30\x20\x38\x30\x20\x38\x30\x6c\x34\x30\x30\x20\ +\x30\x63\x31\x37\x2e\x37\x20\x30\x20\x33\x32\x2d\x31\x34\x2e\x33\ +\x20\x33\x32\x2d\x33\x32\x73\x2d\x31\x34\x2e\x33\x2d\x33\x32\x2d\ +\x33\x32\x2d\x33\x32\x4c\x38\x30\x20\x34\x31\x36\x63\x2d\x38\x2e\ +\x38\x20\x30\x2d\x31\x36\x2d\x37\x2e\x32\x2d\x31\x36\x2d\x31\x36\ +\x4c\x36\x34\x20\x36\x34\x7a\x6d\x34\x30\x36\x2e\x36\x20\x38\x36\ +\x2e\x36\x63\x31\x32\x2e\x35\x2d\x31\x32\x2e\x35\x20\x31\x32\x2e\ +\x35\x2d\x33\x32\x2e\x38\x20\x30\x2d\x34\x35\x2e\x33\x73\x2d\x33\ +\x32\x2e\x38\x2d\x31\x32\x2e\x35\x2d\x34\x35\x2e\x33\x20\x30\x4c\ +\x33\x32\x30\x20\x32\x31\x30\x2e\x37\x6c\x2d\x35\x37\x2e\x34\x2d\ +\x35\x37\x2e\x34\x63\x2d\x31\x32\x2e\x35\x2d\x31\x32\x2e\x35\x2d\ +\x33\x32\x2e\x38\x2d\x31\x32\x2e\x35\x2d\x34\x35\x2e\x33\x20\x30\ +\x6c\x2d\x31\x31\x32\x20\x31\x31\x32\x63\x2d\x31\x32\x2e\x35\x20\ +\x31\x32\x2e\x35\x2d\x31\x32\x2e\x35\x20\x33\x32\x2e\x38\x20\x30\ +\x20\x34\x35\x2e\x33\x73\x33\x32\x2e\x38\x20\x31\x32\x2e\x35\x20\ +\x34\x35\x2e\x33\x20\x30\x4c\x32\x34\x30\x20\x32\x32\x31\x2e\x33\ +\x6c\x35\x37\x2e\x34\x20\x35\x37\x2e\x34\x63\x31\x32\x2e\x35\x20\ +\x31\x32\x2e\x35\x20\x33\x32\x2e\x38\x20\x31\x32\x2e\x35\x20\x34\ +\x35\x2e\x33\x20\x30\x6c\x31\x32\x38\x2d\x31\x32\x38\x7a\x22\x2f\ +\x3e\x3c\x2f\x73\x76\x67\x3e\ \x00\x00\x18\xaa\ \x3c\ \x3f\x78\x6d\x6c\x20\x76\x65\x72\x73\x69\x6f\x6e\x3d\x22\x31\x2e\ @@ -1298,6 +1045,399 @@ \x73\x73\x22\x20\x2f\x3e\x0a\x20\x20\x20\x20\x20\x20\x3c\x2f\x67\ \x3e\x0a\x20\x20\x20\x20\x3c\x2f\x67\x3e\x0a\x20\x20\x3c\x2f\x67\ \x3e\x0a\x3c\x2f\x73\x76\x67\x3e\x0a\ +\x00\x00\x03\x8d\ +\x3c\ +\x73\x76\x67\x20\x78\x6d\x6c\x6e\x73\x3d\x22\x68\x74\x74\x70\x3a\ +\x2f\x2f\x77\x77\x77\x2e\x77\x33\x2e\x6f\x72\x67\x2f\x32\x30\x30\ +\x30\x2f\x73\x76\x67\x22\x20\x76\x69\x65\x77\x42\x6f\x78\x3d\x22\ +\x30\x20\x30\x20\x35\x31\x32\x20\x35\x31\x32\x22\x3e\x3c\x21\x2d\ +\x2d\x21\x20\x46\x6f\x6e\x74\x20\x41\x77\x65\x73\x6f\x6d\x65\x20\ +\x46\x72\x65\x65\x20\x36\x2e\x37\x2e\x32\x20\x62\x79\x20\x40\x66\ +\x6f\x6e\x74\x61\x77\x65\x73\x6f\x6d\x65\x20\x2d\x20\x68\x74\x74\ +\x70\x73\x3a\x2f\x2f\x66\x6f\x6e\x74\x61\x77\x65\x73\x6f\x6d\x65\ +\x2e\x63\x6f\x6d\x20\x4c\x69\x63\x65\x6e\x73\x65\x20\x2d\x20\x68\ +\x74\x74\x70\x73\x3a\x2f\x2f\x66\x6f\x6e\x74\x61\x77\x65\x73\x6f\ +\x6d\x65\x2e\x63\x6f\x6d\x2f\x6c\x69\x63\x65\x6e\x73\x65\x2f\x66\ +\x72\x65\x65\x20\x28\x49\x63\x6f\x6e\x73\x3a\x20\x43\x43\x20\x42\ +\x59\x20\x34\x2e\x30\x2c\x20\x46\x6f\x6e\x74\x73\x3a\x20\x53\x49\ +\x4c\x20\x4f\x46\x4c\x20\x31\x2e\x31\x2c\x20\x43\x6f\x64\x65\x3a\ +\x20\x4d\x49\x54\x20\x4c\x69\x63\x65\x6e\x73\x65\x29\x20\x43\x6f\ +\x70\x79\x72\x69\x67\x68\x74\x20\x32\x30\x32\x34\x20\x46\x6f\x6e\ +\x74\x69\x63\x6f\x6e\x73\x2c\x20\x49\x6e\x63\x2e\x20\x2d\x2d\x3e\ +\x3c\x70\x61\x74\x68\x20\x64\x3d\x22\x4d\x34\x30\x20\x34\x38\x43\ +\x32\x36\x2e\x37\x20\x34\x38\x20\x31\x36\x20\x35\x38\x2e\x37\x20\ +\x31\x36\x20\x37\x32\x6c\x30\x20\x34\x38\x63\x30\x20\x31\x33\x2e\ +\x33\x20\x31\x30\x2e\x37\x20\x32\x34\x20\x32\x34\x20\x32\x34\x6c\ +\x34\x38\x20\x30\x63\x31\x33\x2e\x33\x20\x30\x20\x32\x34\x2d\x31\ +\x30\x2e\x37\x20\x32\x34\x2d\x32\x34\x6c\x30\x2d\x34\x38\x63\x30\ +\x2d\x31\x33\x2e\x33\x2d\x31\x30\x2e\x37\x2d\x32\x34\x2d\x32\x34\ +\x2d\x32\x34\x4c\x34\x30\x20\x34\x38\x7a\x4d\x31\x39\x32\x20\x36\ +\x34\x63\x2d\x31\x37\x2e\x37\x20\x30\x2d\x33\x32\x20\x31\x34\x2e\ +\x33\x2d\x33\x32\x20\x33\x32\x73\x31\x34\x2e\x33\x20\x33\x32\x20\ +\x33\x32\x20\x33\x32\x6c\x32\x38\x38\x20\x30\x63\x31\x37\x2e\x37\ +\x20\x30\x20\x33\x32\x2d\x31\x34\x2e\x33\x20\x33\x32\x2d\x33\x32\ +\x73\x2d\x31\x34\x2e\x33\x2d\x33\x32\x2d\x33\x32\x2d\x33\x32\x4c\ +\x31\x39\x32\x20\x36\x34\x7a\x6d\x30\x20\x31\x36\x30\x63\x2d\x31\ +\x37\x2e\x37\x20\x30\x2d\x33\x32\x20\x31\x34\x2e\x33\x2d\x33\x32\ +\x20\x33\x32\x73\x31\x34\x2e\x33\x20\x33\x32\x20\x33\x32\x20\x33\ +\x32\x6c\x32\x38\x38\x20\x30\x63\x31\x37\x2e\x37\x20\x30\x20\x33\ +\x32\x2d\x31\x34\x2e\x33\x20\x33\x32\x2d\x33\x32\x73\x2d\x31\x34\ +\x2e\x33\x2d\x33\x32\x2d\x33\x32\x2d\x33\x32\x6c\x2d\x32\x38\x38\ +\x20\x30\x7a\x6d\x30\x20\x31\x36\x30\x63\x2d\x31\x37\x2e\x37\x20\ +\x30\x2d\x33\x32\x20\x31\x34\x2e\x33\x2d\x33\x32\x20\x33\x32\x73\ +\x31\x34\x2e\x33\x20\x33\x32\x20\x33\x32\x20\x33\x32\x6c\x32\x38\ +\x38\x20\x30\x63\x31\x37\x2e\x37\x20\x30\x20\x33\x32\x2d\x31\x34\ +\x2e\x33\x20\x33\x32\x2d\x33\x32\x73\x2d\x31\x34\x2e\x33\x2d\x33\ +\x32\x2d\x33\x32\x2d\x33\x32\x6c\x2d\x32\x38\x38\x20\x30\x7a\x4d\ +\x31\x36\x20\x32\x33\x32\x6c\x30\x20\x34\x38\x63\x30\x20\x31\x33\ +\x2e\x33\x20\x31\x30\x2e\x37\x20\x32\x34\x20\x32\x34\x20\x32\x34\ +\x6c\x34\x38\x20\x30\x63\x31\x33\x2e\x33\x20\x30\x20\x32\x34\x2d\ +\x31\x30\x2e\x37\x20\x32\x34\x2d\x32\x34\x6c\x30\x2d\x34\x38\x63\ +\x30\x2d\x31\x33\x2e\x33\x2d\x31\x30\x2e\x37\x2d\x32\x34\x2d\x32\ +\x34\x2d\x32\x34\x6c\x2d\x34\x38\x20\x30\x63\x2d\x31\x33\x2e\x33\ +\x20\x30\x2d\x32\x34\x20\x31\x30\x2e\x37\x2d\x32\x34\x20\x32\x34\ +\x7a\x4d\x34\x30\x20\x33\x36\x38\x63\x2d\x31\x33\x2e\x33\x20\x30\ +\x2d\x32\x34\x20\x31\x30\x2e\x37\x2d\x32\x34\x20\x32\x34\x6c\x30\ +\x20\x34\x38\x63\x30\x20\x31\x33\x2e\x33\x20\x31\x30\x2e\x37\x20\ +\x32\x34\x20\x32\x34\x20\x32\x34\x6c\x34\x38\x20\x30\x63\x31\x33\ +\x2e\x33\x20\x30\x20\x32\x34\x2d\x31\x30\x2e\x37\x20\x32\x34\x2d\ +\x32\x34\x6c\x30\x2d\x34\x38\x63\x30\x2d\x31\x33\x2e\x33\x2d\x31\ +\x30\x2e\x37\x2d\x32\x34\x2d\x32\x34\x2d\x32\x34\x6c\x2d\x34\x38\ +\x20\x30\x7a\x22\x2f\x3e\x3c\x2f\x73\x76\x67\x3e\ +\x00\x00\x02\xc5\ +\x3c\ +\x73\x76\x67\x20\x78\x6d\x6c\x6e\x73\x3d\x22\x68\x74\x74\x70\x3a\ +\x2f\x2f\x77\x77\x77\x2e\x77\x33\x2e\x6f\x72\x67\x2f\x32\x30\x30\ +\x30\x2f\x73\x76\x67\x22\x20\x76\x69\x65\x77\x42\x6f\x78\x3d\x22\ +\x30\x20\x30\x20\x35\x31\x32\x20\x35\x31\x32\x22\x3e\x3c\x21\x2d\ +\x2d\x21\x20\x46\x6f\x6e\x74\x20\x41\x77\x65\x73\x6f\x6d\x65\x20\ +\x46\x72\x65\x65\x20\x36\x2e\x37\x2e\x32\x20\x62\x79\x20\x40\x66\ +\x6f\x6e\x74\x61\x77\x65\x73\x6f\x6d\x65\x20\x2d\x20\x68\x74\x74\ +\x70\x73\x3a\x2f\x2f\x66\x6f\x6e\x74\x61\x77\x65\x73\x6f\x6d\x65\ +\x2e\x63\x6f\x6d\x20\x4c\x69\x63\x65\x6e\x73\x65\x20\x2d\x20\x68\ +\x74\x74\x70\x73\x3a\x2f\x2f\x66\x6f\x6e\x74\x61\x77\x65\x73\x6f\ +\x6d\x65\x2e\x63\x6f\x6d\x2f\x6c\x69\x63\x65\x6e\x73\x65\x2f\x66\ +\x72\x65\x65\x20\x28\x49\x63\x6f\x6e\x73\x3a\x20\x43\x43\x20\x42\ +\x59\x20\x34\x2e\x30\x2c\x20\x46\x6f\x6e\x74\x73\x3a\x20\x53\x49\ +\x4c\x20\x4f\x46\x4c\x20\x31\x2e\x31\x2c\x20\x43\x6f\x64\x65\x3a\ +\x20\x4d\x49\x54\x20\x4c\x69\x63\x65\x6e\x73\x65\x29\x20\x43\x6f\ +\x70\x79\x72\x69\x67\x68\x74\x20\x32\x30\x32\x34\x20\x46\x6f\x6e\ +\x74\x69\x63\x6f\x6e\x73\x2c\x20\x49\x6e\x63\x2e\x20\x2d\x2d\x3e\ +\x3c\x70\x61\x74\x68\x20\x64\x3d\x22\x4d\x31\x37\x37\x2e\x39\x20\ +\x34\x39\x34\x2e\x31\x63\x2d\x31\x38\x2e\x37\x20\x31\x38\x2e\x37\ +\x2d\x34\x39\x2e\x31\x20\x31\x38\x2e\x37\x2d\x36\x37\x2e\x39\x20\ +\x30\x4c\x31\x37\x2e\x39\x20\x34\x30\x31\x2e\x39\x63\x2d\x31\x38\ +\x2e\x37\x2d\x31\x38\x2e\x37\x2d\x31\x38\x2e\x37\x2d\x34\x39\x2e\ +\x31\x20\x30\x2d\x36\x37\x2e\x39\x6c\x35\x30\x2e\x37\x2d\x35\x30\ +\x2e\x37\x20\x34\x38\x20\x34\x38\x63\x36\x2e\x32\x20\x36\x2e\x32\ +\x20\x31\x36\x2e\x34\x20\x36\x2e\x32\x20\x32\x32\x2e\x36\x20\x30\ +\x73\x36\x2e\x32\x2d\x31\x36\x2e\x34\x20\x30\x2d\x32\x32\x2e\x36\ +\x6c\x2d\x34\x38\x2d\x34\x38\x20\x34\x31\x2e\x34\x2d\x34\x31\x2e\ +\x34\x20\x34\x38\x20\x34\x38\x63\x36\x2e\x32\x20\x36\x2e\x32\x20\ +\x31\x36\x2e\x34\x20\x36\x2e\x32\x20\x32\x32\x2e\x36\x20\x30\x73\ +\x36\x2e\x32\x2d\x31\x36\x2e\x34\x20\x30\x2d\x32\x32\x2e\x36\x6c\ +\x2d\x34\x38\x2d\x34\x38\x20\x34\x31\x2e\x34\x2d\x34\x31\x2e\x34\ +\x20\x34\x38\x20\x34\x38\x63\x36\x2e\x32\x20\x36\x2e\x32\x20\x31\ +\x36\x2e\x34\x20\x36\x2e\x32\x20\x32\x32\x2e\x36\x20\x30\x73\x36\ +\x2e\x32\x2d\x31\x36\x2e\x34\x20\x30\x2d\x32\x32\x2e\x36\x6c\x2d\ +\x34\x38\x2d\x34\x38\x20\x34\x31\x2e\x34\x2d\x34\x31\x2e\x34\x20\ +\x34\x38\x20\x34\x38\x63\x36\x2e\x32\x20\x36\x2e\x32\x20\x31\x36\ +\x2e\x34\x20\x36\x2e\x32\x20\x32\x32\x2e\x36\x20\x30\x73\x36\x2e\ +\x32\x2d\x31\x36\x2e\x34\x20\x30\x2d\x32\x32\x2e\x36\x6c\x2d\x34\ +\x38\x2d\x34\x38\x20\x35\x30\x2e\x37\x2d\x35\x30\x2e\x37\x63\x31\ +\x38\x2e\x37\x2d\x31\x38\x2e\x37\x20\x34\x39\x2e\x31\x2d\x31\x38\ +\x2e\x37\x20\x36\x37\x2e\x39\x20\x30\x6c\x39\x32\x2e\x31\x20\x39\ +\x32\x2e\x31\x63\x31\x38\x2e\x37\x20\x31\x38\x2e\x37\x20\x31\x38\ +\x2e\x37\x20\x34\x39\x2e\x31\x20\x30\x20\x36\x37\x2e\x39\x4c\x31\ +\x37\x37\x2e\x39\x20\x34\x39\x34\x2e\x31\x7a\x22\x2f\x3e\x3c\x2f\ +\x73\x76\x67\x3e\ +\x00\x00\x02\x88\ +\x3c\ +\x73\x76\x67\x20\x78\x6d\x6c\x6e\x73\x3d\x22\x68\x74\x74\x70\x3a\ +\x2f\x2f\x77\x77\x77\x2e\x77\x33\x2e\x6f\x72\x67\x2f\x32\x30\x30\ +\x30\x2f\x73\x76\x67\x22\x20\x76\x69\x65\x77\x42\x6f\x78\x3d\x22\ +\x30\x20\x30\x20\x35\x31\x32\x20\x35\x31\x32\x22\x3e\x3c\x21\x2d\ +\x2d\x21\x20\x46\x6f\x6e\x74\x20\x41\x77\x65\x73\x6f\x6d\x65\x20\ +\x46\x72\x65\x65\x20\x36\x2e\x37\x2e\x32\x20\x62\x79\x20\x40\x66\ +\x6f\x6e\x74\x61\x77\x65\x73\x6f\x6d\x65\x20\x2d\x20\x68\x74\x74\ +\x70\x73\x3a\x2f\x2f\x66\x6f\x6e\x74\x61\x77\x65\x73\x6f\x6d\x65\ +\x2e\x63\x6f\x6d\x20\x4c\x69\x63\x65\x6e\x73\x65\x20\x2d\x20\x68\ +\x74\x74\x70\x73\x3a\x2f\x2f\x66\x6f\x6e\x74\x61\x77\x65\x73\x6f\ +\x6d\x65\x2e\x63\x6f\x6d\x2f\x6c\x69\x63\x65\x6e\x73\x65\x2f\x66\ +\x72\x65\x65\x20\x28\x49\x63\x6f\x6e\x73\x3a\x20\x43\x43\x20\x42\ +\x59\x20\x34\x2e\x30\x2c\x20\x46\x6f\x6e\x74\x73\x3a\x20\x53\x49\ +\x4c\x20\x4f\x46\x4c\x20\x31\x2e\x31\x2c\x20\x43\x6f\x64\x65\x3a\ +\x20\x4d\x49\x54\x20\x4c\x69\x63\x65\x6e\x73\x65\x29\x20\x43\x6f\ +\x70\x79\x72\x69\x67\x68\x74\x20\x32\x30\x32\x34\x20\x46\x6f\x6e\ +\x74\x69\x63\x6f\x6e\x73\x2c\x20\x49\x6e\x63\x2e\x20\x2d\x2d\x3e\ +\x3c\x70\x61\x74\x68\x20\x64\x3d\x22\x4d\x36\x34\x20\x33\x32\x43\ +\x32\x38\x2e\x37\x20\x33\x32\x20\x30\x20\x36\x30\x2e\x37\x20\x30\ +\x20\x39\x36\x4c\x30\x20\x34\x31\x36\x63\x30\x20\x33\x35\x2e\x33\ +\x20\x32\x38\x2e\x37\x20\x36\x34\x20\x36\x34\x20\x36\x34\x6c\x33\ +\x38\x34\x20\x30\x63\x33\x35\x2e\x33\x20\x30\x20\x36\x34\x2d\x32\ +\x38\x2e\x37\x20\x36\x34\x2d\x36\x34\x6c\x30\x2d\x33\x32\x30\x63\ +\x30\x2d\x33\x35\x2e\x33\x2d\x32\x38\x2e\x37\x2d\x36\x34\x2d\x36\ +\x34\x2d\x36\x34\x4c\x36\x34\x20\x33\x32\x7a\x6d\x38\x38\x20\x36\ +\x34\x6c\x30\x20\x36\x34\x2d\x38\x38\x20\x30\x20\x30\x2d\x36\x34\ +\x20\x38\x38\x20\x30\x7a\x6d\x35\x36\x20\x30\x6c\x38\x38\x20\x30\ +\x20\x30\x20\x36\x34\x2d\x38\x38\x20\x30\x20\x30\x2d\x36\x34\x7a\ +\x6d\x32\x34\x30\x20\x30\x6c\x30\x20\x36\x34\x2d\x38\x38\x20\x30\ +\x20\x30\x2d\x36\x34\x20\x38\x38\x20\x30\x7a\x4d\x36\x34\x20\x32\ +\x32\x34\x6c\x38\x38\x20\x30\x20\x30\x20\x36\x34\x2d\x38\x38\x20\ +\x30\x20\x30\x2d\x36\x34\x7a\x6d\x32\x33\x32\x20\x30\x6c\x30\x20\ +\x36\x34\x2d\x38\x38\x20\x30\x20\x30\x2d\x36\x34\x20\x38\x38\x20\ +\x30\x7a\x6d\x36\x34\x20\x30\x6c\x38\x38\x20\x30\x20\x30\x20\x36\ +\x34\x2d\x38\x38\x20\x30\x20\x30\x2d\x36\x34\x7a\x4d\x31\x35\x32\ +\x20\x33\x35\x32\x6c\x30\x20\x36\x34\x2d\x38\x38\x20\x30\x20\x30\ +\x2d\x36\x34\x20\x38\x38\x20\x30\x7a\x6d\x35\x36\x20\x30\x6c\x38\ +\x38\x20\x30\x20\x30\x20\x36\x34\x2d\x38\x38\x20\x30\x20\x30\x2d\ +\x36\x34\x7a\x6d\x32\x34\x30\x20\x30\x6c\x30\x20\x36\x34\x2d\x38\ +\x38\x20\x30\x20\x30\x2d\x36\x34\x20\x38\x38\x20\x30\x7a\x22\x2f\ +\x3e\x3c\x2f\x73\x76\x67\x3e\ +\x00\x00\x05\x2e\ +\x3c\ +\x73\x76\x67\x20\x78\x6d\x6c\x6e\x73\x3d\x22\x68\x74\x74\x70\x3a\ +\x2f\x2f\x77\x77\x77\x2e\x77\x33\x2e\x6f\x72\x67\x2f\x32\x30\x30\ +\x30\x2f\x73\x76\x67\x22\x20\x76\x69\x65\x77\x42\x6f\x78\x3d\x22\ +\x30\x20\x30\x20\x35\x31\x32\x20\x35\x31\x32\x22\x3e\x3c\x21\x2d\ +\x2d\x21\x20\x46\x6f\x6e\x74\x20\x41\x77\x65\x73\x6f\x6d\x65\x20\ +\x46\x72\x65\x65\x20\x36\x2e\x37\x2e\x32\x20\x62\x79\x20\x40\x66\ +\x6f\x6e\x74\x61\x77\x65\x73\x6f\x6d\x65\x20\x2d\x20\x68\x74\x74\ +\x70\x73\x3a\x2f\x2f\x66\x6f\x6e\x74\x61\x77\x65\x73\x6f\x6d\x65\ +\x2e\x63\x6f\x6d\x20\x4c\x69\x63\x65\x6e\x73\x65\x20\x2d\x20\x68\ +\x74\x74\x70\x73\x3a\x2f\x2f\x66\x6f\x6e\x74\x61\x77\x65\x73\x6f\ +\x6d\x65\x2e\x63\x6f\x6d\x2f\x6c\x69\x63\x65\x6e\x73\x65\x2f\x66\ +\x72\x65\x65\x20\x28\x49\x63\x6f\x6e\x73\x3a\x20\x43\x43\x20\x42\ +\x59\x20\x34\x2e\x30\x2c\x20\x46\x6f\x6e\x74\x73\x3a\x20\x53\x49\ +\x4c\x20\x4f\x46\x4c\x20\x31\x2e\x31\x2c\x20\x43\x6f\x64\x65\x3a\ +\x20\x4d\x49\x54\x20\x4c\x69\x63\x65\x6e\x73\x65\x29\x20\x43\x6f\ +\x70\x79\x72\x69\x67\x68\x74\x20\x32\x30\x32\x34\x20\x46\x6f\x6e\ +\x74\x69\x63\x6f\x6e\x73\x2c\x20\x49\x6e\x63\x2e\x20\x2d\x2d\x3e\ +\x3c\x70\x61\x74\x68\x20\x64\x3d\x22\x4d\x34\x39\x35\x2e\x39\x20\ +\x31\x36\x36\x2e\x36\x63\x33\x2e\x32\x20\x38\x2e\x37\x20\x2e\x35\ +\x20\x31\x38\x2e\x34\x2d\x36\x2e\x34\x20\x32\x34\x2e\x36\x6c\x2d\ +\x34\x33\x2e\x33\x20\x33\x39\x2e\x34\x63\x31\x2e\x31\x20\x38\x2e\ +\x33\x20\x31\x2e\x37\x20\x31\x36\x2e\x38\x20\x31\x2e\x37\x20\x32\ +\x35\x2e\x34\x73\x2d\x2e\x36\x20\x31\x37\x2e\x31\x2d\x31\x2e\x37\ +\x20\x32\x35\x2e\x34\x6c\x34\x33\x2e\x33\x20\x33\x39\x2e\x34\x63\ +\x36\x2e\x39\x20\x36\x2e\x32\x20\x39\x2e\x36\x20\x31\x35\x2e\x39\ +\x20\x36\x2e\x34\x20\x32\x34\x2e\x36\x63\x2d\x34\x2e\x34\x20\x31\ +\x31\x2e\x39\x2d\x39\x2e\x37\x20\x32\x33\x2e\x33\x2d\x31\x35\x2e\ +\x38\x20\x33\x34\x2e\x33\x6c\x2d\x34\x2e\x37\x20\x38\x2e\x31\x63\ +\x2d\x36\x2e\x36\x20\x31\x31\x2d\x31\x34\x20\x32\x31\x2e\x34\x2d\ +\x32\x32\x2e\x31\x20\x33\x31\x2e\x32\x63\x2d\x35\x2e\x39\x20\x37\ +\x2e\x32\x2d\x31\x35\x2e\x37\x20\x39\x2e\x36\x2d\x32\x34\x2e\x35\ +\x20\x36\x2e\x38\x6c\x2d\x35\x35\x2e\x37\x2d\x31\x37\x2e\x37\x63\ +\x2d\x31\x33\x2e\x34\x20\x31\x30\x2e\x33\x2d\x32\x38\x2e\x32\x20\ +\x31\x38\x2e\x39\x2d\x34\x34\x20\x32\x35\x2e\x34\x6c\x2d\x31\x32\ +\x2e\x35\x20\x35\x37\x2e\x31\x63\x2d\x32\x20\x39\x2e\x31\x2d\x39\ +\x20\x31\x36\x2e\x33\x2d\x31\x38\x2e\x32\x20\x31\x37\x2e\x38\x63\ +\x2d\x31\x33\x2e\x38\x20\x32\x2e\x33\x2d\x32\x38\x20\x33\x2e\x35\ +\x2d\x34\x32\x2e\x35\x20\x33\x2e\x35\x73\x2d\x32\x38\x2e\x37\x2d\ +\x31\x2e\x32\x2d\x34\x32\x2e\x35\x2d\x33\x2e\x35\x63\x2d\x39\x2e\ +\x32\x2d\x31\x2e\x35\x2d\x31\x36\x2e\x32\x2d\x38\x2e\x37\x2d\x31\ +\x38\x2e\x32\x2d\x31\x37\x2e\x38\x6c\x2d\x31\x32\x2e\x35\x2d\x35\ +\x37\x2e\x31\x63\x2d\x31\x35\x2e\x38\x2d\x36\x2e\x35\x2d\x33\x30\ +\x2e\x36\x2d\x31\x35\x2e\x31\x2d\x34\x34\x2d\x32\x35\x2e\x34\x4c\ +\x38\x33\x2e\x31\x20\x34\x32\x35\x2e\x39\x63\x2d\x38\x2e\x38\x20\ +\x32\x2e\x38\x2d\x31\x38\x2e\x36\x20\x2e\x33\x2d\x32\x34\x2e\x35\ +\x2d\x36\x2e\x38\x63\x2d\x38\x2e\x31\x2d\x39\x2e\x38\x2d\x31\x35\ +\x2e\x35\x2d\x32\x30\x2e\x32\x2d\x32\x32\x2e\x31\x2d\x33\x31\x2e\ +\x32\x6c\x2d\x34\x2e\x37\x2d\x38\x2e\x31\x63\x2d\x36\x2e\x31\x2d\ +\x31\x31\x2d\x31\x31\x2e\x34\x2d\x32\x32\x2e\x34\x2d\x31\x35\x2e\ +\x38\x2d\x33\x34\x2e\x33\x63\x2d\x33\x2e\x32\x2d\x38\x2e\x37\x2d\ +\x2e\x35\x2d\x31\x38\x2e\x34\x20\x36\x2e\x34\x2d\x32\x34\x2e\x36\ +\x6c\x34\x33\x2e\x33\x2d\x33\x39\x2e\x34\x43\x36\x34\x2e\x36\x20\ +\x32\x37\x33\x2e\x31\x20\x36\x34\x20\x32\x36\x34\x2e\x36\x20\x36\ +\x34\x20\x32\x35\x36\x73\x2e\x36\x2d\x31\x37\x2e\x31\x20\x31\x2e\ +\x37\x2d\x32\x35\x2e\x34\x4c\x32\x32\x2e\x34\x20\x31\x39\x31\x2e\ +\x32\x63\x2d\x36\x2e\x39\x2d\x36\x2e\x32\x2d\x39\x2e\x36\x2d\x31\ +\x35\x2e\x39\x2d\x36\x2e\x34\x2d\x32\x34\x2e\x36\x63\x34\x2e\x34\ +\x2d\x31\x31\x2e\x39\x20\x39\x2e\x37\x2d\x32\x33\x2e\x33\x20\x31\ +\x35\x2e\x38\x2d\x33\x34\x2e\x33\x6c\x34\x2e\x37\x2d\x38\x2e\x31\ +\x63\x36\x2e\x36\x2d\x31\x31\x20\x31\x34\x2d\x32\x31\x2e\x34\x20\ +\x32\x32\x2e\x31\x2d\x33\x31\x2e\x32\x63\x35\x2e\x39\x2d\x37\x2e\ +\x32\x20\x31\x35\x2e\x37\x2d\x39\x2e\x36\x20\x32\x34\x2e\x35\x2d\ +\x36\x2e\x38\x6c\x35\x35\x2e\x37\x20\x31\x37\x2e\x37\x63\x31\x33\ +\x2e\x34\x2d\x31\x30\x2e\x33\x20\x32\x38\x2e\x32\x2d\x31\x38\x2e\ +\x39\x20\x34\x34\x2d\x32\x35\x2e\x34\x6c\x31\x32\x2e\x35\x2d\x35\ +\x37\x2e\x31\x63\x32\x2d\x39\x2e\x31\x20\x39\x2d\x31\x36\x2e\x33\ +\x20\x31\x38\x2e\x32\x2d\x31\x37\x2e\x38\x43\x32\x32\x37\x2e\x33\ +\x20\x31\x2e\x32\x20\x32\x34\x31\x2e\x35\x20\x30\x20\x32\x35\x36\ +\x20\x30\x73\x32\x38\x2e\x37\x20\x31\x2e\x32\x20\x34\x32\x2e\x35\ +\x20\x33\x2e\x35\x63\x39\x2e\x32\x20\x31\x2e\x35\x20\x31\x36\x2e\ +\x32\x20\x38\x2e\x37\x20\x31\x38\x2e\x32\x20\x31\x37\x2e\x38\x6c\ +\x31\x32\x2e\x35\x20\x35\x37\x2e\x31\x63\x31\x35\x2e\x38\x20\x36\ +\x2e\x35\x20\x33\x30\x2e\x36\x20\x31\x35\x2e\x31\x20\x34\x34\x20\ +\x32\x35\x2e\x34\x6c\x35\x35\x2e\x37\x2d\x31\x37\x2e\x37\x63\x38\ +\x2e\x38\x2d\x32\x2e\x38\x20\x31\x38\x2e\x36\x2d\x2e\x33\x20\x32\ +\x34\x2e\x35\x20\x36\x2e\x38\x63\x38\x2e\x31\x20\x39\x2e\x38\x20\ +\x31\x35\x2e\x35\x20\x32\x30\x2e\x32\x20\x32\x32\x2e\x31\x20\x33\ +\x31\x2e\x32\x6c\x34\x2e\x37\x20\x38\x2e\x31\x63\x36\x2e\x31\x20\ +\x31\x31\x20\x31\x31\x2e\x34\x20\x32\x32\x2e\x34\x20\x31\x35\x2e\ +\x38\x20\x33\x34\x2e\x33\x7a\x4d\x32\x35\x36\x20\x33\x33\x36\x61\ +\x38\x30\x20\x38\x30\x20\x30\x20\x31\x20\x30\x20\x30\x2d\x31\x36\ +\x30\x20\x38\x30\x20\x38\x30\x20\x30\x20\x31\x20\x30\x20\x30\x20\ +\x31\x36\x30\x7a\x22\x2f\x3e\x3c\x2f\x73\x76\x67\x3e\ +\x00\x00\x03\xbb\ +\x3c\ +\x73\x76\x67\x20\x78\x6d\x6c\x6e\x73\x3d\x22\x68\x74\x74\x70\x3a\ +\x2f\x2f\x77\x77\x77\x2e\x77\x33\x2e\x6f\x72\x67\x2f\x32\x30\x30\ +\x30\x2f\x73\x76\x67\x22\x20\x76\x69\x65\x77\x42\x6f\x78\x3d\x22\ +\x30\x20\x30\x20\x35\x31\x32\x20\x35\x31\x32\x22\x3e\x3c\x21\x2d\ +\x2d\x21\x20\x46\x6f\x6e\x74\x20\x41\x77\x65\x73\x6f\x6d\x65\x20\ +\x46\x72\x65\x65\x20\x36\x2e\x37\x2e\x32\x20\x62\x79\x20\x40\x66\ +\x6f\x6e\x74\x61\x77\x65\x73\x6f\x6d\x65\x20\x2d\x20\x68\x74\x74\ +\x70\x73\x3a\x2f\x2f\x66\x6f\x6e\x74\x61\x77\x65\x73\x6f\x6d\x65\ +\x2e\x63\x6f\x6d\x20\x4c\x69\x63\x65\x6e\x73\x65\x20\x2d\x20\x68\ +\x74\x74\x70\x73\x3a\x2f\x2f\x66\x6f\x6e\x74\x61\x77\x65\x73\x6f\ +\x6d\x65\x2e\x63\x6f\x6d\x2f\x6c\x69\x63\x65\x6e\x73\x65\x2f\x66\ +\x72\x65\x65\x20\x28\x49\x63\x6f\x6e\x73\x3a\x20\x43\x43\x20\x42\ +\x59\x20\x34\x2e\x30\x2c\x20\x46\x6f\x6e\x74\x73\x3a\x20\x53\x49\ +\x4c\x20\x4f\x46\x4c\x20\x31\x2e\x31\x2c\x20\x43\x6f\x64\x65\x3a\ +\x20\x4d\x49\x54\x20\x4c\x69\x63\x65\x6e\x73\x65\x29\x20\x43\x6f\ +\x70\x79\x72\x69\x67\x68\x74\x20\x32\x30\x32\x34\x20\x46\x6f\x6e\ +\x74\x69\x63\x6f\x6e\x73\x2c\x20\x49\x6e\x63\x2e\x20\x2d\x2d\x3e\ +\x3c\x70\x61\x74\x68\x20\x64\x3d\x22\x4d\x32\x37\x38\x2e\x36\x20\ +\x39\x2e\x34\x63\x2d\x31\x32\x2e\x35\x2d\x31\x32\x2e\x35\x2d\x33\ +\x32\x2e\x38\x2d\x31\x32\x2e\x35\x2d\x34\x35\x2e\x33\x20\x30\x6c\ +\x2d\x36\x34\x20\x36\x34\x63\x2d\x31\x32\x2e\x35\x20\x31\x32\x2e\ +\x35\x2d\x31\x32\x2e\x35\x20\x33\x32\x2e\x38\x20\x30\x20\x34\x35\ +\x2e\x33\x73\x33\x32\x2e\x38\x20\x31\x32\x2e\x35\x20\x34\x35\x2e\ +\x33\x20\x30\x6c\x39\x2e\x34\x2d\x39\x2e\x34\x4c\x32\x32\x34\x20\ +\x32\x32\x34\x6c\x2d\x31\x31\x34\x2e\x37\x20\x30\x20\x39\x2e\x34\ +\x2d\x39\x2e\x34\x63\x31\x32\x2e\x35\x2d\x31\x32\x2e\x35\x20\x31\ +\x32\x2e\x35\x2d\x33\x32\x2e\x38\x20\x30\x2d\x34\x35\x2e\x33\x73\ +\x2d\x33\x32\x2e\x38\x2d\x31\x32\x2e\x35\x2d\x34\x35\x2e\x33\x20\ +\x30\x6c\x2d\x36\x34\x20\x36\x34\x63\x2d\x31\x32\x2e\x35\x20\x31\ +\x32\x2e\x35\x2d\x31\x32\x2e\x35\x20\x33\x32\x2e\x38\x20\x30\x20\ +\x34\x35\x2e\x33\x6c\x36\x34\x20\x36\x34\x63\x31\x32\x2e\x35\x20\ +\x31\x32\x2e\x35\x20\x33\x32\x2e\x38\x20\x31\x32\x2e\x35\x20\x34\ +\x35\x2e\x33\x20\x30\x73\x31\x32\x2e\x35\x2d\x33\x32\x2e\x38\x20\ +\x30\x2d\x34\x35\x2e\x33\x6c\x2d\x39\x2e\x34\x2d\x39\x2e\x34\x4c\ +\x32\x32\x34\x20\x32\x38\x38\x6c\x30\x20\x31\x31\x34\x2e\x37\x2d\ +\x39\x2e\x34\x2d\x39\x2e\x34\x63\x2d\x31\x32\x2e\x35\x2d\x31\x32\ +\x2e\x35\x2d\x33\x32\x2e\x38\x2d\x31\x32\x2e\x35\x2d\x34\x35\x2e\ +\x33\x20\x30\x73\x2d\x31\x32\x2e\x35\x20\x33\x32\x2e\x38\x20\x30\ +\x20\x34\x35\x2e\x33\x6c\x36\x34\x20\x36\x34\x63\x31\x32\x2e\x35\ +\x20\x31\x32\x2e\x35\x20\x33\x32\x2e\x38\x20\x31\x32\x2e\x35\x20\ +\x34\x35\x2e\x33\x20\x30\x6c\x36\x34\x2d\x36\x34\x63\x31\x32\x2e\ +\x35\x2d\x31\x32\x2e\x35\x20\x31\x32\x2e\x35\x2d\x33\x32\x2e\x38\ +\x20\x30\x2d\x34\x35\x2e\x33\x73\x2d\x33\x32\x2e\x38\x2d\x31\x32\ +\x2e\x35\x2d\x34\x35\x2e\x33\x20\x30\x6c\x2d\x39\x2e\x34\x20\x39\ +\x2e\x34\x4c\x32\x38\x38\x20\x32\x38\x38\x6c\x31\x31\x34\x2e\x37\ +\x20\x30\x2d\x39\x2e\x34\x20\x39\x2e\x34\x63\x2d\x31\x32\x2e\x35\ +\x20\x31\x32\x2e\x35\x2d\x31\x32\x2e\x35\x20\x33\x32\x2e\x38\x20\ +\x30\x20\x34\x35\x2e\x33\x73\x33\x32\x2e\x38\x20\x31\x32\x2e\x35\ +\x20\x34\x35\x2e\x33\x20\x30\x6c\x36\x34\x2d\x36\x34\x63\x31\x32\ +\x2e\x35\x2d\x31\x32\x2e\x35\x20\x31\x32\x2e\x35\x2d\x33\x32\x2e\ +\x38\x20\x30\x2d\x34\x35\x2e\x33\x6c\x2d\x36\x34\x2d\x36\x34\x63\ +\x2d\x31\x32\x2e\x35\x2d\x31\x32\x2e\x35\x2d\x33\x32\x2e\x38\x2d\ +\x31\x32\x2e\x35\x2d\x34\x35\x2e\x33\x20\x30\x73\x2d\x31\x32\x2e\ +\x35\x20\x33\x32\x2e\x38\x20\x30\x20\x34\x35\x2e\x33\x6c\x39\x2e\ +\x34\x20\x39\x2e\x34\x4c\x32\x38\x38\x20\x32\x32\x34\x6c\x30\x2d\ +\x31\x31\x34\x2e\x37\x20\x39\x2e\x34\x20\x39\x2e\x34\x63\x31\x32\ +\x2e\x35\x20\x31\x32\x2e\x35\x20\x33\x32\x2e\x38\x20\x31\x32\x2e\ +\x35\x20\x34\x35\x2e\x33\x20\x30\x73\x31\x32\x2e\x35\x2d\x33\x32\ +\x2e\x38\x20\x30\x2d\x34\x35\x2e\x33\x6c\x2d\x36\x34\x2d\x36\x34\ +\x7a\x22\x2f\x3e\x3c\x2f\x73\x76\x67\x3e\ +\x00\x00\x03\x51\ +\x3c\ +\x73\x76\x67\x20\x78\x6d\x6c\x6e\x73\x3d\x22\x68\x74\x74\x70\x3a\ +\x2f\x2f\x77\x77\x77\x2e\x77\x33\x2e\x6f\x72\x67\x2f\x32\x30\x30\ +\x30\x2f\x73\x76\x67\x22\x20\x76\x69\x65\x77\x42\x6f\x78\x3d\x22\ +\x30\x20\x30\x20\x34\x34\x38\x20\x35\x31\x32\x22\x3e\x3c\x21\x2d\ +\x2d\x21\x20\x46\x6f\x6e\x74\x20\x41\x77\x65\x73\x6f\x6d\x65\x20\ +\x46\x72\x65\x65\x20\x36\x2e\x37\x2e\x32\x20\x62\x79\x20\x40\x66\ +\x6f\x6e\x74\x61\x77\x65\x73\x6f\x6d\x65\x20\x2d\x20\x68\x74\x74\ +\x70\x73\x3a\x2f\x2f\x66\x6f\x6e\x74\x61\x77\x65\x73\x6f\x6d\x65\ +\x2e\x63\x6f\x6d\x20\x4c\x69\x63\x65\x6e\x73\x65\x20\x2d\x20\x68\ +\x74\x74\x70\x73\x3a\x2f\x2f\x66\x6f\x6e\x74\x61\x77\x65\x73\x6f\ +\x6d\x65\x2e\x63\x6f\x6d\x2f\x6c\x69\x63\x65\x6e\x73\x65\x2f\x66\ +\x72\x65\x65\x20\x28\x49\x63\x6f\x6e\x73\x3a\x20\x43\x43\x20\x42\ +\x59\x20\x34\x2e\x30\x2c\x20\x46\x6f\x6e\x74\x73\x3a\x20\x53\x49\ +\x4c\x20\x4f\x46\x4c\x20\x31\x2e\x31\x2c\x20\x43\x6f\x64\x65\x3a\ +\x20\x4d\x49\x54\x20\x4c\x69\x63\x65\x6e\x73\x65\x29\x20\x43\x6f\ +\x70\x79\x72\x69\x67\x68\x74\x20\x32\x30\x32\x34\x20\x46\x6f\x6e\ +\x74\x69\x63\x6f\x6e\x73\x2c\x20\x49\x6e\x63\x2e\x20\x2d\x2d\x3e\ +\x3c\x70\x61\x74\x68\x20\x64\x3d\x22\x4d\x33\x36\x38\x20\x38\x30\ +\x6c\x33\x32\x20\x30\x20\x30\x20\x33\x32\x2d\x33\x32\x20\x30\x20\ +\x30\x2d\x33\x32\x7a\x4d\x33\x35\x32\x20\x33\x32\x63\x2d\x31\x37\ +\x2e\x37\x20\x30\x2d\x33\x32\x20\x31\x34\x2e\x33\x2d\x33\x32\x20\ +\x33\x32\x4c\x31\x32\x38\x20\x36\x34\x63\x30\x2d\x31\x37\x2e\x37\ +\x2d\x31\x34\x2e\x33\x2d\x33\x32\x2d\x33\x32\x2d\x33\x32\x4c\x33\ +\x32\x20\x33\x32\x43\x31\x34\x2e\x33\x20\x33\x32\x20\x30\x20\x34\ +\x36\x2e\x33\x20\x30\x20\x36\x34\x6c\x30\x20\x36\x34\x63\x30\x20\ +\x31\x37\x2e\x37\x20\x31\x34\x2e\x33\x20\x33\x32\x20\x33\x32\x20\ +\x33\x32\x6c\x30\x20\x31\x39\x32\x63\x2d\x31\x37\x2e\x37\x20\x30\ +\x2d\x33\x32\x20\x31\x34\x2e\x33\x2d\x33\x32\x20\x33\x32\x6c\x30\ +\x20\x36\x34\x63\x30\x20\x31\x37\x2e\x37\x20\x31\x34\x2e\x33\x20\ +\x33\x32\x20\x33\x32\x20\x33\x32\x6c\x36\x34\x20\x30\x63\x31\x37\ +\x2e\x37\x20\x30\x20\x33\x32\x2d\x31\x34\x2e\x33\x20\x33\x32\x2d\ +\x33\x32\x6c\x31\x39\x32\x20\x30\x63\x30\x20\x31\x37\x2e\x37\x20\ +\x31\x34\x2e\x33\x20\x33\x32\x20\x33\x32\x20\x33\x32\x6c\x36\x34\ +\x20\x30\x63\x31\x37\x2e\x37\x20\x30\x20\x33\x32\x2d\x31\x34\x2e\ +\x33\x20\x33\x32\x2d\x33\x32\x6c\x30\x2d\x36\x34\x63\x30\x2d\x31\ +\x37\x2e\x37\x2d\x31\x34\x2e\x33\x2d\x33\x32\x2d\x33\x32\x2d\x33\ +\x32\x6c\x30\x2d\x31\x39\x32\x63\x31\x37\x2e\x37\x20\x30\x20\x33\ +\x32\x2d\x31\x34\x2e\x33\x20\x33\x32\x2d\x33\x32\x6c\x30\x2d\x36\ +\x34\x63\x30\x2d\x31\x37\x2e\x37\x2d\x31\x34\x2e\x33\x2d\x33\x32\ +\x2d\x33\x32\x2d\x33\x32\x6c\x2d\x36\x34\x20\x30\x7a\x4d\x39\x36\ +\x20\x31\x36\x30\x63\x31\x37\x2e\x37\x20\x30\x20\x33\x32\x2d\x31\ +\x34\x2e\x33\x20\x33\x32\x2d\x33\x32\x6c\x31\x39\x32\x20\x30\x63\ +\x30\x20\x31\x37\x2e\x37\x20\x31\x34\x2e\x33\x20\x33\x32\x20\x33\ +\x32\x20\x33\x32\x6c\x30\x20\x31\x39\x32\x63\x2d\x31\x37\x2e\x37\ +\x20\x30\x2d\x33\x32\x20\x31\x34\x2e\x33\x2d\x33\x32\x20\x33\x32\ +\x6c\x2d\x31\x39\x32\x20\x30\x63\x30\x2d\x31\x37\x2e\x37\x2d\x31\ +\x34\x2e\x33\x2d\x33\x32\x2d\x33\x32\x2d\x33\x32\x6c\x30\x2d\x31\ +\x39\x32\x7a\x4d\x34\x38\x20\x34\x30\x30\x6c\x33\x32\x20\x30\x20\ +\x30\x20\x33\x32\x2d\x33\x32\x20\x30\x20\x30\x2d\x33\x32\x7a\x6d\ +\x33\x32\x30\x20\x33\x32\x6c\x30\x2d\x33\x32\x20\x33\x32\x20\x30\ +\x20\x30\x20\x33\x32\x2d\x33\x32\x20\x30\x7a\x4d\x34\x38\x20\x31\ +\x31\x32\x6c\x30\x2d\x33\x32\x20\x33\x32\x20\x30\x20\x30\x20\x33\ +\x32\x2d\x33\x32\x20\x30\x7a\x22\x2f\x3e\x3c\x2f\x73\x76\x67\x3e\ +\ +\x00\x00\x02\x6e\ +\x3c\ +\x73\x76\x67\x20\x78\x6d\x6c\x6e\x73\x3d\x22\x68\x74\x74\x70\x3a\ +\x2f\x2f\x77\x77\x77\x2e\x77\x33\x2e\x6f\x72\x67\x2f\x32\x30\x30\ +\x30\x2f\x73\x76\x67\x22\x20\x76\x69\x65\x77\x42\x6f\x78\x3d\x22\ +\x30\x20\x30\x20\x35\x31\x32\x20\x35\x31\x32\x22\x3e\x3c\x21\x2d\ +\x2d\x21\x20\x46\x6f\x6e\x74\x20\x41\x77\x65\x73\x6f\x6d\x65\x20\ +\x46\x72\x65\x65\x20\x36\x2e\x37\x2e\x32\x20\x62\x79\x20\x40\x66\ +\x6f\x6e\x74\x61\x77\x65\x73\x6f\x6d\x65\x20\x2d\x20\x68\x74\x74\ +\x70\x73\x3a\x2f\x2f\x66\x6f\x6e\x74\x61\x77\x65\x73\x6f\x6d\x65\ +\x2e\x63\x6f\x6d\x20\x4c\x69\x63\x65\x6e\x73\x65\x20\x2d\x20\x68\ +\x74\x74\x70\x73\x3a\x2f\x2f\x66\x6f\x6e\x74\x61\x77\x65\x73\x6f\ +\x6d\x65\x2e\x63\x6f\x6d\x2f\x6c\x69\x63\x65\x6e\x73\x65\x2f\x66\ +\x72\x65\x65\x20\x28\x49\x63\x6f\x6e\x73\x3a\x20\x43\x43\x20\x42\ +\x59\x20\x34\x2e\x30\x2c\x20\x46\x6f\x6e\x74\x73\x3a\x20\x53\x49\ +\x4c\x20\x4f\x46\x4c\x20\x31\x2e\x31\x2c\x20\x43\x6f\x64\x65\x3a\ +\x20\x4d\x49\x54\x20\x4c\x69\x63\x65\x6e\x73\x65\x29\x20\x43\x6f\ +\x70\x79\x72\x69\x67\x68\x74\x20\x32\x30\x32\x34\x20\x46\x6f\x6e\ +\x74\x69\x63\x6f\x6e\x73\x2c\x20\x49\x6e\x63\x2e\x20\x2d\x2d\x3e\ +\x3c\x70\x61\x74\x68\x20\x64\x3d\x22\x4d\x34\x39\x38\x2e\x31\x20\ +\x35\x2e\x36\x63\x31\x30\x2e\x31\x20\x37\x20\x31\x35\x2e\x34\x20\ +\x31\x39\x2e\x31\x20\x31\x33\x2e\x35\x20\x33\x31\x2e\x32\x6c\x2d\ +\x36\x34\x20\x34\x31\x36\x63\x2d\x31\x2e\x35\x20\x39\x2e\x37\x2d\ +\x37\x2e\x34\x20\x31\x38\x2e\x32\x2d\x31\x36\x20\x32\x33\x73\x2d\ +\x31\x38\x2e\x39\x20\x35\x2e\x34\x2d\x32\x38\x20\x31\x2e\x36\x4c\ +\x32\x38\x34\x20\x34\x32\x37\x2e\x37\x6c\x2d\x36\x38\x2e\x35\x20\ +\x37\x34\x2e\x31\x63\x2d\x38\x2e\x39\x20\x39\x2e\x37\x2d\x32\x32\ +\x2e\x39\x20\x31\x32\x2e\x39\x2d\x33\x35\x2e\x32\x20\x38\x2e\x31\ +\x53\x31\x36\x30\x20\x34\x39\x33\x2e\x32\x20\x31\x36\x30\x20\x34\ +\x38\x30\x6c\x30\x2d\x38\x33\x2e\x36\x63\x30\x2d\x34\x20\x31\x2e\ +\x35\x2d\x37\x2e\x38\x20\x34\x2e\x32\x2d\x31\x30\x2e\x38\x4c\x33\ +\x33\x31\x2e\x38\x20\x32\x30\x32\x2e\x38\x63\x35\x2e\x38\x2d\x36\ +\x2e\x33\x20\x35\x2e\x36\x2d\x31\x36\x2d\x2e\x34\x2d\x32\x32\x73\ +\x2d\x31\x35\x2e\x37\x2d\x36\x2e\x34\x2d\x32\x32\x2d\x2e\x37\x4c\ +\x31\x30\x36\x20\x33\x36\x30\x2e\x38\x20\x31\x37\x2e\x37\x20\x33\ +\x31\x36\x2e\x36\x43\x37\x2e\x31\x20\x33\x31\x31\x2e\x33\x20\x2e\ +\x33\x20\x33\x30\x30\x2e\x37\x20\x30\x20\x32\x38\x38\x2e\x39\x73\ +\x35\x2e\x39\x2d\x32\x32\x2e\x38\x20\x31\x36\x2e\x31\x2d\x32\x38\ +\x2e\x37\x6c\x34\x34\x38\x2d\x32\x35\x36\x63\x31\x30\x2e\x37\x2d\ +\x36\x2e\x31\x20\x32\x33\x2e\x39\x2d\x35\x2e\x35\x20\x33\x34\x20\ +\x31\x2e\x34\x7a\x22\x2f\x3e\x3c\x2f\x73\x76\x67\x3e\ ' qt_resource_name = b'\ @@ -1305,137 +1445,151 @@ \x00\x6f\xa6\x53\ \x00\x69\ \x00\x63\x00\x6f\x00\x6e\x00\x73\ -\x00\x04\ -\x00\x06\xf6\x35\ -\x00\x68\ -\x00\x6f\x00\x6d\x00\x65\ -\x00\x04\ -\x00\x07\x46\xc5\ -\x00\x6d\ -\x00\x6f\x00\x76\x00\x65\ +\x00\x05\ +\x00\x77\x95\x85\ +\x00\x70\ +\x00\x72\x00\x6f\x00\x62\x00\x65\ \x00\x04\ \x00\x07\x98\xc5\ \x00\x73\ \x00\x61\x00\x76\x00\x65\ +\x00\x08\ +\x06\x91\xdc\xa7\ +\x00\x77\ +\x00\x6f\x00\x72\x00\x6b\x00\x66\x00\x6c\x00\x6f\x00\x77\ +\x00\x0d\ +\x0a\xef\x61\xc2\ +\x00\x72\ +\x00\x65\x00\x63\x00\x6f\x00\x6e\x00\x73\x00\x74\x00\x72\x00\x75\x00\x63\x00\x74\x00\x6f\x00\x72\ \x00\x04\ \x00\x07\x99\x7e\ \x00\x73\ \x00\x63\x00\x61\x00\x6e\ \x00\x08\ +\x06\x89\x2d\x83\ +\x00\x73\ +\x00\x70\x00\x61\x00\x72\x00\x6b\x00\x6c\x00\x65\x00\x73\ +\x00\x09\ +\x0b\x69\x49\xa5\ +\x00\x61\ +\x00\x75\x00\x74\x00\x6f\x00\x73\x00\x63\x00\x61\x00\x6c\x00\x65\ +\x00\x04\ +\x00\x06\xf6\x35\ +\x00\x68\ +\x00\x6f\x00\x6d\x00\x65\ +\x00\x06\ +\x07\x59\x0b\xa4\ +\x00\x6f\ +\x00\x62\x00\x6a\x00\x65\x00\x63\x00\x74\ +\x00\x08\ +\x0c\xb6\x35\xa5\ +\x00\x61\ +\x00\x75\x00\x74\x00\x6f\x00\x6d\x00\x61\x00\x74\x00\x65\ +\x00\x08\ \x00\x48\x34\xa4\ \x00\x6c\ \x00\x69\x00\x6e\x00\x65\x00\x2d\x00\x63\x00\x75\x00\x74\ -\x00\x05\ -\x00\x77\x95\x85\ +\x00\x09\ +\x0f\x9f\xb4\xa3\ \x00\x70\ -\x00\x72\x00\x6f\x00\x62\x00\x65\ +\x00\x74\x00\x79\x00\x63\x00\x68\x00\x6f\x00\x64\x00\x75\x00\x73\ +\x00\x08\ +\x09\x5b\xb4\x53\ +\x00\x70\ +\x00\x72\x00\x6f\x00\x64\x00\x75\x00\x63\x00\x74\x00\x73\ \x00\x05\ \x00\x79\xc2\xc2\ \x00\x72\ \x00\x75\x00\x6c\x00\x65\x00\x72\ \x00\x08\ -\x06\x91\xdc\xa7\ -\x00\x77\ -\x00\x6f\x00\x72\x00\x6b\x00\x66\x00\x6c\x00\x6f\x00\x77\ -\x00\x06\ -\x07\x59\x0b\xa4\ -\x00\x6f\ -\x00\x62\x00\x6a\x00\x65\x00\x63\x00\x74\ -\x00\x08\ \x08\xba\xc7\x93\ \x00\x70\ \x00\x61\x00\x74\x00\x74\x00\x65\x00\x72\x00\x6e\x00\x73\ \x00\x08\ -\x09\x5b\xb4\x53\ -\x00\x70\ -\x00\x72\x00\x6f\x00\x64\x00\x75\x00\x63\x00\x74\x00\x73\ +\x0c\xbb\x0b\xc3\ +\x00\x73\ +\x00\x65\x00\x74\x00\x74\x00\x69\x00\x6e\x00\x67\x00\x73\ +\x00\x04\ +\x00\x07\x46\xc5\ +\x00\x6d\ +\x00\x6f\x00\x76\x00\x65\ \x00\x09\ \x0a\xa8\xbf\x45\ \x00\x72\ \x00\x65\x00\x63\x00\x74\x00\x61\x00\x6e\x00\x67\x00\x6c\x00\x65\ -\x00\x0d\ -\x0a\xef\x61\xc2\ -\x00\x72\ -\x00\x65\x00\x63\x00\x6f\x00\x6e\x00\x73\x00\x74\x00\x72\x00\x75\x00\x63\x00\x74\x00\x6f\x00\x72\ -\x00\x09\ -\x0b\x69\x49\xa5\ -\x00\x61\ -\x00\x75\x00\x74\x00\x6f\x00\x73\x00\x63\x00\x61\x00\x6c\x00\x65\ -\x00\x08\ -\x0c\xb6\x35\xa5\ -\x00\x61\ -\x00\x75\x00\x74\x00\x6f\x00\x6d\x00\x61\x00\x74\x00\x65\ -\x00\x08\ -\x0c\xbb\x0b\xc3\ +\x00\x04\ +\x00\x07\x9c\x44\ \x00\x73\ -\x00\x65\x00\x74\x00\x74\x00\x69\x00\x6e\x00\x67\x00\x73\ -\x00\x09\ -\x0f\x9f\xb4\xa3\ -\x00\x70\ -\x00\x74\x00\x79\x00\x63\x00\x68\x00\x6f\x00\x64\x00\x75\x00\x73\ +\x00\x65\x00\x6e\x00\x64\ ' qt_resource_struct_v1 = b'\ \x00\x00\x00\x00\x00\x02\x00\x00\x00\x01\x00\x00\x00\x01\ -\x00\x00\x00\x00\x00\x02\x00\x00\x00\x11\x00\x00\x00\x02\ +\x00\x00\x00\x00\x00\x02\x00\x00\x00\x13\x00\x00\x00\x02\ +\x00\x00\x00\xa0\x00\x00\x00\x00\x00\x01\x00\x00\x19\x2f\ +\x00\x00\x01\x56\x00\x00\x00\x00\x00\x01\x00\x00\x4d\x1b\ +\x00\x00\x00\x20\x00\x00\x00\x00\x00\x01\x00\x00\x03\x5f\ +\x00\x00\x00\x64\x00\x00\x00\x00\x00\x01\x00\x00\x0d\xc7\ +\x00\x00\x01\x7c\x00\x00\x00\x00\x00\x01\x00\x00\x54\x2f\ +\x00\x00\x00\xd6\x00\x00\x00\x00\x00\x01\x00\x00\x23\xc9\ \x00\x00\x00\x10\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\ -\x00\x00\x00\x1e\x00\x00\x00\x00\x00\x01\x00\x00\x03\x22\ -\x00\x00\x00\x2c\x00\x00\x00\x00\x00\x01\x00\x00\x06\xca\ -\x00\x00\x00\x3a\x00\x00\x00\x00\x00\x01\x00\x00\x09\x90\ -\x00\x00\x00\x48\x00\x00\x00\x00\x00\x01\x00\x00\x0d\x0c\ -\x00\x00\x00\x5e\x00\x00\x00\x00\x00\x01\x00\x00\x0f\x8d\ -\x00\x00\x00\x6e\x00\x00\x00\x00\x00\x01\x00\x00\x12\xec\ -\x00\x00\x00\x7e\x00\x00\x00\x00\x00\x01\x00\x00\x15\xb5\ -\x00\x00\x00\x94\x00\x00\x00\x00\x00\x01\x00\x00\x18\xfb\ -\x00\x00\x00\xa6\x00\x00\x00\x00\x00\x01\x00\x00\x1c\x97\ -\x00\x00\x00\xbc\x00\x00\x00\x00\x00\x01\x00\x00\x1e\xdd\ -\x00\x00\x00\xd2\x00\x00\x00\x00\x00\x01\x00\x00\x22\x46\ -\x00\x00\x00\xea\x00\x00\x00\x00\x00\x01\x00\x00\x25\x5d\ -\x00\x00\x01\x0a\x00\x00\x00\x00\x00\x01\x00\x00\x29\x87\ -\x00\x00\x01\x22\x00\x00\x00\x00\x00\x01\x00\x00\x2c\xb6\ -\x00\x00\x01\x38\x00\x00\x00\x00\x00\x01\x00\x00\x30\x43\ -\x00\x00\x01\x4e\x00\x00\x00\x00\x00\x01\x00\x00\x35\x75\ +\x00\x00\x01\x1a\x00\x00\x00\x00\x00\x01\x00\x00\x42\x94\ +\x00\x00\x00\x72\x00\x00\x00\x00\x00\x01\x00\x00\x11\x54\ +\x00\x00\x00\x2e\x00\x00\x00\x00\x00\x01\x00\x00\x06\x4d\ +\x00\x00\x00\xae\x00\x00\x00\x00\x00\x01\x00\x00\x1c\x76\ +\x00\x00\x01\x2a\x00\x00\x00\x00\x00\x01\x00\x00\x45\x5d\ +\x00\x00\x01\x04\x00\x00\x00\x00\x00\x01\x00\x00\x3f\x03\ +\x00\x00\x01\x64\x00\x00\x00\x00\x00\x01\x00\x00\x50\xda\ +\x00\x00\x00\x44\x00\x00\x00\x00\x00\x01\x00\x00\x09\x93\ +\x00\x00\x00\x88\x00\x00\x00\x00\x00\x01\x00\x00\x15\xf9\ +\x00\x00\x00\xc0\x00\x00\x00\x00\x00\x01\x00\x00\x20\x12\ +\x00\x00\x01\x40\x00\x00\x00\x00\x00\x01\x00\x00\x47\xe9\ +\x00\x00\x00\xec\x00\x00\x00\x00\x00\x01\x00\x00\x26\x55\ ' qt_resource_struct_v2 = b'\ \x00\x00\x00\x00\x00\x02\x00\x00\x00\x01\x00\x00\x00\x01\ \x00\x00\x00\x00\x00\x00\x00\x00\ -\x00\x00\x00\x00\x00\x02\x00\x00\x00\x11\x00\x00\x00\x02\ +\x00\x00\x00\x00\x00\x02\x00\x00\x00\x13\x00\x00\x00\x02\ \x00\x00\x00\x00\x00\x00\x00\x00\ +\x00\x00\x00\xa0\x00\x00\x00\x00\x00\x01\x00\x00\x19\x2f\ +\x00\x00\x01\x93\xd1\x62\x88\xe0\ +\x00\x00\x01\x56\x00\x00\x00\x00\x00\x01\x00\x00\x4d\x1b\ +\x00\x00\x01\x93\xd1\x62\x88\xe0\ +\x00\x00\x00\x20\x00\x00\x00\x00\x00\x01\x00\x00\x03\x5f\ +\x00\x00\x01\x93\xd1\x62\x88\xe0\ +\x00\x00\x00\x64\x00\x00\x00\x00\x00\x01\x00\x00\x0d\xc7\ +\x00\x00\x01\x93\xd1\x62\x88\xe0\ +\x00\x00\x01\x7c\x00\x00\x00\x00\x00\x01\x00\x00\x54\x2f\ +\x00\x00\x01\x93\xd1\x62\x88\xe0\ +\x00\x00\x00\xd6\x00\x00\x00\x00\x00\x01\x00\x00\x23\xc9\ +\x00\x00\x01\x93\xd1\x62\x88\xe0\ \x00\x00\x00\x10\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\ -\x00\x00\x01\x8e\xa0\x29\xf0\xd8\ -\x00\x00\x00\x1e\x00\x00\x00\x00\x00\x01\x00\x00\x03\x22\ -\x00\x00\x01\x8e\xa0\x29\xf0\xd8\ -\x00\x00\x00\x2c\x00\x00\x00\x00\x00\x01\x00\x00\x06\xca\ -\x00\x00\x01\x8e\xa0\x29\xf0\xd8\ -\x00\x00\x00\x3a\x00\x00\x00\x00\x00\x01\x00\x00\x09\x90\ -\x00\x00\x01\x8e\xa0\x29\xf0\xd8\ -\x00\x00\x00\x48\x00\x00\x00\x00\x00\x01\x00\x00\x0d\x0c\ -\x00\x00\x01\x8e\xa0\x29\xf0\xd8\ -\x00\x00\x00\x5e\x00\x00\x00\x00\x00\x01\x00\x00\x0f\x8d\ -\x00\x00\x01\x8e\xa0\x29\xf0\xd8\ -\x00\x00\x00\x6e\x00\x00\x00\x00\x00\x01\x00\x00\x12\xec\ -\x00\x00\x01\x8e\xa0\x29\xf0\xd8\ -\x00\x00\x00\x7e\x00\x00\x00\x00\x00\x01\x00\x00\x15\xb5\ -\x00\x00\x01\x8e\xa0\x29\xf0\xd8\ -\x00\x00\x00\x94\x00\x00\x00\x00\x00\x01\x00\x00\x18\xfb\ -\x00\x00\x01\x8e\xa0\x29\xf0\xd8\ -\x00\x00\x00\xa6\x00\x00\x00\x00\x00\x01\x00\x00\x1c\x97\ -\x00\x00\x01\x8e\xa0\x29\xf0\xd8\ -\x00\x00\x00\xbc\x00\x00\x00\x00\x00\x01\x00\x00\x1e\xdd\ -\x00\x00\x01\x8e\xa0\x29\xf0\xd8\ -\x00\x00\x00\xd2\x00\x00\x00\x00\x00\x01\x00\x00\x22\x46\ -\x00\x00\x01\x8e\xa0\x29\xf0\xd8\ -\x00\x00\x00\xea\x00\x00\x00\x00\x00\x01\x00\x00\x25\x5d\ -\x00\x00\x01\x8e\xa0\x29\xf0\xd8\ -\x00\x00\x01\x0a\x00\x00\x00\x00\x00\x01\x00\x00\x29\x87\ -\x00\x00\x01\x8e\xa0\x29\xf0\xd8\ -\x00\x00\x01\x22\x00\x00\x00\x00\x00\x01\x00\x00\x2c\xb6\ -\x00\x00\x01\x8e\xa0\x29\xf0\xd8\ -\x00\x00\x01\x38\x00\x00\x00\x00\x00\x01\x00\x00\x30\x43\ -\x00\x00\x01\x8e\xa0\x29\xf0\xd8\ -\x00\x00\x01\x4e\x00\x00\x00\x00\x00\x01\x00\x00\x35\x75\ -\x00\x00\x01\x8e\x9f\x32\xd5\x0a\ +\x00\x00\x01\x93\xd1\x62\x88\xe0\ +\x00\x00\x01\x1a\x00\x00\x00\x00\x00\x01\x00\x00\x42\x94\ +\x00\x00\x01\x93\xd1\x62\x88\xe0\ +\x00\x00\x00\x72\x00\x00\x00\x00\x00\x01\x00\x00\x11\x54\ +\x00\x00\x01\x93\xd1\x62\x88\xe0\ +\x00\x00\x00\x2e\x00\x00\x00\x00\x00\x01\x00\x00\x06\x4d\ +\x00\x00\x01\x93\xd1\x62\x88\xe0\ +\x00\x00\x00\xae\x00\x00\x00\x00\x00\x01\x00\x00\x1c\x76\ +\x00\x00\x01\x93\xd1\x62\x88\xe0\ +\x00\x00\x01\x2a\x00\x00\x00\x00\x00\x01\x00\x00\x45\x5d\ +\x00\x00\x01\x93\xd1\x62\x88\xe0\ +\x00\x00\x01\x04\x00\x00\x00\x00\x00\x01\x00\x00\x3f\x03\ +\x00\x00\x01\x93\xd1\x62\x88\xe0\ +\x00\x00\x01\x64\x00\x00\x00\x00\x00\x01\x00\x00\x50\xda\ +\x00\x00\x01\x93\xd1\x62\x88\xe0\ +\x00\x00\x00\x44\x00\x00\x00\x00\x00\x01\x00\x00\x09\x93\ +\x00\x00\x01\x93\xd1\x62\x88\xe0\ +\x00\x00\x00\x88\x00\x00\x00\x00\x00\x01\x00\x00\x15\xf9\ +\x00\x00\x01\x93\xd1\x62\x88\xe0\ +\x00\x00\x00\xc0\x00\x00\x00\x00\x00\x01\x00\x00\x20\x12\ +\x00\x00\x01\x93\xd1\x62\x88\xe0\ +\x00\x00\x01\x40\x00\x00\x00\x00\x00\x01\x00\x00\x47\xe9\ +\x00\x00\x01\x93\xd1\x62\x88\xe0\ +\x00\x00\x00\xec\x00\x00\x00\x00\x00\x01\x00\x00\x26\x55\ +\x00\x00\x01\x93\x8e\x95\x84\x0f\ ' qt_version = [int(v) for v in QtCore.qVersion().split('.')] @@ -1447,13 +1601,13 @@ qt_resource_struct = qt_resource_struct_v2 -def qInitResources() -> None: +def qInitResources() -> None: # noqa: N802 QtCore.qRegisterResourceData( rcc_version, qt_resource_struct, qt_resource_name, qt_resource_data ) -def qCleanupResources() -> None: +def qCleanupResources() -> None: # noqa: N802 QtCore.qUnregisterResourceData( rcc_version, qt_resource_struct, qt_resource_name, qt_resource_data ) diff --git a/src/ptychodus/view/resources.qrc b/src/ptychodus/view/resources.qrc index 969758da..c80c8c55 100644 --- a/src/ptychodus/view/resources.qrc +++ b/src/ptychodus/view/resources.qrc @@ -1,22 +1,24 @@ - Font-Awesome-6.5.2/svgs/solid/robot.svg - Font-Awesome-6.5.2/svgs/solid/arrows-left-right-to-line.svg - Font-Awesome-6.5.2/svgs/solid/house-chimney.svg - Font-Awesome-6.5.2/svgs/solid/chart-line.svg - Font-Awesome-6.5.2/svgs/solid/arrows-up-down-left-right.svg - Font-Awesome-6.5.2/svgs/solid/layer-group.svg - Font-Awesome-6.5.2/svgs/solid/table-cells.svg - Font-Awesome-6.5.2/svgs/solid/circle-radiation.svg - Font-Awesome-6.5.2/svgs/solid/list.svg + Font-Awesome-6.7.2/svgs/solid/robot.svg + Font-Awesome-6.7.2/svgs/solid/arrows-left-right-to-line.svg + Font-Awesome-6.7.2/svgs/solid/house-chimney.svg + Font-Awesome-6.7.2/svgs/solid/chart-line.svg + Font-Awesome-6.7.2/svgs/solid/arrows-up-down-left-right.svg + Font-Awesome-6.7.2/svgs/solid/layer-group.svg + Font-Awesome-6.7.2/svgs/solid/table-cells.svg + Font-Awesome-6.7.2/svgs/solid/circle-radiation.svg + Font-Awesome-6.7.2/svgs/solid/list.svg ptychodus.svg - Font-Awesome-6.5.2/svgs/solid/screwdriver-wrench.svg - Font-Awesome-6.5.2/svgs/solid/vector-square.svg - Font-Awesome-6.5.2/svgs/solid/ruler.svg - Font-Awesome-6.5.2/svgs/regular/floppy-disk.svg - Font-Awesome-6.5.2/svgs/solid/route.svg - Font-Awesome-6.5.2/svgs/solid/gear.svg - Font-Awesome-6.5.2/svgs/solid/toolbox.svg + Font-Awesome-6.7.2/svgs/solid/screwdriver-wrench.svg + Font-Awesome-6.7.2/svgs/solid/vector-square.svg + Font-Awesome-6.7.2/svgs/solid/ruler.svg + Font-Awesome-6.7.2/svgs/regular/floppy-disk.svg + Font-Awesome-6.7.2/svgs/solid/route.svg + Font-Awesome-6.7.2/svgs/solid/paper-plane.svg + Font-Awesome-6.7.2/svgs/solid/gear.svg + Font-Awesome-6.7.2/svgs/solid/wand-magic-sparkles.svg + Font-Awesome-6.7.2/svgs/solid/toolbox.svg diff --git a/src/ptychodus/view/scan.py b/src/ptychodus/view/scan.py index 3d1653f9..e2427790 100644 --- a/src/ptychodus/view/scan.py +++ b/src/ptychodus/view/scan.py @@ -11,18 +11,18 @@ class ScanPlotView(QWidget): def __init__(self, parent: QWidget | None) -> None: super().__init__(parent) self.figure = Figure() - self.figureCanvas = FigureCanvasQTAgg(self.figure) - self.navigationToolbar = NavigationToolbar(self.figureCanvas, self) + self.figure_canvas = FigureCanvasQTAgg(self.figure) + self.navigation_toolbar = NavigationToolbar(self.figure_canvas, self) self.axes = self.figure.add_subplot(111) @classmethod - def createInstance(cls, parent: QWidget | None = None) -> ScanPlotView: + def create_instance(cls, parent: QWidget | None = None) -> ScanPlotView: view = cls(parent) layout = QVBoxLayout() layout.setContentsMargins(0, 0, 0, 0) - layout.addWidget(view.navigationToolbar) - layout.addWidget(view.figureCanvas) + layout.addWidget(view.navigation_toolbar) + layout.addWidget(view.figure_canvas) view.setLayout(layout) return view diff --git a/src/ptychodus/view/settings.py b/src/ptychodus/view/settings.py index 7a3bfb55..552ad9d0 100644 --- a/src/ptychodus/view/settings.py +++ b/src/ptychodus/view/settings.py @@ -6,17 +6,17 @@ class SettingsButtonBox(QWidget): def __init__(self, parent: QWidget | None) -> None: super().__init__(parent) - self.openButton = QPushButton('Open') - self.saveButton = QPushButton('Save') + self.open_button = QPushButton('Open') + self.save_button = QPushButton('Save') @classmethod - def createInstance(cls, parent: QWidget | None = None) -> SettingsButtonBox: + def create_instance(cls, parent: QWidget | None = None) -> SettingsButtonBox: view = cls(parent) layout = QHBoxLayout() layout.setContentsMargins(0, 0, 0, 0) - layout.addWidget(view.openButton) - layout.addWidget(view.saveButton) + layout.addWidget(view.open_button) + layout.addWidget(view.save_button) view.setLayout(layout) return view @@ -25,16 +25,16 @@ def createInstance(cls, parent: QWidget | None = None) -> SettingsButtonBox: class SettingsView(QWidget): def __init__(self, parent: QWidget | None) -> None: super().__init__(parent) - self.listView = QListView() - self.buttonBox = SettingsButtonBox.createInstance() + self.list_view = QListView() + self.button_box = SettingsButtonBox.create_instance() @classmethod - def createInstance(cls, parent: QWidget | None = None) -> SettingsView: + def create_instance(cls, parent: QWidget | None = None) -> SettingsView: view = cls(parent) layout = QVBoxLayout() - layout.addWidget(view.listView) - layout.addWidget(view.buttonBox) + layout.addWidget(view.list_view) + layout.addWidget(view.button_box) view.setLayout(layout) return view diff --git a/src/ptychodus/view/test.py b/src/ptychodus/view/test.py index 14f340b5..3e4b48bb 100644 --- a/src/ptychodus/view/test.py +++ b/src/ptychodus/view/test.py @@ -6,8 +6,8 @@ FOUND_QT = True -class Tester: - def printQtStatus(self) -> None: +class QtTester: + def print_qt_status(self) -> None: if FOUND_QT: print('Qt FOUND!') else: @@ -15,12 +15,12 @@ def printQtStatus(self) -> None: if FOUND_QT: - def getWidget(self) -> QWidget: + def get_widget(self) -> QWidget: return QWidget() if __name__ == '__main__': - tester = Tester() - tester.printQtStatus() - widget = tester.getWidget() + tester = QtTester() + tester.print_qt_status() + widget = tester.get_widget() print(widget) diff --git a/src/ptychodus/view/visualization.py b/src/ptychodus/view/visualization.py index 07450446..d6265bc8 100644 --- a/src/ptychodus/view/visualization.py +++ b/src/ptychodus/view/visualization.py @@ -38,8 +38,8 @@ class ImageItemEvents(QObject): - rectangleFinished = pyqtSignal(QRectF) - lineCutFinished = pyqtSignal(QLineF) + rectangle_finished = pyqtSignal(QRectF) + line_cut_finished = pyqtSignal(QLineF) class ImageMouseTool(Enum): @@ -50,35 +50,35 @@ class ImageMouseTool(Enum): class ImageItem(QGraphicsPixmapItem): - def __init__(self, events: ImageItemEvents, statusBar: QStatusBar) -> None: + def __init__(self, events: ImageItemEvents, status_bar: QStatusBar) -> None: super().__init__() self._events = events - self._statusBar = statusBar + self._status_bar = status_bar self._product: VisualizationProduct | None = None - self._mouseTool = ImageMouseTool.MOVE_TOOL - self._lineItem = QGraphicsLineItem(self) - self._lineItem.hide() - self._rectangleItem = QGraphicsRectItem(self) - self._rectangleItem.hide() - self._rectangleOrigin = QPointF() + self._mouse_tool = ImageMouseTool.MOVE_TOOL + self._line_item = QGraphicsLineItem(self) + self._line_item.hide() + self._rectangle_item = QGraphicsRectItem(self) + self._rectangle_item.hide() + self._rectangle_origin = QPointF() self.setTransformationMode(Qt.TransformationMode.FastTransformation) self.setAcceptedMouseButtons(Qt.MouseButton.LeftButton) self.setAcceptHoverEvents(True) - def getProduct(self) -> VisualizationProduct | None: + def get_product(self) -> VisualizationProduct | None: return self._product - def setProduct(self, product: VisualizationProduct) -> None: - imageRGBAf = product.getImageRGBA() + def set_product(self, product: VisualizationProduct) -> None: + image_rgba_f = product.get_image_rgba() # NOTE .copy() ensures imageRGBAi is not a view - imageRGBAi = numpy.multiply(imageRGBAf, 255).astype(numpy.uint8).copy() + image_rgba_i = numpy.multiply(image_rgba_f, 255).astype(numpy.uint8).copy() try: image = QImage( - imageRGBAi.data, - imageRGBAi.shape[1], - imageRGBAi.shape[0], - imageRGBAi.strides[0], + image_rgba_i.data, + image_rgba_i.shape[1], + image_rgba_i.shape[0], + image_rgba_i.strides[0], QImage.Format.Format_RGBA8888, ) pixmap = QPixmap.fromImage(image) @@ -89,141 +89,141 @@ def setProduct(self, product: VisualizationProduct) -> None: self._product = product self.setPixmap(pixmap) - def clearProduct(self) -> None: + def clear_product(self) -> None: self._product = None self.setPixmap(QPixmap()) - def setMouseTool(self, mouseTool: ImageMouseTool) -> None: - self._mouseTool = mouseTool + def set_mouse_tool(self, mouse_tool: ImageMouseTool) -> None: + self._mouse_tool = mouse_tool - def hoverEnterEvent(self, event: QGraphicsSceneHoverEvent) -> None: + def hoverEnterEvent(self, event: QGraphicsSceneHoverEvent) -> None: # noqa: N802 app = QApplication.instance() if app: cursor = Qt.CursorShape.CrossCursor - if self._mouseTool == ImageMouseTool.MOVE_TOOL: + if self._mouse_tool == ImageMouseTool.MOVE_TOOL: cursor = Qt.CursorShape.OpenHandCursor app.setOverrideCursor(cursor) # type: ignore super().hoverEnterEvent(event) - def hoverMoveEvent(self, event: QGraphicsSceneHoverEvent) -> None: + def hoverMoveEvent(self, event: QGraphicsSceneHoverEvent) -> None: # noqa: N802 pos = event.pos() if self._product is not None: - infoText = self._product.getInfoText(pos.x(), pos.y()) - self._statusBar.showMessage(infoText) + info_text = self._product.get_info_text(pos.x(), pos.y()) + self._status_bar.showMessage(info_text) super().hoverMoveEvent(event) - def hoverLeaveEvent(self, event: QGraphicsSceneHoverEvent) -> None: + def hoverLeaveEvent(self, event: QGraphicsSceneHoverEvent) -> None: # noqa: N802 app = QApplication.instance() if app: app.restoreOverrideCursor() # type: ignore - self._statusBar.clearMessage() + self._status_bar.clearMessage() super().hoverLeaveEvent(event) - def _changeOverrideCursor(self, cursor: Qt.CursorShape) -> None: + def _change_override_cursor(self, cursor: Qt.CursorShape) -> None: app = QApplication.instance() if app: app.changeOverrideCursor(cursor) # type: ignore @staticmethod - def _createPen(color: Qt.GlobalColor) -> QPen: + def _create_pen(color: Qt.GlobalColor) -> QPen: pen = QPen(color) pen.setCapStyle(Qt.PenCapStyle.FlatCap) pen.setJoinStyle(Qt.PenJoinStyle.MiterJoin) pen.setCosmetic(True) return pen - def mousePressEvent(self, event: QGraphicsSceneMouseEvent) -> None: - if self._mouseTool == ImageMouseTool.MOVE_TOOL: - self._changeOverrideCursor(Qt.CursorShape.ClosedHandCursor) - elif self._mouseTool == ImageMouseTool.RULER_TOOL: + def mousePressEvent(self, event: QGraphicsSceneMouseEvent) -> None: # noqa: N802 + if self._mouse_tool == ImageMouseTool.MOVE_TOOL: + self._change_override_cursor(Qt.CursorShape.ClosedHandCursor) + elif self._mouse_tool == ImageMouseTool.RULER_TOOL: line = QLineF(event.pos(), event.pos()) self.prepareGeometryChange() - self._lineItem.setLine(line) - self._lineItem.setPen(self._createPen(Qt.GlobalColor.cyan)) - self._lineItem.show() - elif self._mouseTool == ImageMouseTool.RECTANGLE_TOOL: - self._rectangleOrigin = event.pos() - rect = QRectF(self._rectangleOrigin, QSizeF()) + self._line_item.setLine(line) + self._line_item.setPen(self._create_pen(Qt.GlobalColor.cyan)) + self._line_item.show() + elif self._mouse_tool == ImageMouseTool.RECTANGLE_TOOL: + self._rectangle_origin = event.pos() + rect = QRectF(self._rectangle_origin, QSizeF()) self.prepareGeometryChange() - self._rectangleItem.setRect(rect) - self._rectangleItem.setPen(self._createPen(Qt.GlobalColor.cyan)) - self._rectangleItem.show() - elif self._mouseTool == ImageMouseTool.LINE_CUT_TOOL: + self._rectangle_item.setRect(rect) + self._rectangle_item.setPen(self._create_pen(Qt.GlobalColor.cyan)) + self._rectangle_item.show() + elif self._mouse_tool == ImageMouseTool.LINE_CUT_TOOL: line = QLineF(event.pos(), event.pos()) self.prepareGeometryChange() - self._lineItem.setLine(line) - self._lineItem.setPen(self._createPen(Qt.GlobalColor.magenta)) - self._lineItem.show() + self._line_item.setLine(line) + self._line_item.setPen(self._create_pen(Qt.GlobalColor.magenta)) + self._line_item.show() - def mouseMoveEvent(self, event: QGraphicsSceneMouseEvent) -> None: - if self._mouseTool == ImageMouseTool.MOVE_TOOL: + def mouseMoveEvent(self, event: QGraphicsSceneMouseEvent) -> None: # noqa: N802 + if self._mouse_tool == ImageMouseTool.MOVE_TOOL: self.setPos(self.scenePos() + event.scenePos() - event.lastScenePos()) - elif self._mouseTool == ImageMouseTool.RULER_TOOL: - origin = self._lineItem.line().p1() + elif self._mouse_tool == ImageMouseTool.RULER_TOOL: + origin = self._line_item.line().p1() line = QLineF(origin, event.pos()) self.prepareGeometryChange() - self._lineItem.setLine(line) + self._line_item.setLine(line) message1 = f'{line.length():.1f} pixels, {line.angle():.2f}\u00b0' message2 = f'{line.dx():.1f} \u00d7 {line.dy():.1f}' - self._statusBar.showMessage(f'{message1} ({message2})') - elif self._mouseTool == ImageMouseTool.RECTANGLE_TOOL: - rect = QRectF(self._rectangleOrigin, event.pos()).normalized() + self._status_bar.showMessage(f'{message1} ({message2})') + elif self._mouse_tool == ImageMouseTool.RECTANGLE_TOOL: + rect = QRectF(self._rectangle_origin, event.pos()).normalized() center = rect.center() self.prepareGeometryChange() - self._rectangleItem.setRect(rect) + self._rectangle_item.setRect(rect) message1 = f'{rect.width():.1f} \u00d7 {rect.height():.1f}' message2 = f'{center.x():.1f}, {center.y():.1f}' - self._statusBar.showMessage(f'Rectangle: {message1} (Center: {message2})') - elif self._mouseTool == ImageMouseTool.LINE_CUT_TOOL: - origin = self._lineItem.line().p1() + self._status_bar.showMessage(f'Rectangle: {message1} (Center: {message2})') + elif self._mouse_tool == ImageMouseTool.LINE_CUT_TOOL: + origin = self._line_item.line().p1() line = QLineF(origin, event.pos()) self.prepareGeometryChange() - self._lineItem.setLine(line) + self._line_item.setLine(line) message1 = f'{line.length():.1f} pixels, {line.angle():.2f}\u00b0' message2 = f'{line.dx():.1f} \u00d7 {line.dy():.1f}' - self._statusBar.showMessage(f'{message1} ({message2})') - - def mouseReleaseEvent(self, event: QGraphicsSceneMouseEvent) -> None: - if self._mouseTool == ImageMouseTool.MOVE_TOOL: - self._changeOverrideCursor(Qt.CursorShape.OpenHandCursor) - elif self._mouseTool == ImageMouseTool.RULER_TOOL: - self._lineItem.setLine(QLineF()) - self._lineItem.hide() - elif self._mouseTool == ImageMouseTool.RECTANGLE_TOOL: - self._events.rectangleFinished.emit(self._rectangleItem.rect()) - self._rectangleItem.setRect(QRectF()) - self._rectangleItem.hide() - elif self._mouseTool == ImageMouseTool.LINE_CUT_TOOL: - self._events.lineCutFinished.emit(self._lineItem.line()) - self._lineItem.setLine(QLineF()) - self._lineItem.hide() + self._status_bar.showMessage(f'{message1} ({message2})') + + def mouseReleaseEvent(self, event: QGraphicsSceneMouseEvent) -> None: # noqa: N802 + if self._mouse_tool == ImageMouseTool.MOVE_TOOL: + self._change_override_cursor(Qt.CursorShape.OpenHandCursor) + elif self._mouse_tool == ImageMouseTool.RULER_TOOL: + self._line_item.setLine(QLineF()) + self._line_item.hide() + elif self._mouse_tool == ImageMouseTool.RECTANGLE_TOOL: + self._events.rectangle_finished.emit(self._rectangle_item.rect()) + self._rectangle_item.setRect(QRectF()) + self._rectangle_item.hide() + elif self._mouse_tool == ImageMouseTool.LINE_CUT_TOOL: + self._events.line_cut_finished.emit(self._line_item.line()) + self._line_item.setLine(QLineF()) + self._line_item.hide() class LineCutDialog(QDialog): def __init__(self, parent: QWidget | None) -> None: super().__init__(parent) self.figure = Figure() - self.figureCanvas = FigureCanvasQTAgg(self.figure) - self.navigationToolbar = NavigationToolbar(self.figureCanvas, self) + self.figure_canvas = FigureCanvasQTAgg(self.figure) + self.navigation_toolbar = NavigationToolbar(self.figure_canvas, self) self.axes = self.figure.add_subplot(111) @classmethod - def createInstance(cls, parent: QWidget | None = None) -> LineCutDialog: + def create_instance(cls, parent: QWidget | None = None) -> LineCutDialog: view = cls(parent) view.setWindowTitle('Line-Cut Dialog') layout = QVBoxLayout() - layout.addWidget(view.navigationToolbar) - layout.addWidget(view.figureCanvas) + layout.addWidget(view.navigation_toolbar) + layout.addWidget(view.figure_canvas) view.setLayout(layout) return view @@ -231,35 +231,35 @@ def createInstance(cls, parent: QWidget | None = None) -> LineCutDialog: class RectangleView(QGroupBox): @staticmethod - def _createReadOnlyLineEdit() -> QLineEdit: - lineEdit = QLineEdit() + def _create_read_only_line_edit() -> QLineEdit: + line_edit = QLineEdit() - palette = lineEdit.palette() + palette = line_edit.palette() palette.setColor(QPalette.Base, palette.color(QPalette.Window)) palette.setColor(QPalette.Text, palette.color(QPalette.WindowText)) - lineEdit.setPalette(palette) + line_edit.setPalette(palette) - lineEdit.setFocusPolicy(Qt.NoFocus) - lineEdit.setReadOnly(True) + line_edit.setFocusPolicy(Qt.NoFocus) + line_edit.setReadOnly(True) - return lineEdit + return line_edit def __init__(self, parent: QWidget | None) -> None: super().__init__('Rectangle', parent) - self.centerXLineEdit = RectangleView._createReadOnlyLineEdit() - self.centerYLineEdit = RectangleView._createReadOnlyLineEdit() - self.widthLineEdit = RectangleView._createReadOnlyLineEdit() - self.heightLineEdit = RectangleView._createReadOnlyLineEdit() + self.center_x_line_edit = RectangleView._create_read_only_line_edit() + self.center_y_line_edit = RectangleView._create_read_only_line_edit() + self.width_line_edit = RectangleView._create_read_only_line_edit() + self.height_line_edit = RectangleView._create_read_only_line_edit() @classmethod - def createInstance(cls, parent: QWidget | None = None) -> RectangleView: + def create_instance(cls, parent: QWidget | None = None) -> RectangleView: view = cls(parent) layout = QFormLayout() - layout.addRow('Center X:', view.centerXLineEdit) - layout.addRow('Center Y:', view.centerYLineEdit) - layout.addRow('Width:', view.widthLineEdit) - layout.addRow('Height:', view.heightLineEdit) + layout.addRow('Center X:', view.center_x_line_edit) + layout.addRow('Center Y:', view.center_y_line_edit) + layout.addRow('Width:', view.width_line_edit) + layout.addRow('Height:', view.height_line_edit) view.setLayout(layout) return view @@ -269,63 +269,63 @@ class HistogramDialog(QDialog): def __init__(self, parent: QWidget | None) -> None: super().__init__(parent) self.figure = Figure() - self.figureCanvas = FigureCanvasQTAgg(self.figure) - self.navigationToolbar = NavigationToolbar(self.figureCanvas, self) + self.figure_canvas = FigureCanvasQTAgg(self.figure) + self.navigation_toolbar = NavigationToolbar(self.figure_canvas, self) self.axes = self.figure.add_subplot(111) - self.rectangleView = RectangleView.createInstance() + self.rectangle_view = RectangleView.create_instance() @classmethod - def createInstance(cls, parent: QWidget | None = None) -> HistogramDialog: + def create_instance(cls, parent: QWidget | None = None) -> HistogramDialog: view = cls(parent) view.setWindowTitle('Histogram') layout = QVBoxLayout() - layout.addWidget(view.navigationToolbar) - layout.addWidget(view.figureCanvas, 1) - layout.addWidget(view.rectangleView) + layout.addWidget(view.navigation_toolbar) + layout.addWidget(view.figure_canvas, 1) + layout.addWidget(view.rectangle_view) view.setLayout(layout) return view class VisualizationView(QGraphicsView): - def wheelEvent(self, event: QWheelEvent) -> None: - oldPosition = self.mapToScene(event.pos()) + def wheelEvent(self, event: QWheelEvent) -> None: # noqa: N802 + old_position = self.mapToScene(event.pos()) - zoomBase = 1.25 - zoom = zoomBase if event.angleDelta().y() > 0 else 1.0 / zoomBase + zoom_base = 1.25 + zoom = zoom_base if event.angleDelta().y() > 0 else 1.0 / zoom_base self.scale(zoom, zoom) - newPosition = self.mapToScene(event.pos()) + new_position = self.mapToScene(event.pos()) - deltaPosition = newPosition - oldPosition - self.translate(deltaPosition.x(), deltaPosition.y()) + delta_position = new_position - old_position + self.translate(delta_position.x(), delta_position.y()) class VisualizationWidget(QGroupBox): def __init__(self, title: str, parent: QWidget | None) -> None: super().__init__(title, parent) - self.toolBar = QToolBar('Tools') - self.homeAction = QAction(QIcon(':/icons/home'), 'Home') - self.saveAction = QAction(QIcon(':/icons/save'), 'Save Image') - self.autoscaleAction = QAction(QIcon(':/icons/autoscale'), 'Autoscale Color Axis') - self.visualizationView = VisualizationView() + self.tool_bar = QToolBar('Tools') + self.home_action = QAction(QIcon(':/icons/home'), 'Home') + self.save_action = QAction(QIcon(':/icons/save'), 'Save Image') + self.autoscale_action = QAction(QIcon(':/icons/autoscale'), 'Autoscale Color Axis') + self.visualization_view = VisualizationView() @classmethod - def createInstance(cls, title: str, parent: QWidget | None = None) -> VisualizationWidget: + def create_instance(cls, title: str, parent: QWidget | None = None) -> VisualizationWidget: view = cls(title, parent) view.setAlignment(Qt.AlignHCenter) - view.toolBar.setFloatable(False) - view.toolBar.setIconSize(QSize(32, 32)) - view.toolBar.addAction(view.homeAction) - view.toolBar.addAction(view.saveAction) - view.toolBar.addAction(view.autoscaleAction) + view.tool_bar.setFloatable(False) + view.tool_bar.setIconSize(QSize(32, 32)) + view.tool_bar.addAction(view.home_action) + view.tool_bar.addAction(view.save_action) + view.tool_bar.addAction(view.autoscale_action) layout = QVBoxLayout() layout.setContentsMargins(0, 0, 0, 0) - layout.addWidget(view.toolBar) - layout.addWidget(view.visualizationView) + layout.addWidget(view.tool_bar) + layout.addWidget(view.visualization_view) view.setLayout(layout) return view @@ -335,22 +335,22 @@ class VisualizationParametersView(QGroupBox): def __init__(self, parent: QWidget | None) -> None: super().__init__('Visualization', parent) - self.rendererComboBox = QComboBox() - self.transformationComboBox = QComboBox() - self.variantComboBox = QComboBox() - self.minDisplayValueLineEdit = DecimalLineEdit.createInstance() - self.maxDisplayValueLineEdit = DecimalLineEdit.createInstance() + self.renderer_combo_box = QComboBox() + self.transformation_combo_box = QComboBox() + self.variant_combo_box = QComboBox() + self.min_display_value_line_edit = DecimalLineEdit.create_instance() + self.max_display_value_line_edit = DecimalLineEdit.create_instance() @classmethod - def createInstance(cls, parent: QWidget | None = None) -> VisualizationParametersView: + def create_instance(cls, parent: QWidget | None = None) -> VisualizationParametersView: view = cls(parent) layout = QFormLayout() - layout.addRow('Renderer:', view.rendererComboBox) - layout.addRow('Transform:', view.transformationComboBox) - layout.addRow('Variant:', view.variantComboBox) - layout.addRow('Min Display Value:', view.minDisplayValueLineEdit) - layout.addRow('Max Display Value:', view.maxDisplayValueLineEdit) + layout.addRow('Renderer:', view.renderer_combo_box) + layout.addRow('Transform:', view.transformation_combo_box) + layout.addRow('Variant:', view.variant_combo_box) + layout.addRow('Min Display Value:', view.min_display_value_line_edit) + layout.addRow('Max Display Value:', view.max_display_value_line_edit) view.setLayout(layout) return view diff --git a/src/ptychodus/view/widgets/__init__.py b/src/ptychodus/view/widgets/__init__.py index 2cde33ad..da4d2ecb 100644 --- a/src/ptychodus/view/widgets/__init__.py +++ b/src/ptychodus/view/widgets/__init__.py @@ -1,12 +1,12 @@ -from .angleWidget import AngleWidget -from .groupBox import BottomTitledGroupBox, GroupBoxWithPresets -from .comboBoxItemDelegate import ComboBoxItemDelegate -from .decimalLineEdit import DecimalLineEdit -from .decimalSlider import DecimalSlider -from .exceptionDialog import ExceptionDialog -from .lengthWidget import LengthWidget -from .progressBarItemDelegate import ProgressBarItemDelegate -from .uuidLineEdit import UUIDLineEdit +from .angle_widget import AngleWidget +from .group_box import BottomTitledGroupBox, GroupBoxWithPresets +from .combo_box_item_delegate import ComboBoxItemDelegate +from .decimal_line_edit import DecimalLineEdit +from .decimal_slider import DecimalSlider +from .exception_dialog import ExceptionDialog +from .length_widget import LengthWidget +from .progress_bar_item_delegate import ProgressBarItemDelegate +from .uuid_line_edit import UUIDLineEdit __all__ = [ 'AngleWidget', diff --git a/src/ptychodus/view/widgets/angleWidget.py b/src/ptychodus/view/widgets/angleWidget.py deleted file mode 100644 index 9ac23354..00000000 --- a/src/ptychodus/view/widgets/angleWidget.py +++ /dev/null @@ -1,59 +0,0 @@ -from __future__ import annotations -from decimal import Decimal - -from PyQt5.QtCore import pyqtSignal -from PyQt5.QtWidgets import QComboBox, QHBoxLayout, QWidget -import numpy - -from .decimalLineEdit import DecimalLineEdit - - -class AngleWidget(QWidget): - angleChanged = pyqtSignal(Decimal) - - def __init__(self, parent: QWidget | None) -> None: - super().__init__(parent) - self.angleInTurns = Decimal() - self.angleLineEdit = DecimalLineEdit.createInstance(isSigned=False) - self.unitsComboBox = QComboBox() - - @classmethod - def createInstance(cls, parent: QWidget | None = None) -> AngleWidget: - widget = cls(parent) - - widget.angleLineEdit.valueChanged.connect(widget._setAngleInTurnsFromWidgets) - - widget.unitsComboBox.addItem('turn', Decimal(1)) - widget.unitsComboBox.addItem('deg', Decimal(360)) - widget.unitsComboBox.addItem('rad', 2 * Decimal.from_float(numpy.pi)) - widget.unitsComboBox.activated.connect(widget._updateDisplay) - - layout = QHBoxLayout() - layout.setContentsMargins(0, 0, 0, 0) - layout.addWidget(widget.angleLineEdit) - layout.addWidget(widget.unitsComboBox) - widget.setLayout(layout) - - return widget - - def isReadOnly(self) -> bool: - return self.angleLineEdit.isReadOnly() - - def setReadOnly(self, enable: bool) -> None: - self.angleLineEdit.setReadOnly(enable) - - def getAngleInTurns(self) -> Decimal: - return self.angleInTurns - - def setAngleInTurns(self, angleInTurns: Decimal) -> None: - self.angleInTurns = angleInTurns - self._updateDisplay() - self.angleChanged.emit(self.getAngleInTurns()) - - def _setAngleInTurnsFromWidgets(self, angle: Decimal) -> None: - self.angleInTurns = angle / self.unitsComboBox.currentData() - self.angleChanged.emit(self.angleInTurns) - - def _updateDisplay(self) -> None: - angleInDisplayUnits = self.angleInTurns * self.unitsComboBox.currentData() - self.angleLineEdit.setValue(angleInDisplayUnits) diff --git a/src/ptychodus/view/widgets/angle_widget.py b/src/ptychodus/view/widgets/angle_widget.py new file mode 100644 index 00000000..f60edae2 --- /dev/null +++ b/src/ptychodus/view/widgets/angle_widget.py @@ -0,0 +1,59 @@ +from __future__ import annotations +from decimal import Decimal + +from PyQt5.QtCore import pyqtSignal +from PyQt5.QtWidgets import QComboBox, QHBoxLayout, QWidget +import numpy + +from .decimal_line_edit import DecimalLineEdit + + +class AngleWidget(QWidget): + angle_changed = pyqtSignal(Decimal) + + def __init__(self, parent: QWidget | None) -> None: + super().__init__(parent) + self.angle_in_turns = Decimal() + self.angle_line_edit = DecimalLineEdit.create_instance(is_signed=False) + self.units_combo_box = QComboBox() + + @classmethod + def create_instance(cls, parent: QWidget | None = None) -> AngleWidget: + widget = cls(parent) + + widget.angle_line_edit.value_changed.connect(widget._set_angle_in_turns_from_widgets) + + widget.units_combo_box.addItem('turn', Decimal(1)) + widget.units_combo_box.addItem('deg', Decimal(360)) + widget.units_combo_box.addItem('rad', 2 * Decimal.from_float(numpy.pi)) + widget.units_combo_box.activated.connect(widget._update_display) + + layout = QHBoxLayout() + layout.setContentsMargins(0, 0, 0, 0) + layout.addWidget(widget.angle_line_edit) + layout.addWidget(widget.units_combo_box) + widget.setLayout(layout) + + return widget + + def is_read_only(self) -> bool: + return self.angle_line_edit.is_read_only() + + def set_read_only(self, enable: bool) -> None: + self.angle_line_edit.set_read_only(enable) + + def get_angle_in_turns(self) -> Decimal: + return self.angle_in_turns + + def set_angle_in_turns(self, angle_in_turns: Decimal) -> None: + self.angle_in_turns = angle_in_turns + self._update_display() + self.angle_changed.emit(self.get_angle_in_turns()) + + def _set_angle_in_turns_from_widgets(self, angle: Decimal) -> None: + self.angle_in_turns = angle / self.units_combo_box.currentData() + self.angle_changed.emit(self.angle_in_turns) + + def _update_display(self) -> None: + angle_in_display_units = self.angle_in_turns * self.units_combo_box.currentData() + self.angle_line_edit.set_value(angle_in_display_units) diff --git a/src/ptychodus/view/widgets/comboBoxItemDelegate.py b/src/ptychodus/view/widgets/combo_box_item_delegate.py similarity index 73% rename from src/ptychodus/view/widgets/comboBoxItemDelegate.py rename to src/ptychodus/view/widgets/combo_box_item_delegate.py index 7dbaa735..e938ddf9 100644 --- a/src/ptychodus/view/widgets/comboBoxItemDelegate.py +++ b/src/ptychodus/view/widgets/combo_box_item_delegate.py @@ -21,10 +21,10 @@ class ComboBoxItemDelegate(QStyledItemDelegate): def __init__(self, model: QAbstractItemModel, parent: QObject | None = None) -> None: super().__init__(parent) self._model = model - self._paintComboBox = False + self._paint_combo_box = False def paint(self, painter: QPainter, option: QStyleOptionViewItem, index: QModelIndex) -> None: - if self._paintComboBox and index.flags() & Qt.ItemFlag.ItemIsEditable: + if self._paint_combo_box and index.flags() & Qt.ItemFlag.ItemIsEditable: opt = QStyleOptionComboBox() opt.rect = option.rect opt.currentText = index.data(Qt.DisplayRole) @@ -33,38 +33,38 @@ def paint(self, painter: QPainter, option: QStyleOptionViewItem, index: QModelIn else: super().paint(painter, option, index) - def createEditor( + def createEditor( # noqa: N802 self, parent: QWidget, option: QStyleOptionViewItem, index: QModelIndex ) -> QWidget: - comboBox = QComboBox(parent) - comboBox.activated.connect(self._commitDataAndCloseEditor) - comboBox.setModel(self._model) - return comboBox + combo_box = QComboBox(parent) + combo_box.activated.connect(self._commit_data_and_close_editor) + combo_box.setModel(self._model) + return combo_box - def setEditorData(self, editor: QWidget, index: QModelIndex) -> None: + def setEditorData(self, editor: QWidget, index: QModelIndex) -> None: # noqa: N802 if isinstance(editor, QComboBox): - currentText = str(index.data(Qt.EditRole)) - comboBoxIndex = editor.findText(currentText) + current_text = str(index.data(Qt.EditRole)) + combo_box_index = editor.findText(current_text) - if comboBoxIndex >= 0: - editor.setCurrentIndex(comboBoxIndex) + if combo_box_index >= 0: + editor.setCurrentIndex(combo_box_index) editor.showPopup() else: super().setEditorData(editor, index) - def setModelData(self, editor: QWidget, model: QAbstractItemModel, index: QModelIndex) -> None: + def setModelData(self, editor: QWidget, model: QAbstractItemModel, index: QModelIndex) -> None: # noqa: N802 if isinstance(editor, QComboBox): model.setData(index, editor.currentText(), Qt.EditRole) else: super().setModelData(editor, model, index) - def updateEditorGeometry( + def updateEditorGeometry( # noqa: N802 self, editor: QWidget, option: QStyleOptionViewItem, index: QModelIndex ) -> None: editor.setGeometry(option.rect) - def _commitDataAndCloseEditor(self) -> None: + def _commit_data_and_close_editor(self) -> None: editor = self.sender() if isinstance(editor, QComboBox): diff --git a/src/ptychodus/view/widgets/decimalLineEdit.py b/src/ptychodus/view/widgets/decimalLineEdit.py deleted file mode 100644 index 47f7b28d..00000000 --- a/src/ptychodus/view/widgets/decimalLineEdit.py +++ /dev/null @@ -1,100 +0,0 @@ -from __future__ import annotations -from decimal import Decimal -import logging - -from PyQt5.QtCore import pyqtSignal -from PyQt5.QtGui import QDoubleValidator -from PyQt5.QtWidgets import QHBoxLayout, QLineEdit, QWidget - -logger = logging.getLogger(__name__) - - -class DecimalLineEdit(QWidget): - valueChanged = pyqtSignal(Decimal) - - def __init__(self, parent: QWidget | None) -> None: - super().__init__(parent) - self._validator = QDoubleValidator() - self._lineEdit = QLineEdit() - self._value = Decimal() - self._minimum: Decimal | None = None - self._maximum: Decimal | None = None - - @classmethod - def createInstance( - cls, *, isSigned: bool = False, parent: QWidget | None = None - ) -> DecimalLineEdit: - widget = cls(parent) - - widget._lineEdit.setValidator(widget._validator) - widget._lineEdit.editingFinished.connect(widget._setValueFromLineEdit) - widget._setValueToLineEditAndEmitValueChanged() - - layout = QHBoxLayout() - layout.setContentsMargins(0, 0, 0, 0) - layout.addWidget(widget._lineEdit) - widget.setLayout(layout) - - if not isSigned: - widget._validator.setBottom(0.0) - - return widget - - def isReadOnly(self) -> bool: - return self._lineEdit.isReadOnly() - - def setReadOnly(self, enable: bool) -> None: - self._lineEdit.setReadOnly(enable) - - def getValue(self) -> Decimal: - if self._minimum is not None and self._value < self._minimum: - return self._minimum - - if self._maximum is not None and self._value > self._maximum: - return self._maximum - - return self._value - - def setValue(self, value: Decimal) -> None: - if value != self._value: - self._value = value - self._setValueToLineEditAndEmitValueChanged() - - def getMinimum(self) -> Decimal | None: - return self._minimum - - def setMinimum(self, value: Decimal) -> None: - valueBefore = self.getValue() - self._minimum = value - valueAfter = self.getValue() - - if valueBefore != valueAfter: - self._setValueToLineEditAndEmitValueChanged() - - def getMaximum(self) -> Decimal | None: - return self._maximum - - def setMaximum(self, value: Decimal) -> None: - valueBefore = self.getValue() - self._maximum = value - valueAfter = self.getValue() - - if valueBefore != valueAfter: - self._setValueToLineEditAndEmitValueChanged() - - def _setValueFromLineEdit(self) -> None: - decimalText = self._lineEdit.text() - - try: - self._value = Decimal(decimalText) - except ValueError: - logger.error(f'Failed to parse value "{decimalText}"') - else: - self._emitValueChanged() - - def _setValueToLineEditAndEmitValueChanged(self) -> None: - self._lineEdit.setText(str(self.getValue())) - self._emitValueChanged() - - def _emitValueChanged(self) -> None: - self.valueChanged.emit(self._value) diff --git a/src/ptychodus/view/widgets/decimal_line_edit.py b/src/ptychodus/view/widgets/decimal_line_edit.py new file mode 100644 index 00000000..bcad9536 --- /dev/null +++ b/src/ptychodus/view/widgets/decimal_line_edit.py @@ -0,0 +1,100 @@ +from __future__ import annotations +from decimal import Decimal +import logging + +from PyQt5.QtCore import pyqtSignal +from PyQt5.QtGui import QDoubleValidator +from PyQt5.QtWidgets import QHBoxLayout, QLineEdit, QWidget + +logger = logging.getLogger(__name__) + + +class DecimalLineEdit(QWidget): + value_changed = pyqtSignal(Decimal) + + def __init__(self, parent: QWidget | None) -> None: + super().__init__(parent) + self._validator = QDoubleValidator() + self._line_edit = QLineEdit() + self._value = Decimal() + self._minimum: Decimal | None = None + self._maximum: Decimal | None = None + + @classmethod + def create_instance( + cls, *, is_signed: bool = False, parent: QWidget | None = None + ) -> DecimalLineEdit: + widget = cls(parent) + + widget._line_edit.setValidator(widget._validator) + widget._line_edit.editingFinished.connect(widget._set_value_from_line_edit) + widget._set_value_to_line_edit_and_emit_value_changed() + + layout = QHBoxLayout() + layout.setContentsMargins(0, 0, 0, 0) + layout.addWidget(widget._line_edit) + widget.setLayout(layout) + + if not is_signed: + widget._validator.setBottom(0.0) + + return widget + + def is_read_only(self) -> bool: + return self._line_edit.isReadOnly() + + def set_read_only(self, enable: bool) -> None: + self._line_edit.setReadOnly(enable) + + def get_value(self) -> Decimal: + if self._minimum is not None and self._value < self._minimum: + return self._minimum + + if self._maximum is not None and self._value > self._maximum: + return self._maximum + + return self._value + + def set_value(self, value: Decimal) -> None: + if value != self._value: + self._value = value + self._set_value_to_line_edit_and_emit_value_changed() + + def get_minimum(self) -> Decimal | None: + return self._minimum + + def set_minimum(self, value: Decimal) -> None: + value_before = self.get_value() + self._minimum = value + value_after = self.get_value() + + if value_before != value_after: + self._set_value_to_line_edit_and_emit_value_changed() + + def get_maximum(self) -> Decimal | None: + return self._maximum + + def set_maximum(self, value: Decimal) -> None: + value_before = self.get_value() + self._maximum = value + value_after = self.get_value() + + if value_before != value_after: + self._set_value_to_line_edit_and_emit_value_changed() + + def _set_value_from_line_edit(self) -> None: + decimal_text = self._line_edit.text() + + try: + self._value = Decimal(decimal_text) + except ValueError: + logger.error(f'Failed to parse value "{decimal_text}"') + else: + self._emit_value_changed() + + def _set_value_to_line_edit_and_emit_value_changed(self) -> None: + self._line_edit.setText(str(self.get_value())) + self._emit_value_changed() + + def _emit_value_changed(self) -> None: + self.value_changed.emit(self._value) diff --git a/src/ptychodus/view/widgets/decimalSlider.py b/src/ptychodus/view/widgets/decimal_slider.py similarity index 66% rename from src/ptychodus/view/widgets/decimalSlider.py rename to src/ptychodus/view/widgets/decimal_slider.py index 525b7569..f2194c2a 100644 --- a/src/ptychodus/view/widgets/decimalSlider.py +++ b/src/ptychodus/view/widgets/decimal_slider.py @@ -10,7 +10,7 @@ class DecimalSlider(QWidget): - valueChanged = pyqtSignal(Decimal) + value_changed = pyqtSignal(Decimal) def __init__(self, slider: QSlider, parent: QWidget | None) -> None: super().__init__(parent) @@ -21,21 +21,21 @@ def __init__(self, slider: QSlider, parent: QWidget | None) -> None: self._maximum = Decimal() @classmethod - def createInstance( + def create_instance( cls, orientation: Qt.Orientation, parent: QWidget | None = None, *, - numberOfTicks: int = 1000, + num_ticks: int = 1000, ) -> DecimalSlider: slider = QSlider(orientation) - slider.setRange(0, numberOfTicks) + slider.setRange(0, num_ticks) slider.setTickPosition(QSlider.TickPosition.TicksBelow) slider.setTickInterval(100) widget = cls(slider, parent) - slider.valueChanged.connect(lambda value: widget._setValueFromSlider()) - widget.setValueAndRange(Decimal(1) / 2, Interval[Decimal](Decimal(0), Decimal(1))) + slider.valueChanged.connect(lambda value: widget._set_value_from_slider()) + widget.set_value_and_range(Decimal(1) / 2, Interval[Decimal](Decimal(0), Decimal(1))) layout = QHBoxLayout() layout.setContentsMargins(0, 0, 0, 0) @@ -45,39 +45,39 @@ def createInstance( return widget - def getValue(self) -> Decimal: + def get_value(self) -> Decimal: return self._value - def setValue(self, value: Decimal) -> None: - if self._setValueToSlider(value): - self._emitValueChanged() + def set_value(self, value: Decimal) -> None: + if self._set_value_to_slider(value): + self._emit_value_changed() - def setValueAndRange( + def set_value_and_range( self, value: Decimal, range_: Interval[Decimal], - blockValueChangedSignal: bool = False, + block_value_changed_signal: bool = False, ) -> None: - shouldEmit = False + should_emit = False if range_.upper <= range_.lower: raise ValueError(f'maximum <= minimum ({range_.upper} <= {range_.lower})') if range_.lower != self._minimum: self._minimum = range_.lower - shouldEmit = True + should_emit = True if range_.upper != self._maximum: self._maximum = range_.upper - shouldEmit = True + should_emit = True - if self._setValueToSlider(value): - shouldEmit = True + if self._set_value_to_slider(value): + should_emit = True - if not blockValueChangedSignal and shouldEmit: - self._emitValueChanged() + if not block_value_changed_signal and should_emit: + self._emit_value_changed() - def _setValueFromSlider(self) -> None: + def _set_value_from_slider(self) -> None: upper = Decimal(self._slider.value() - self._slider.minimum()) lower = Decimal(self._slider.maximum() - self._slider.minimum()) alpha = upper / lower @@ -85,11 +85,11 @@ def _setValueFromSlider(self) -> None: if value != self._value: self._value = value - self._updateLabel() - self._emitValueChanged() + self._update_label() + self._emit_value_changed() - def _setValueToSlider(self, value: Decimal) -> bool: - shouldEmit = False + def _set_value_to_slider(self, value: Decimal) -> bool: + should_emit = False alpha = (Decimal(value) - self._minimum) / (self._maximum - self._minimum) ivaluef = (1 - alpha) * self._slider.minimum() + alpha * self._slider.maximum() @@ -108,13 +108,13 @@ def _setValueToSlider(self, value: Decimal) -> bool: if value != self._value: self._value = Decimal(value) - self._updateLabel() - shouldEmit = True + self._update_label() + should_emit = True - return shouldEmit + return should_emit - def _updateLabel(self) -> None: + def _update_label(self) -> None: self._label.setText(f'{self._value:.3f}') - def _emitValueChanged(self) -> None: - self.valueChanged.emit(self._value) + def _emit_value_changed(self) -> None: + self.value_changed.emit(self._value) diff --git a/src/ptychodus/view/widgets/exceptionDialog.py b/src/ptychodus/view/widgets/exceptionDialog.py deleted file mode 100644 index a390c302..00000000 --- a/src/ptychodus/view/widgets/exceptionDialog.py +++ /dev/null @@ -1,15 +0,0 @@ -import traceback - -from PyQt5.QtWidgets import QMessageBox - - -class ExceptionDialog(QMessageBox): - @classmethod - def showException(cls, actor: str, exception: Exception) -> None: - dialog = cls() - dialog.setWindowTitle('Exception Dialog') - dialog.setIcon(QMessageBox.Icon.Critical) - dialog.setText(f'{actor} raised a {exception.__class__.__name__}!') - dialog.setInformativeText(str(exception)) - dialog.setDetailedText(traceback.format_exc()) - _ = dialog.open() diff --git a/src/ptychodus/view/widgets/exception_dialog.py b/src/ptychodus/view/widgets/exception_dialog.py new file mode 100644 index 00000000..e455bfc5 --- /dev/null +++ b/src/ptychodus/view/widgets/exception_dialog.py @@ -0,0 +1,43 @@ +from typing import Final +import traceback + +from PyQt5.QtCore import QEvent +from PyQt5.QtWidgets import QMessageBox, QSizePolicy, QTextEdit + + +class ExceptionDialog(QMessageBox): + MIN_SIZE: Final[int] = 0 + MAX_SIZE: Final[int] = 16777215 + + @classmethod + def show_exception(cls, actor: str, exception: Exception) -> None: + dialog = cls() + dialog.setSizeGripEnabled(True) + dialog.setWindowTitle('Exception Dialog') + dialog.setIcon(QMessageBox.Icon.Critical) + dialog.setText(f'{actor} raised a {exception.__class__.__name__}!') + dialog.setInformativeText(str(exception)) + dialog.setDetailedText(traceback.format_exc()) + _ = dialog.exec() + + def event(self, event: QEvent) -> bool: + result = super().event(event) + + if event.type() == QEvent.LayoutRequest or event.type() == QEvent.Resize: + self.setMinimumHeight(self.MIN_SIZE) + self.setMaximumHeight(self.MAX_SIZE) + self.setMinimumWidth(self.MIN_SIZE) + self.setMaximumWidth(self.MAX_SIZE) + self.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding) + + text_edit = self.findChild(QTextEdit) + + if text_edit is not None: + # make the detailed text expandable + text_edit.setMinimumHeight(self.MIN_SIZE) + text_edit.setMaximumHeight(self.MAX_SIZE) + text_edit.setMinimumWidth(self.MIN_SIZE) + text_edit.setMaximumWidth(self.MAX_SIZE) + text_edit.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding) + + return result diff --git a/src/ptychodus/view/widgets/groupBox.py b/src/ptychodus/view/widgets/group_box.py similarity index 64% rename from src/ptychodus/view/widgets/groupBox.py rename to src/ptychodus/view/widgets/group_box.py index f4c0220c..9ffe2004 100644 --- a/src/ptychodus/view/widgets/groupBox.py +++ b/src/ptychodus/view/widgets/group_box.py @@ -15,29 +15,29 @@ class GroupBoxWithPresets(QWidget): def __init__(self, title: str, parent: QWidget | None = None) -> None: super().__init__(parent) - self._titleLabel = QLabel(title) + self._title_label = QLabel(title) - self.presetsMenu = QMenu() - self._presetsButton = QToolButton() - self._presetsButton.setText('Presets ') - self._presetsButton.setToolButtonStyle(Qt.ToolButtonTextOnly) - self._presetsButton.setMenu(self.presetsMenu) - self._presetsButton.setPopupMode(QToolButton.InstantPopup) + self.presets_menu = QMenu() + self._presets_button = QToolButton() + self._presets_button.setText('Presets ') + self._presets_button.setToolButtonStyle(Qt.ToolButtonTextOnly) + self._presets_button.setMenu(self.presets_menu) + self._presets_button.setPopupMode(QToolButton.InstantPopup) self.contents = QWidget() - frameLayout = QVBoxLayout() - frameLayout.addWidget(self.contents) + frame_layout = QVBoxLayout() + frame_layout.addWidget(self.contents) self._frame = QFrame() self._frame.setFrameShape(QFrame.StyledPanel) self._frame.setFrameShadow(QFrame.Plain) self._frame.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Maximum) - self._frame.setLayout(frameLayout) + self._frame.setLayout(frame_layout) layout = QGridLayout() layout.setContentsMargins(0, 0, 0, 0) - layout.addWidget(self._titleLabel, 0, 0) - layout.addWidget(self._presetsButton, 0, 1, Qt.AlignLeft) + layout.addWidget(self._title_label, 0, 0) + layout.addWidget(self._presets_button, 0, 1, Qt.AlignLeft) layout.addWidget(self._frame, 1, 0, 1, 2) layout.setColumnStretch(1, 1) self.setLayout(layout) diff --git a/src/ptychodus/view/widgets/lengthWidget.py b/src/ptychodus/view/widgets/lengthWidget.py deleted file mode 100644 index 081aaeae..00000000 --- a/src/ptychodus/view/widgets/lengthWidget.py +++ /dev/null @@ -1,80 +0,0 @@ -from __future__ import annotations -from decimal import Decimal, ROUND_FLOOR - -from PyQt5.QtCore import pyqtSignal -from PyQt5.QtWidgets import QComboBox, QHBoxLayout, QWidget - -from .decimalLineEdit import DecimalLineEdit - - -class LengthWidget(QWidget): - lengthChanged = pyqtSignal(Decimal) - - def __init__(self, isSigned: bool, parent: QWidget | None) -> None: - super().__init__(parent) - self.lengthInMeters = Decimal() - self.lineEdit = DecimalLineEdit.createInstance(isSigned=isSigned) - self.unitsComboBox = QComboBox() - - @classmethod - def createInstance( - cls, *, isSigned: bool = False, parent: QWidget | None = None - ) -> LengthWidget: - widget = cls(isSigned, parent) - - if not isSigned: - widget.lineEdit.setMinimum(Decimal()) - - widget.lineEdit.valueChanged.connect(widget._setLengthInMetersFromWidgets) - - widget.unitsComboBox.addItem('m', 0) - widget.unitsComboBox.addItem('mm', -3) - widget.unitsComboBox.addItem('\u00b5m', -6) - widget.unitsComboBox.addItem('nm', -9) - widget.unitsComboBox.addItem('\u212b', -10) - widget.unitsComboBox.addItem('pm', -12) - widget.unitsComboBox.activated.connect(widget._updateDisplay) - - layout = QHBoxLayout() - layout.setContentsMargins(0, 0, 0, 0) - layout.addWidget(widget.lineEdit) - layout.addWidget(widget.unitsComboBox) - widget.setLayout(layout) - - return widget - - def isReadOnly(self) -> bool: - return self.lineEdit.isReadOnly() - - def setReadOnly(self, enable: bool) -> None: - self.lineEdit.setReadOnly(enable) - - def getLengthInMeters(self) -> Decimal: - return self.lengthInMeters - - def setLengthInMeters(self, lengthInMeters: Decimal) -> None: - self.lengthInMeters = lengthInMeters - - if not lengthInMeters.is_zero(): - exponent = 3 * int( - (abs(lengthInMeters).log10() / 3).to_integral_exact(rounding=ROUND_FLOOR) - ) - index = self.unitsComboBox.findData(exponent) - - if index != -1: - self.unitsComboBox.setCurrentIndex(index) - - self._updateDisplay() - - @property - def _scaleToMeters(self) -> Decimal: - exponent = self.unitsComboBox.currentData() - return Decimal(f'1e{exponent:+d}') - - def _setLengthInMetersFromWidgets(self, magnitude: Decimal) -> None: - self.lengthInMeters = magnitude * self._scaleToMeters - self.lengthChanged.emit(self.lengthInMeters) - - def _updateDisplay(self) -> None: - lengthInDisplayUnits = self.lengthInMeters / self._scaleToMeters - self.lineEdit.setValue(lengthInDisplayUnits) diff --git a/src/ptychodus/view/widgets/length_widget.py b/src/ptychodus/view/widgets/length_widget.py new file mode 100644 index 00000000..a8e7a360 --- /dev/null +++ b/src/ptychodus/view/widgets/length_widget.py @@ -0,0 +1,78 @@ +from __future__ import annotations +from decimal import Decimal, ROUND_FLOOR + +from PyQt5.QtCore import pyqtSignal +from PyQt5.QtWidgets import QComboBox, QHBoxLayout, QWidget + +from .decimal_line_edit import DecimalLineEdit + + +class LengthWidget(QWidget): + length_changed = pyqtSignal(Decimal) + + def __init__(self, is_signed: bool, parent: QWidget | None) -> None: + super().__init__(parent) + self.length_m = Decimal() + self.line_edit = DecimalLineEdit.create_instance(is_signed=is_signed) + self.units_combo_box = QComboBox() + + @classmethod + def create_instance( + cls, *, is_signed: bool = False, parent: QWidget | None = None + ) -> LengthWidget: + widget = cls(is_signed, parent) + + if not is_signed: + widget.line_edit.set_minimum(Decimal()) + + widget.line_edit.value_changed.connect(widget._set_length_m_from_widgets) + + widget.units_combo_box.addItem('m', 0) + widget.units_combo_box.addItem('mm', -3) + widget.units_combo_box.addItem('\u00b5m', -6) + widget.units_combo_box.addItem('nm', -9) + widget.units_combo_box.addItem('\u212b', -10) + widget.units_combo_box.addItem('pm', -12) + widget.units_combo_box.activated.connect(widget._update_display) + + layout = QHBoxLayout() + layout.setContentsMargins(0, 0, 0, 0) + layout.addWidget(widget.line_edit) + layout.addWidget(widget.units_combo_box) + widget.setLayout(layout) + + return widget + + def is_read_only(self) -> bool: + return self.line_edit.is_read_only() + + def set_read_only(self, enable: bool) -> None: + self.line_edit.set_read_only(enable) + + def get_length_m(self) -> Decimal: + return self.length_m + + def set_length_m(self, length_m: Decimal) -> None: + self.length_m = length_m + + if not length_m.is_zero(): + exponent = 3 * int((abs(length_m).log10() / 3).to_integral_exact(rounding=ROUND_FLOOR)) + index = self.units_combo_box.findData(exponent) + + if index != -1: + self.units_combo_box.setCurrentIndex(index) + + self._update_display() + + @property + def _scale_to_meters(self) -> Decimal: + exponent = self.units_combo_box.currentData() + return Decimal(f'1e{exponent:+d}') + + def _set_length_m_from_widgets(self, magnitude: Decimal) -> None: + self.length_m = magnitude * self._scale_to_meters + self.length_changed.emit(self.length_m) + + def _update_display(self) -> None: + length_in_display_units = self.length_m / self._scale_to_meters + self.line_edit.set_value(length_in_display_units) diff --git a/src/ptychodus/view/widgets/progressBarItemDelegate.py b/src/ptychodus/view/widgets/progress_bar_item_delegate.py similarity index 87% rename from src/ptychodus/view/widgets/progressBarItemDelegate.py rename to src/ptychodus/view/widgets/progress_bar_item_delegate.py index 05fdda44..26123503 100644 --- a/src/ptychodus/view/widgets/progressBarItemDelegate.py +++ b/src/ptychodus/view/widgets/progress_bar_item_delegate.py @@ -15,6 +15,7 @@ class ProgressBarItemDelegate(QStyledItemDelegate): def paint(self, painter: QPainter, option: QStyleOptionViewItem, index: QModelIndex) -> None: + text = index.data(Qt.ItemDataRole.DisplayRole) progress = index.data(Qt.ItemDataRole.UserRole) if progress is None: @@ -24,7 +25,7 @@ def paint(self, painter: QPainter, option: QStyleOptionViewItem, index: QModelIn opt.rect = option.rect opt.minimum = 0 opt.maximum = 100 - opt.progress = progress - opt.text = f'{progress}%' + opt.progress = int(progress) + opt.text = text opt.textVisible = True QApplication.style().drawControl(QStyle.ControlElement.CE_ProgressBar, opt, painter) diff --git a/src/ptychodus/view/widgets/uuidLineEdit.py b/src/ptychodus/view/widgets/uuid_line_edit.py similarity index 79% rename from src/ptychodus/view/widgets/uuidLineEdit.py rename to src/ptychodus/view/widgets/uuid_line_edit.py index 2e643d91..0d3b17b0 100644 --- a/src/ptychodus/view/widgets/uuidLineEdit.py +++ b/src/ptychodus/view/widgets/uuid_line_edit.py @@ -5,11 +5,11 @@ class UUIDLineEdit(QLineEdit): @staticmethod - def _createValidator() -> QRegularExpressionValidator: + def _create_validator() -> QRegularExpressionValidator: hexre = '[0-9A-Fa-f]' uuidre = f'{hexre}{{8}}-{hexre}{{4}}-{hexre}{{4}}-{hexre}{{4}}-{hexre}{{12}}' return QRegularExpressionValidator(QRegularExpression(uuidre)) def __init__(self, parent: QWidget | None = None) -> None: super().__init__(parent) - self.setValidator(UUIDLineEdit._createValidator()) + self.setValidator(UUIDLineEdit._create_validator()) diff --git a/src/ptychodus/view/workflow.py b/src/ptychodus/view/workflow.py index bd197125..76856c33 100644 --- a/src/ptychodus/view/workflow.py +++ b/src/ptychodus/view/workflow.py @@ -24,13 +24,13 @@ class WorkflowAuthorizationDialog(QDialog): def __init__(self, parent: QWidget | None) -> None: super().__init__(parent) self.label = QLabel() - self.lineEdit = QLineEdit() - self.buttonBox = QDialogButtonBox() - self.okButton = self.buttonBox.addButton(QDialogButtonBox.StandardButton.Ok) - self.cancelButton = self.buttonBox.addButton(QDialogButtonBox.StandardButton.Cancel) + self.line_edit = QLineEdit() + self.button_box = QDialogButtonBox() + self.ok_button = self.button_box.addButton(QDialogButtonBox.StandardButton.Ok) + self.cancel_button = self.button_box.addButton(QDialogButtonBox.StandardButton.Cancel) @classmethod - def createInstance(cls, parent: QWidget | None = None) -> WorkflowAuthorizationDialog: + def create_instance(cls, parent: QWidget | None = None) -> WorkflowAuthorizationDialog: view = cls(parent) view.setWindowTitle('Authorize Workflow') @@ -38,18 +38,18 @@ def createInstance(cls, parent: QWidget | None = None) -> WorkflowAuthorizationD view.label.setTextInteractionFlags(Qt.TextInteractionFlag.TextBrowserInteraction) view.label.setOpenExternalLinks(True) - view.buttonBox.clicked.connect(view._handleButtonBoxClicked) + view.button_box.clicked.connect(view._handle_button_box_clicked) layout = QVBoxLayout() layout.addWidget(view.label) - layout.addWidget(view.lineEdit) - layout.addWidget(view.buttonBox) + layout.addWidget(view.line_edit) + layout.addWidget(view.button_box) view.setLayout(layout) return view - def _handleButtonBoxClicked(self, button: QAbstractButton) -> None: - if self.buttonBox.buttonRole(button) == QDialogButtonBox.ButtonRole.AcceptRole: + def _handle_button_box_clicked(self, button: QAbstractButton) -> None: + if self.button_box.buttonRole(button) == QDialogButtonBox.ButtonRole.AcceptRole: self.accept() else: self.reject() @@ -58,18 +58,18 @@ def _handleButtonBoxClicked(self, button: QAbstractButton) -> None: class WorkflowInputDataView(QGroupBox): def __init__(self, parent: QWidget | None) -> None: super().__init__('Input Data', parent) - self.endpointIDLineEdit = UUIDLineEdit() - self.globusPathLineEdit = QLineEdit() - self.posixPathLineEdit = QLineEdit() + self.endpoint_id_line_edit = UUIDLineEdit() + self.globus_path_line_edit = QLineEdit() + self.posix_path_line_edit = QLineEdit() @classmethod - def createInstance(cls, parent: QWidget | None = None) -> WorkflowInputDataView: + def create_instance(cls, parent: QWidget | None = None) -> WorkflowInputDataView: view = cls(parent) layout = QFormLayout() - layout.addRow('Endpoint ID:', view.endpointIDLineEdit) - layout.addRow('Globus Path:', view.globusPathLineEdit) - layout.addRow('POSIX Path:', view.posixPathLineEdit) + layout.addRow('Endpoint ID:', view.endpoint_id_line_edit) + layout.addRow('Globus Path:', view.globus_path_line_edit) + layout.addRow('POSIX Path:', view.posix_path_line_edit) view.setLayout(layout) return view @@ -78,20 +78,20 @@ def createInstance(cls, parent: QWidget | None = None) -> WorkflowInputDataView: class WorkflowOutputDataView(QGroupBox): def __init__(self, parent: QWidget | None) -> None: super().__init__('Output Data', parent) - self.roundTripCheckBox = QCheckBox('Round Trip') - self.endpointIDLineEdit = UUIDLineEdit() - self.globusPathLineEdit = QLineEdit() - self.posixPathLineEdit = QLineEdit() + self.round_trip_check_box = QCheckBox('Round Trip') + self.endpoint_id_line_edit = UUIDLineEdit() + self.globus_path_line_edit = QLineEdit() + self.posix_path_line_edit = QLineEdit() @classmethod - def createInstance(cls, parent: QWidget | None = None) -> WorkflowOutputDataView: + def create_instance(cls, parent: QWidget | None = None) -> WorkflowOutputDataView: view = cls(parent) layout = QFormLayout() - layout.addRow(view.roundTripCheckBox) - layout.addRow('Endpoint ID:', view.endpointIDLineEdit) - layout.addRow('Globus Path:', view.globusPathLineEdit) - layout.addRow('POSIX Path:', view.posixPathLineEdit) + layout.addRow(view.round_trip_check_box) + layout.addRow('Endpoint ID:', view.endpoint_id_line_edit) + layout.addRow('Globus Path:', view.globus_path_line_edit) + layout.addRow('POSIX Path:', view.posix_path_line_edit) view.setLayout(layout) return view @@ -100,20 +100,20 @@ def createInstance(cls, parent: QWidget | None = None) -> WorkflowOutputDataView class WorkflowComputeView(QGroupBox): def __init__(self, parent: QWidget | None) -> None: super().__init__('Compute', parent) - self.computeEndpointIDLineEdit = UUIDLineEdit() - self.dataEndpointIDLineEdit = UUIDLineEdit() - self.dataGlobusPathLineEdit = QLineEdit() - self.dataPosixPathLineEdit = QLineEdit() + self.compute_endpoint_id_line_edit = UUIDLineEdit() + self.data_endpoint_id_line_edit = UUIDLineEdit() + self.data_globus_path_line_edit = QLineEdit() + self.data_posix_path_line_edit = QLineEdit() @classmethod - def createInstance(cls, parent: QWidget | None = None) -> WorkflowComputeView: + def create_instance(cls, parent: QWidget | None = None) -> WorkflowComputeView: view = cls(parent) layout = QFormLayout() - layout.addRow('Compute Endpoint ID:', view.computeEndpointIDLineEdit) - layout.addRow('Data Endpoint ID:', view.dataEndpointIDLineEdit) - layout.addRow('Data Globus Path:', view.dataGlobusPathLineEdit) - layout.addRow('Data POSIX Path:', view.dataPosixPathLineEdit) + layout.addRow('Compute Endpoint ID:', view.compute_endpoint_id_line_edit) + layout.addRow('Data Endpoint ID:', view.data_endpoint_id_line_edit) + layout.addRow('Data Globus Path:', view.data_globus_path_line_edit) + layout.addRow('Data POSIX Path:', view.data_posix_path_line_edit) view.setLayout(layout) return view @@ -122,22 +122,22 @@ def createInstance(cls, parent: QWidget | None = None) -> WorkflowComputeView: class WorkflowExecutionView(QGroupBox): def __init__(self, parent: QWidget | None) -> None: super().__init__('Execution', parent) - self.productComboBox = QComboBox() - self.inputDataView = WorkflowInputDataView.createInstance() - self.computeView = WorkflowComputeView.createInstance() - self.outputDataView = WorkflowOutputDataView.createInstance() - self.executeButton = QPushButton('Execute') + self.product_combo_box = QComboBox() + self.input_data_view = WorkflowInputDataView.create_instance() + self.compute_view = WorkflowComputeView.create_instance() + self.output_data_view = WorkflowOutputDataView.create_instance() + self.execute_button = QPushButton('Execute') @classmethod - def createInstance(cls, parent: QWidget | None = None) -> WorkflowExecutionView: + def create_instance(cls, parent: QWidget | None = None) -> WorkflowExecutionView: view = cls(parent) layout = QFormLayout() - layout.addRow('Product:', view.productComboBox) - layout.addRow(view.inputDataView) - layout.addRow(view.computeView) - layout.addRow(view.outputDataView) - layout.addRow(view.executeButton) + layout.addRow('Product:', view.product_combo_box) + layout.addRow(view.input_data_view) + layout.addRow(view.compute_view) + layout.addRow(view.output_data_view) + layout.addRow(view.execute_button) view.setLayout(layout) return view @@ -146,17 +146,17 @@ def createInstance(cls, parent: QWidget | None = None) -> WorkflowExecutionView: class WorkflowStatusView(QGroupBox): def __init__(self, parent: QWidget | None) -> None: super().__init__('Status', parent) - self.autoRefreshCheckBox = QCheckBox('Auto Refresh [sec]:') - self.autoRefreshSpinBox = QSpinBox() - self.refreshButton = QPushButton('Refresh') + self.auto_refresh_check_box = QCheckBox('Auto Refresh [sec]:') + self.auto_refresh_spin_box = QSpinBox() + self.refresh_button = QPushButton('Refresh') @classmethod - def createInstance(cls, parent: QWidget | None = None) -> WorkflowStatusView: + def create_instance(cls, parent: QWidget | None = None) -> WorkflowStatusView: view = cls(parent) layout = QFormLayout() - layout.addRow(view.autoRefreshCheckBox, view.autoRefreshSpinBox) - layout.addRow(view.refreshButton) + layout.addRow(view.auto_refresh_check_box, view.auto_refresh_spin_box) + layout.addRow(view.refresh_button) view.setLayout(layout) return view @@ -165,17 +165,17 @@ def createInstance(cls, parent: QWidget | None = None) -> WorkflowStatusView: class WorkflowParametersView(QWidget): def __init__(self, parent: QWidget | None) -> None: super().__init__(parent) - self.executionView = WorkflowExecutionView.createInstance() - self.statusView = WorkflowStatusView.createInstance() - self.authorizationDialog = WorkflowAuthorizationDialog.createInstance(self) + self.execution_view = WorkflowExecutionView.create_instance() + self.status_view = WorkflowStatusView.create_instance() + self.authorization_dialog = WorkflowAuthorizationDialog.create_instance(self) @classmethod - def createInstance(cls, parent: QWidget | None = None) -> WorkflowParametersView: + def create_instance(cls, parent: QWidget | None = None) -> WorkflowParametersView: view = cls(parent) layout = QVBoxLayout() - layout.addWidget(view.executionView) - layout.addWidget(view.statusView) + layout.addWidget(view.execution_view) + layout.addWidget(view.status_view) layout.addStretch() view.setLayout(layout) diff --git a/tests/aperture.py b/tests/aperture.py index b611708f..3a1d9745 100644 --- a/tests/aperture.py +++ b/tests/aperture.py @@ -19,18 +19,18 @@ def get_fresnel_number(self, z_m: float) -> float: lower = self._wavelength_m * z_m return upper / lower - def _integral1D(self, r_m: RealArrayType, z_m: float) -> ComplexArrayType: - sqrt2NF = numpy.sqrt(2 * self.get_fresnel_number(z_m)) + def _integral1d(self, r_m: RealArrayType, z_m: float) -> ComplexArrayType: + sqrt2NF = numpy.sqrt(2 * self.get_fresnel_number(z_m)) # noqa: N806 xi = 2 * r_m / self._width_m - Sm, Cm = fresnel(sqrt2NF * (1 - xi)) - Sp, Cp = fresnel(sqrt2NF * (1 + xi)) + Sm, Cm = fresnel(sqrt2NF * (1 - xi)) # noqa: N806 + Sp, Cp = fresnel(sqrt2NF * (1 + xi)) # noqa: N806 return (Cm + Cp + 1j * (Sm + Sp)) / numpy.sqrt(2) def diffract(self, x_m: RealArrayType, y_m: RealArrayType, z_m: float) -> ComplexArrayType: """Fresnel diffraction; see Goodman p.100""" assert x_m.shape == y_m.shape - Ix = self._integral1D(x_m, z_m) - Iy = self._integral1D(y_m, z_m) + Ix = self._integral1d(x_m, z_m) # noqa: N806 + Iy = self._integral1d(y_m, z_m) # noqa: N806 return Ix * Iy * numpy.exp(2j * numpy.pi * z_m / self._wavelength_m) / 1j @@ -49,8 +49,8 @@ def diffract(self, x_m: RealArrayType, y_m: RealArrayType, z_m: float) -> Comple assert x_m.shape == y_m.shape twopi = 2 * numpy.pi - sqrtLZ = numpy.sqrt(self._wavelength_m * z_m) - sqrtNF = numpy.sqrt(self.get_fresnel_number(z_m)) + sqrtLZ = numpy.sqrt(self._wavelength_m * z_m) # noqa: N806 + sqrtNF = numpy.sqrt(self.get_fresnel_number(z_m)) # noqa: N806 rhop = rho / sqrtLZ rp = numpy.hypot(x_m, y_m) / sqrtLZ diff --git a/tests/test_phase_unwrap.py b/tests/test_phase_unwrap.py new file mode 100644 index 00000000..e3a34c51 --- /dev/null +++ b/tests/test_phase_unwrap.py @@ -0,0 +1,25 @@ +import os + +import numpy as np +import matplotlib.pyplot as plt + +import ptychodus.model.phase_unwrapper as pu + + +def test_phase_unwrap() -> None: + phase_unwrapper = pu.PhaseUnwrapper( + image_grad_method='fourier_differentiation', + image_integration_method='fourier', + ) + img = np.load(os.path.join('data', 'phase_unwrap', 'recon_20241220_epoch_400.npy')) + img = img[0] + + phase = phase_unwrapper.unwrap(img) + + plt.figure() + plt.imshow(phase) + plt.show() + + +if __name__ == '__main__': + test_phase_unwrap() diff --git a/tests/test_propagation.py b/tests/test_propagation.py index be0d2913..4fe15b31 100644 --- a/tests/test_propagation.py +++ b/tests/test_propagation.py @@ -1,5 +1,5 @@ -if __name__ == "__main__": +if __name__ == '__main__': import matplotlib - matplotlib.use("Agg") + matplotlib.use('Agg') import matplotlib.pyplot as plt diff --git a/tests/test_zernike.py b/tests/test_zernike.py index f93fcf36..d809edb8 100644 --- a/tests/test_zernike.py +++ b/tests/test_zernike.py @@ -2,11 +2,11 @@ def test_indexing() -> None: idx = 0 for n in range(10): - print("") + print('') for m in range(-n, n + 1, 2): idx_calc = (n * (n + 2) + m) // 2 - print(f"{n=} {m=:+d} {idx=} {idx_calc=}") + print(f'{n=} {m=:+d} {idx=} {idx_calc=}') assert idx == idx_calc idx += 1 @@ -15,7 +15,7 @@ def test_pyramid() -> None: import numpy import matplotlib - matplotlib.use("Agg") + matplotlib.use('Agg') import matplotlib.colors import matplotlib.pyplot as plt @@ -25,9 +25,9 @@ def test_pyramid() -> None: num_pixels = 256 max_radial_degree = 6 - Y, X = numpy.mgrid[:num_pixels, :num_pixels] - X = (X - (num_pixels - 1) / 2) / (num_pixels / 2) - Y = (Y - (num_pixels - 1) / 2) / (num_pixels / 2) + Y, X = numpy.mgrid[:num_pixels, :num_pixels] # noqa: N806 + X = (X - (num_pixels - 1) / 2) / (num_pixels / 2) # noqa: N806 + Y = (Y - (num_pixels - 1) / 2) / (num_pixels / 2) # noqa: N806 distance = numpy.hypot(Y, X) angle = numpy.arctan2(Y, X) @@ -41,16 +41,16 @@ def test_pyramid() -> None: for radial_degree in range(max_radial_degree): for angular_frequency in range(-radial_degree, radial_degree + 1, 2): polynomial = ZernikePolynomial(radial_degree, angular_frequency) - Z = polynomial(distance, angle, undefined_value=numpy.nan) + Z = polynomial(distance, angle, undefined_value=numpy.nan) # noqa: N806 row = radial_degree col = max_radial_degree + angular_frequency ax = fig.add_subplot(gs[row : row + 1, col : col + 2]) - ax.pcolormesh(X, Y, Z, norm=matplotlib.colors.CenteredNorm(), cmap="seismic") - ax.set_aspect("equal") + ax.pcolormesh(X, Y, Z, norm=matplotlib.colors.CenteredNorm(), cmap='seismic') + ax.set_aspect('equal') ax.set_title(str(polynomial)) - ax.axis("off") + ax.axis('off') - plt.savefig("zernike_pyramid.png", bbox_inches="tight", dpi=my_dpi) + plt.savefig('zernike_pyramid.png', bbox_inches='tight', dpi=my_dpi) plt.close(fig)