-
Notifications
You must be signed in to change notification settings - Fork 0
refactor(asr-worker): save transcriptions to artifact filesystem #14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
ClemDoum
wants to merge
2
commits into
main
Choose a base branch
from
feature(asr-worker)/save-results
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,59 +1,184 @@ | ||
| import torchaudio | ||
| from caul.configs.parakeet import ParakeetConfig | ||
| from caul.model_handlers.helpers import ParakeetModelHandlerResult | ||
| from caul.tasks.preprocessing.helpers import PreprocessedInput | ||
| import uuid | ||
| from collections.abc import Generator | ||
| from contextlib import contextmanager | ||
| from pathlib import Path | ||
|
|
||
| from caul.configs import ParakeetConfig | ||
| from caul.model_handlers.asr_model_handler import ASRModelHandler | ||
| from caul.model_handlers.objects import ASRModelHandlerResult | ||
| from caul.tasks.preprocessing.objects import PreprocessedInput | ||
| from config import ASRWorkerConfig | ||
| from datashare_python.types_ import ProgressRateHandler | ||
| from datashare_python.utils import ( | ||
| ActivityWithProgress, | ||
| activity_defn, | ||
| debuggable_name, | ||
| read_artifact_metadata, | ||
| safe_dir, | ||
| to_raw_progress, | ||
| write_artifact_metadata, | ||
| ) | ||
| from datashare_python.utils import artifacts_dir as get_artifacts_dir | ||
| from pydantic import TypeAdapter | ||
| from temporalio import activity | ||
|
|
||
| from .constants import ( | ||
| POSTPROCESS_ACTIVITY, | ||
| PREPROCESS_ACTIVITY, | ||
| RUN_INFERENCE_ACTIVITY, | ||
| TRANSCRIPTION_JSON, | ||
| TRANSCRIPTION_METADATA_KEY, | ||
| ) | ||
| from .models import Transcription | ||
|
|
||
| _BASE_WEIGHT = 1.0 | ||
| _PREPROCESS_WEIGHT = 5 * _BASE_WEIGHT | ||
| _INFERENCE_WEIGHT = 10 * _PREPROCESS_WEIGHT | ||
|
|
||
| _LIST_OF_PATH_ADAPTER = TypeAdapter(list[Path]) | ||
|
|
||
|
|
||
| # TODO: update caul to provide context managers rather than load/shutdown | ||
| @contextmanager | ||
| def _handler(config: ParakeetConfig) -> Generator[ASRModelHandler, None, None]: | ||
| asr_handler = config.handler_from_config() | ||
| try: | ||
| asr_handler.startup() | ||
| yield asr_handler | ||
| finally: | ||
| asr_handler.shutdown() | ||
|
|
||
|
|
||
| class ASRActivities(ActivityWithProgress): | ||
| # TODO: pass this at runtime | ||
| _handler_config = ParakeetConfig(return_tensors=False) | ||
|
|
||
| @activity_defn(name=PREPROCESS_ACTIVITY, progress_weight=_PREPROCESS_WEIGHT) | ||
| def preprocess(self, paths: list[Path]) -> list[Path]: | ||
| # TODO: this shouldn't be necessary, fix this bug | ||
| paths = _LIST_OF_PATH_ADAPTER.validate_python(paths) | ||
| worker_config = ASRWorkerConfig() | ||
| audio_root = worker_config.audios_root | ||
| workdir = worker_config.workdir | ||
| # TODO: load from config passed at runtime with caching | ||
| # TODO: avoid loading the full handler we just need preprocessing | ||
| with _handler(self._handler_config) as asr_handler: | ||
| preprocessor = asr_handler.preprocessor | ||
| # TODO: implement a caching strategy here, we could avoid processing files | ||
| # which have already been preprocessed | ||
| to_process = [str(audio_root / p) for p in paths] | ||
| batches = [] | ||
| # TODO: handle progress here | ||
| for batch in preprocessor.process(str(to_process), output_dir=workdir): | ||
| for preprocessed_input in batch: | ||
| uuid_name = uuid.uuid4().hex[:20] | ||
| segment_dir = safe_dir(uuid_name) | ||
| # TODO: find a more debuggable name for this | ||
| segment_path = ( | ||
| workdir / segment_dir / f"{uuid_name}-preprocessed.json" | ||
| ) | ||
| segment_path.parent.mkdir(parents=True, exist_ok=True) | ||
| preprocessed_input.model_dump_json(segment_path) | ||
| batches.append(segment_path.relative_to(workdir)) | ||
| return batches | ||
|
|
||
| @activity_defn(name=RUN_INFERENCE_ACTIVITY, progress_weight=_INFERENCE_WEIGHT) | ||
| def infer( | ||
| self, | ||
| preprocessed_inputs: list[Path], | ||
| *, | ||
| progress: ProgressRateHandler | None = None, | ||
| ) -> list[Path]: | ||
| preprocessed_inputs = _LIST_OF_PATH_ADAPTER.validate_python(preprocessed_inputs) | ||
| worker_config = ASRWorkerConfig() | ||
| workdir = worker_config.workdir | ||
| # TODO: load from config passed at runtime with caching | ||
| # TODO: avoid loading the full handler we just need inference | ||
| with _handler(self._handler_config) as asr_handler: | ||
| inference_runner = asr_handler.inference_handler | ||
| # TODO: extract this into a function to improve testability | ||
| paths = [] | ||
| if progress is not None: | ||
| progress = to_raw_progress( | ||
| progress, max_progress=len(preprocessed_inputs) | ||
| ) | ||
| abs_paths = [workdir / rel_path for rel_path in preprocessed_inputs] | ||
| audios = (PreprocessedInput.model_validate_json(f) for f in abs_paths) | ||
| for res_i, (path, asr_res) in enumerate( | ||
| zip(preprocessed_inputs, inference_runner.process(audios), strict=True) | ||
| ): | ||
| filename = f"{debuggable_name(path)}-transcript.json" | ||
| transcript_path = workdir / safe_dir(filename) / filename | ||
| transcript_path.parent.mkdir(parents=True, exist_ok=True) | ||
| transcript_path.write_text(asr_res.model_dump_json()) | ||
| paths.append(transcript_path.relative_to(workdir)) | ||
| if progress is not None: | ||
| self._event_loop.run_until_complete(progress(res_i)) | ||
| return paths | ||
|
|
||
| @activity_defn(name=POSTPROCESS_ACTIVITY, progress_weight=_BASE_WEIGHT) | ||
| def postprocess( | ||
| self, | ||
| inference_results: list[Path], | ||
| input_paths: list[Path], | ||
| project: str, | ||
| *, | ||
| progress: ProgressRateHandler | None = None, | ||
| ) -> None: | ||
| inference_results = _LIST_OF_PATH_ADAPTER.validate_python(inference_results) | ||
| input_paths = _LIST_OF_PATH_ADAPTER.validate_python(input_paths) | ||
| worker_config = ASRWorkerConfig() | ||
| artifacts_root = worker_config.artifacts_root | ||
| # TODO: load from config passed at runtime with caching | ||
| # TODO: avoid loading the full handler we just need postprocessing | ||
| with _handler(self._handler_config) as asr_handler: | ||
| post_processor = asr_handler.postprocessor | ||
| if progress is not None: | ||
| progress = to_raw_progress(progress, max_progress=len(input_paths)) | ||
| with post_processor: | ||
| transcriptions = post_processor.process(inference_results) | ||
| # Strict is important here ! | ||
| for i, (original, asr_result) in enumerate( | ||
| zip(input_paths, transcriptions, strict=True) | ||
| ): | ||
| t_path = write_transcription( | ||
| asr_result, | ||
| original.name, | ||
| artifacts_root=artifacts_root, | ||
| project=project, | ||
| ) | ||
| activity.logger.debug("wrote transcription for %s", t_path) | ||
| if progress is not None: | ||
| self._event_loop.run_until_complete(progress(i)) | ||
|
|
||
|
|
||
| class ASRActivities: | ||
| """Contains activity definitions as well as reference to models""" | ||
|
|
||
| def __init__(self): | ||
| # TODO: Eventually this may include whisper, which will | ||
| # then require passing language_map | ||
| self.asr_handler = ParakeetConfig(return_tensors=False).handler_from_config() | ||
|
|
||
| # load models | ||
| self.asr_handler.startup() | ||
|
|
||
| @activity.defn(name="asr.transcription.preprocess") | ||
| async def preprocess(self, inputs: list[str]) -> list[list[PreprocessedInput]]: | ||
| """Preprocess transcription inputs | ||
|
|
||
| :param inputs: list of file paths | ||
| :return: list of caul.tasks.preprocessing.helpers.PreprocessedInput | ||
| """ | ||
| return self.asr_handler.preprocessor.process(inputs) | ||
|
|
||
| @activity.defn(name="asr.transcription.infer") | ||
| async def infer( | ||
| self, inputs: list[PreprocessedInput] | ||
| ) -> list[ParakeetModelHandlerResult]: | ||
| """Transcribe audio files. | ||
|
|
||
| :param inputs: list of preprocessed inputs | ||
| :return: list of inference handler results | ||
| """ | ||
| # Load tensors | ||
| for item in inputs: | ||
| tensor, sample_rate = torchaudio.load(item.metadata.preprocessed_file_path) | ||
| # normalize | ||
| tensor = self.asr_handler.preprocessor.normalize(tensor, sample_rate) | ||
| # assign | ||
| item.tensor = tensor | ||
|
|
||
| return self.asr_handler.inference_handler.process(inputs) | ||
|
|
||
| @activity.defn(name="asr.transcription.postprocess") | ||
| async def postprocess( | ||
| self, inputs: list[ParakeetModelHandlerResult] | ||
| ) -> list[ParakeetModelHandlerResult]: | ||
| """Postprocess and reorder transcriptions | ||
|
|
||
| :param inputs: list of inference handler results | ||
| :return: list of parakeet inference handler results | ||
| """ | ||
| return self.asr_handler.postprocessor.process(inputs) | ||
| def write_transcription( | ||
| asr_result: ASRModelHandlerResult, | ||
| transcribed_filename: str, | ||
| *, | ||
| artifacts_root: Path, | ||
| project: str, | ||
| ) -> Path: | ||
| result = Transcription.from_asr_handler_result(asr_result) | ||
| artifact_dir = artifacts_root / get_artifacts_dir( | ||
| project, filename=transcribed_filename | ||
| ) | ||
| artifact_dir.mkdir(exist_ok=True, parents=True) | ||
| # TODO: if transcriptions are too large we could also serialize them | ||
| # as jsonl | ||
| transcription_path = artifact_dir / TRANSCRIPTION_JSON | ||
| transcription_path.write_text(result.model_dump_json()) | ||
| try: | ||
| meta = read_artifact_metadata( | ||
| artifacts_root, project, filename=transcribed_filename | ||
| ) | ||
| except FileNotFoundError: | ||
| meta = dict() | ||
| meta[TRANSCRIPTION_METADATA_KEY] = transcription_path.name | ||
| write_artifact_metadata( | ||
| meta, artifacts_root, project=project, filename=transcribed_filename | ||
| ) | ||
| return transcription_path | ||
|
|
||
|
|
||
| REGISTRY = [ASRActivities.preprocess, ASRActivities.infer, ASRActivities.postprocess] | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,16 @@ | ||
| from pathlib import Path | ||
| from typing import ClassVar | ||
|
|
||
| import datashare_python | ||
| from datashare_python.config import WorkerConfig | ||
| from pydantic import Field | ||
|
|
||
| _ALL_LOGGERS = [datashare_python.__name__, __name__, "__main__"] | ||
|
|
||
|
|
||
| class ASRWorkerConfig(WorkerConfig): | ||
| loggers: ClassVar[list[str]] = Field(_ALL_LOGGERS, frozen=True) | ||
|
|
||
| audios_root: Path | ||
| artifacts_root: Path | ||
| workdir: Path |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be worth putting artifact-related functionality into its own module in
icij-common? It's probably an edge case, but there might be times when we wanted to interact with artifacts outside of Datashare (say if they're saved to shared storage).