diff --git a/common/batch.py b/common/batch.py index cfd0ac0..dea7aa4 100644 --- a/common/batch.py +++ b/common/batch.py @@ -46,7 +46,7 @@ def batch_size(self) -> int: if not isinstance(tensor, torch.Tensor): continue return tensor.shape[0] - raise Exception("Could not determine batch size from tensors.") + raise RuntimeError("Could not determine batch size from tensors.") @dataclass diff --git a/common/checkpointing/snapshot.py b/common/checkpointing/snapshot.py index 2703efd..6b10059 100644 --- a/common/checkpointing/snapshot.py +++ b/common/checkpointing/snapshot.py @@ -189,7 +189,7 @@ def get_checkpoint( checkpoints = get_checkpoints(save_dir) if not checkpoints: if not missing_ok: - raise Exception(f"No checkpoints found at {save_dir}") + raise RuntimeError(f"No checkpoints found at {save_dir}") else: logging.info(f"No checkpoints found for restoration at {save_dir}.") return "" @@ -204,7 +204,7 @@ def get_checkpoint( chosen_checkpoint = checkpoint break else: - raise Exception(f"Desired checkpoint at {global_step} not found in {save_dir}") + raise RuntimeError(f"Desired checkpoint at {global_step} not found in {save_dir}") return chosen_checkpoint diff --git a/reader/dataset.py b/reader/dataset.py index 6e811cc..96df85c 100644 --- a/reader/dataset.py +++ b/reader/dataset.py @@ -67,7 +67,7 @@ def _validate_columns(self): columns = set(self._dataset_kwargs.get("columns", [])) wrong_columns = set(columns) - set(self._schema.names) if wrong_columns: - raise Exception(f"Specified columns {list(wrong_columns)} not in schema.") + raise RuntimeError(f"Specified columns {list(wrong_columns)} not in schema.") def serve(self): self.reader = _Reader(location=self.LOCATION, ds=self) diff --git a/reader/dds.py b/reader/dds.py index 7b49893..a15927e 100644 --- a/reader/dds.py +++ b/reader/dds.py @@ -25,7 +25,7 @@ def maybe_start_dataset_service(): return if packaging.version.parse(tf.__version__) < packaging.version.parse("2.5"): - raise Exception(f"maybe_distribute_dataset requires TF >= 2.5; got {tf.__version__}") + raise RuntimeError(f"maybe_distribute_dataset requires TF >= 2.5; got {tf.__version__}") if env.is_dispatcher(): logging.info(f"env.get_reader_port() = {env.get_reader_port()}") diff --git a/reader/utils.py b/reader/utils.py index fc0e34c..03aefe9 100644 --- a/reader/utils.py +++ b/reader/utils.py @@ -71,7 +71,7 @@ def get_imputation_value(pa_type): pa.string(): pa.scalar("", type=pa.string()), } if pa_type not in type_map: - raise Exception(f"Imputation for type {pa_type} not supported.") + raise RuntimeError(f"Imputation for type {pa_type} not supported.") return type_map[pa_type] def _impute(array: pa.array) -> pa.array: diff --git a/tools/pq.py b/tools/pq.py index 24c6345..457a2fc 100644 --- a/tools/pq.py +++ b/tools/pq.py @@ -64,7 +64,7 @@ def __iter__(self): def _head(self): total_read = self._num * self.bytes_per_row if total_read >= int(500e6): - raise Exception( + raise RuntimeError( "Sorry you're trying to read more than 500 MB " f"into memory ({total_read} bytes)." ) return self._ds.head(self._num, columns=self._columns) @@ -75,7 +75,7 @@ def bytes_per_row(self) -> int: for t in self._ds.schema.types: try: nbits += t.bit_width - except: + except Exception: # Just estimate size if it is variable nbits += 8 return nbits // 8