diff --git a/arcflow/utils/bulk_import.py b/arcflow/utils/bulk_import.py index b579577..be59b58 100644 --- a/arcflow/utils/bulk_import.py +++ b/arcflow/utils/bulk_import.py @@ -8,6 +8,7 @@ from pathlib import Path from datetime import datetime from asnake.client import ASnakeClient +from multiprocessing.pool import ThreadPool as Pool import re @@ -80,14 +81,82 @@ def check_for_children(repo_id, rid, asnake_client): print(f'Error retrieving child count for resource ID: {e}') return -1 +def delete_archival_object(repo_id, ao_id, asnake_client): + """ + Function to delete an archival object by ID. + Returns True if successful, False otherwise. + """ + try: + delete_response = asnake_client.delete( + f"/repositories/{repo_id}/archival_objects/{ao_id}") + if delete_response.status_code == 200: + print(f"Deleted archival object {ao_id} successfully.") + return True + else: + print(f"Failed to delete archival object {ao_id}. Status code: {delete_response.status_code}") + return False + except Exception as e: + print(f'Error deleting archival object ID {ao_id}: {e}') + return False + +def delete_children(repo_id, rid, asnake_client): + """ + Function to delete all top-level children of a resource. + Returns integer value for the number of children deleted or -1 if encounters an error. + """ + try: + info = asnake_client.get(f"/repositories/{repo_id}/resources/{rid}/tree/root").json() + child_count = int(info.get('child_count', 0)) + if child_count > 0: + with Pool(processes=10) as pool: + waypoints = int(info.get('waypoints', 0)) + # in case there are more children than the precomputed_waypoints + # starting with the highest waypoint and working backwards to avoid the list shrinking and changing offsets for remaining waypoints + for i in range(waypoints, 1, -1): + waypoint = asnake_client.get(f"/repositories/{repo_id}/resources/{rid}/tree/waypoint", + params={ + 'offset': i-1, + }).json() + results = [pool.apply_async( + delete_archival_object, + args=(repo_id, child['uri'].split('/')[-1], asnake_client)) + for child in waypoint] + # wait for task to complete + for r in results: + r.get() + + # then delete the remaining children in the precomputed_waypoints + results = [pool.apply_async( + delete_archival_object, + args=(repo_id, child['uri'].split('/')[-1], asnake_client)) + for child in info['precomputed_waypoints']['']['0']] + # wait for task to complete + for r in results: + r.get() + return child_count + except Exception as e: + print(f'Error deleting children for resource ID: {e}') + return -1 + def report_csv_error(report_dict, error_string): """Function to print and log error messages (assumes only one error message).""" report_dict["error"] = error_string print(error_string) -def csv_bulk_import(csv_directory=None, load_type='ao', only_validate='false', save_output_files=False): +def csv_bulk_import( + csv_directory=None, + load_type='ao', + only_validate='false', + save_output_files=False, + overwrite_children=False, + only_delete_children=False, + report_text_file=""): """Function to handle CSV bulk import.""" - print("Starting CSV bulk import...") + if report_text_file: + print(f"Retrying CSV bulk import with report file {report_text_file}...") + else: + print("Starting CSV bulk import...") + if not csv_directory or not os.path.exists(csv_directory): print(f'Directory {csv_directory} does not exist. Exiting.') exit(0) @@ -97,8 +166,26 @@ def csv_bulk_import(csv_directory=None, load_type='ao', only_validate='false', s bulk_import_report = [] - for f in glob.iglob(f'{csv_directory}*.csv'): - print(f'Processing file {f}...') + if report_text_file: + try: + with open(report_text_file, "r") as file: + entries = yaml.safe_load(file) + except FileNotFoundError: + print(f"File {report_text_file} not found.") + exit(0) + else: + entries = glob.iglob(f'{csv_directory}*.csv') + + for f in entries: + if report_text_file: + if f.get("java_mysql_error", 0) > 0: + f = f"{csv_directory}{f['identifier']}.csv" + print(f'Retrying file {f}...') + else: + continue + else: + print(f'Processing file {f}...') + file_import_report = {} file_import_report["identifier"] = Path(f).stem file_import_report["type"] = load_type @@ -136,15 +223,28 @@ def csv_bulk_import(csv_directory=None, load_type='ao', only_validate='false', s file_import_report["rid"] = rid if load_type == "ao": - child_count = check_for_children(repo, rid, client) - if child_count > 0: - report_csv_error(file_import_report, f'EAD ID {ead_id} already has {child_count} top-level children in ASpace. Not imported.') - bulk_import_report.append(file_import_report) - continue - elif child_count == -1: - report_csv_error(file_import_report, f'Error checking children for EAD ID {ead_id}. Not imported.') - bulk_import_report.append(file_import_report) - continue + if overwrite_children or only_delete_children: + deleted_children = delete_children(repo, rid, client) + file_import_report["deleted_children"] = deleted_children + if deleted_children == -1: + report_csv_error(file_import_report, f'Error deleting children for EAD ID {ead_id}. Not imported.') + bulk_import_report.append(file_import_report) + continue + if only_delete_children: + file_import_report["results_status"] = "Completed" + file_import_report["results_warnings"] = f"Deleted {deleted_children} children. No import performed." + bulk_import_report.append(file_import_report) + continue + else: + child_count = check_for_children(repo, rid, client) + if child_count > 0: + report_csv_error(file_import_report, f'EAD ID {ead_id} already has {child_count} top-level children in ASpace. Not imported.') + bulk_import_report.append(file_import_report) + continue + elif child_count == -1: + report_csv_error(file_import_report, f'Error checking children for EAD ID {ead_id}. Not imported.') + bulk_import_report.append(file_import_report) + continue file_list = [] with open(f, 'rb') as file: @@ -183,6 +283,10 @@ def csv_bulk_import(csv_directory=None, load_type='ao', only_validate='false', s bulk_import_report.append(file_import_report) print(json.dumps(import_job, indent=4)) + + if not bulk_import_report: + print("No more files to process. Exiting.") + exit(0) if save_output_files: try: @@ -206,12 +310,12 @@ def save_report(path, report_list, validate_only): txt_report_save_path = os.path.join(report_save_path, report_text_file_name) with open(txt_report_save_path, 'w', encoding='utf-8') as report: - print("Import Job Info", file=report) + print("# Import Job Info", file=report) json.dump(report_list, report, indent=4) report_csv_file_name = report_file_name_stem + ".csv" - fieldnames = ['identifier','ead_id','aspace_url','import_date','repo_id', 'rid', 'only_validate','type','resource_id','error','results_status','results_warnings','results_id','results_uri'] + fieldnames = ['identifier','ead_id','aspace_url','import_date','repo_id', 'rid', 'only_validate','type','resource_id','error','results_status','results_warnings','results_id','results_uri','deleted_children'] issue_assessment_fieldnames = get_issue_assessment_fieldnames() fieldnames.extend(issue_assessment_fieldnames) @@ -222,6 +326,8 @@ def save_report(path, report_list, validate_only): for row in report_list: writer.writerow(row) + return f"{report_save_path}/{report_text_file_name}" + def check_job_status(asnake_client, repo_id, job_id): """Function to check whether a job has completed (and thus output files are ready).""" while True: @@ -359,18 +465,44 @@ def main(): '--save-output-files', action='store_true', help='Download job output files',) + parser.add_argument( + '--overwrite-children', + action='store_true', + help='Overwrite/delete existing children during import/validation',) + parser.add_argument( + '--only-delete-children', + action='store_true', + help='Only delete existing children without performing import',) + parser.add_argument( + '--max-retries', + type=int, + default=0, + help='Number of times to retry a failed job (default: 0)',) args = parser.parse_args() if not args.dir.endswith('/'): args.dir += '/' - import_report = csv_bulk_import( - csv_directory=args.dir, - load_type=args.load_type, - only_validate='true' if args.only_validate else 'false', - save_output_files=args.save_output_files) - - save_report(args.dir, import_report, args.only_validate) + report_text_file = "" + is_retrying = args.max_retries > 0 + while True: + if args.max_retries < 0: + if is_retrying: + print("Maximum retries reached. Exiting.") + break + else: + import_report = csv_bulk_import( + csv_directory=args.dir, + load_type=args.load_type, + only_validate='true' if args.only_validate else 'false', + save_output_files=args.save_output_files, + overwrite_children=args.overwrite_children, + only_delete_children=args.only_delete_children, + report_text_file=report_text_file) + + report_text_file = save_report(args.dir, import_report, args.only_validate) + + args.max_retries -= 1 if __name__ == '__main__': main() \ No newline at end of file