Skip to content

Commit 70b0f69

Browse files
Completes OPEN-3480 Add an arg to the add methods of the Python API for when users want to bypass warnings / overwrite
1 parent 5f1b3b3 commit 70b0f69

File tree

1 file changed

+47
-13
lines changed

1 file changed

+47
-13
lines changed

openlayer/__init__.py

Lines changed: 47 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ def add_model(
219219
model_package_dir: str,
220220
task_type: TaskType,
221221
sample_data: pd.DataFrame = None,
222+
force: bool = False,
222223
project_id: str = None,
223224
):
224225
"""Adds a model to a project's staging area.
@@ -234,6 +235,11 @@ def add_model(
234235
235236
.. important::
236237
The sample_data must be a dataframe with at least two rows.
238+
force : bool
239+
If :obj:`add_model` is called when there is already a model in the staging area,
240+
when ``force=True``, the existing staged model will be overwritten by the new
241+
one. When ``force=False``, the user will be prompted to confirm the
242+
overwrite.
237243
238244
Examples
239245
--------
@@ -330,7 +336,10 @@ def add_model(
330336
) from None
331337

332338
self._stage_resource(
333-
resource_name="model", resource_dir=model_package_dir, project_id=project_id
339+
resource_name="model",
340+
resource_dir=model_package_dir,
341+
project_id=project_id,
342+
force=force,
334343
)
335344

336345
def add_dataset(
@@ -348,6 +357,7 @@ def add_dataset(
348357
sep: str = ",",
349358
dataset_config_file_path: Optional[str] = None,
350359
project_id: str = None,
360+
force: bool = False,
351361
):
352362
r"""Adds a dataset to a project's staging area (from a csv).
353363
@@ -393,6 +403,10 @@ class probabilities. For example, for a binary classification
393403
The language of the dataset in ISO 639-1 (alpha-2 code) format.
394404
sep : str, default ','
395405
Delimiter to use. E.g. `'\\t'`.
406+
force : bool
407+
If :obj:`add_dataset` is called when there is already a dataset of the same type in the
408+
staging area, when ``force=True``, the existing staged dataset will be overwritten by the new
409+
one. When ``force=False``, the user will be prompted to confirm the overwrite.
396410
397411
Notes
398412
-----
@@ -550,6 +564,7 @@ class probabilities. For example, for a binary classification
550564
resource_name=dataset_type.value,
551565
resource_dir=temp_dir,
552566
project_id=project_id,
567+
force=force,
553568
)
554569

555570
def add_dataframe(
@@ -566,6 +581,7 @@ def add_dataframe(
566581
language: str = "en",
567582
project_id: str = None,
568583
dataset_config_file_path: Optional[str] = None,
584+
force: bool = False,
569585
):
570586
r"""Adds a dataset to a project's staging area (from a pandas DataFrame).
571587
@@ -609,6 +625,10 @@ class probabilities. For example, for a binary classification
609625
:obj:`TaskType.TabularClassification` or :obj:`TaskType.TabularRegression`.
610626
language : str, default 'en'
611627
The language of the dataset in ISO 639-1 (alpha-2 code) format.
628+
force : bool
629+
If :obj:`add_dataframe` is called when there is already a dataset of the same type in the
630+
staging area, when ``force=True``, the existing staged dataset will be overwritten by the new
631+
one. When ``force=False``, the user will be prompted to confirm the overwrite.
612632
613633
Notes
614634
-----
@@ -740,15 +760,20 @@ class probabilities. For example, for a binary classification
740760
categorical_feature_names=categorical_feature_names,
741761
project_id=project_id,
742762
dataset_config_file_path=dataset_config_file_path,
763+
force=force,
743764
)
744765

745-
def commit(self, message: str, project_id: int):
766+
def commit(self, message: str, project_id: int, force: bool = False):
746767
"""Adds a commit message to staged resources.
747768
748769
Parameters
749770
----------
750771
message : str
751772
The commit message, between 1 and 140 characters.
773+
force : bool
774+
If :obj:`commit` is called when there is already a commit message,
775+
when ``force=True``, the existing message will be overwritten by the new
776+
one. When ``force=False``, the user will be prompted to confirm the overwrite.
752777
753778
Notes
754779
-----
@@ -794,15 +819,18 @@ def commit(self, message: str, project_id: int):
794819

795820
if os.path.exists(f"{project_dir}/commit.yaml"):
796821
print("Found a previous commit that was not pushed to the platform.")
797-
with open(f"{project_dir}/commit.yaml", "r") as commit_file:
798-
commit = yaml.safe_load(commit_file)
799-
print(
800-
f"\t - Commit message: `{commit['message']}` \n \t - Date: {commit['date']}"
822+
overwrite = "n"
823+
824+
if not force:
825+
with open(f"{project_dir}/commit.yaml", "r") as commit_file:
826+
commit = yaml.safe_load(commit_file)
827+
print(
828+
f"\t - Commit message: `{commit['message']}` \n \t - Date: {commit['date']}"
829+
)
830+
overwrite = input(
831+
"Do you want to overwrite it with the current message? [y/n]: "
801832
)
802-
overwrite = input(
803-
"Do you want to overwrite it with the current message? [y/n]: "
804-
)
805-
if overwrite.lower() == "y":
833+
if overwrite.lower() == "y" or force:
806834
print("Overwriting commit message...")
807835
os.remove(f"{project_dir}/commit.yaml")
808836

@@ -990,7 +1018,9 @@ def restore(self, resource_name: str, project_id: int):
9901018
):
9911019
os.remove(f"{project_dir}/commit.yaml")
9921020

993-
def _stage_resource(self, resource_name: str, resource_dir: str, project_id: int):
1021+
def _stage_resource(
1022+
self, resource_name: str, resource_dir: str, project_id: int, force: bool
1023+
):
9941024
"""Adds the resource specified by `resource_name` to the project's staging directory.
9951025
9961026
Parameters
@@ -1001,6 +1031,8 @@ def _stage_resource(self, resource_name: str, resource_dir: str, project_id: int
10011031
The path from which to copy the resource.
10021032
project_id : int
10031033
The id of the project to which the resource should be added.
1034+
force : bool
1035+
Whether to overwrite the resource if it already exists in the staging area.
10041036
"""
10051037
if resource_name not in ["model", "training", "validation"]:
10061038
raise ValueError(
@@ -1016,9 +1048,11 @@ def _stage_resource(self, resource_name: str, resource_dir: str, project_id: int
10161048

10171049
if os.path.exists(staging_dir):
10181050
print(f"Found an existing {resource_name} staged.")
1019-
overwrite = input("Do you want to overwrite it? [y/n] ")
1051+
overwrite = "n"
10201052

1021-
if overwrite.lower() == "y":
1053+
if not force:
1054+
overwrite = input("Do you want to overwrite it? [y/n] ")
1055+
if overwrite.lower() == "y" or force:
10221056
print(f"Overwriting previously staged {resource_name}...")
10231057
shutil.rmtree(staging_dir)
10241058
else:

0 commit comments

Comments
 (0)