Skip to content
176 changes: 154 additions & 22 deletions arcflow/utils/bulk_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pathlib import Path
from datetime import datetime
from asnake.client import ASnakeClient
from multiprocessing.pool import ThreadPool as Pool
import re


Expand Down Expand Up @@ -80,14 +81,82 @@ def check_for_children(repo_id, rid, asnake_client):
print(f'Error retrieving child count for resource ID: {e}')
return -1

def delete_archival_object(repo_id, ao_id, asnake_client):
"""
Function to delete an archival object by ID.
Returns True if successful, False otherwise.
"""
try:
delete_response = asnake_client.delete(
f"/repositories/{repo_id}/archival_objects/{ao_id}")
if delete_response.status_code == 200:
print(f"Deleted archival object {ao_id} successfully.")
return True
else:
print(f"Failed to delete archival object {ao_id}. Status code: {delete_response.status_code}")
return False
except Exception as e:
print(f'Error deleting archival object ID {ao_id}: {e}')
return False

def delete_children(repo_id, rid, asnake_client):
"""
Function to delete all top-level children of a resource.
Returns integer value for the number of children deleted or -1 if encounters an error.
"""
try:
info = asnake_client.get(f"/repositories/{repo_id}/resources/{rid}/tree/root").json()
child_count = int(info.get('child_count', 0))
if child_count > 0:
with Pool(processes=10) as pool:
waypoints = int(info.get('waypoints', 0))
# in case there are more children than the precomputed_waypoints
# starting with the highest waypoint and working backwards to avoid the list shrinking and changing offsets for remaining waypoints
for i in range(waypoints, 1, -1):
waypoint = asnake_client.get(f"/repositories/{repo_id}/resources/{rid}/tree/waypoint",
params={
'offset': i-1,
}).json()
results = [pool.apply_async(
delete_archival_object,
args=(repo_id, child['uri'].split('/')[-1], asnake_client))
for child in waypoint]
# wait for task to complete
for r in results:
r.get()

# then delete the remaining children in the precomputed_waypoints
results = [pool.apply_async(
delete_archival_object,
args=(repo_id, child['uri'].split('/')[-1], asnake_client))
for child in info['precomputed_waypoints']['']['0']]
# wait for task to complete
for r in results:
r.get()
return child_count
except Exception as e:
print(f'Error deleting children for resource ID: {e}')
return -1

def report_csv_error(report_dict, error_string):
"""Function to print and log error messages (assumes only one error message)."""
report_dict["error"] = error_string
print(error_string)

def csv_bulk_import(csv_directory=None, load_type='ao', only_validate='false', save_output_files=False):
def csv_bulk_import(
csv_directory=None,
load_type='ao',
only_validate='false',
save_output_files=False,
overwrite_children=False,
only_delete_children=False,
report_text_file=""):
"""Function to handle CSV bulk import."""
print("Starting CSV bulk import...")
if report_text_file:
print(f"Retrying CSV bulk import with report file {report_text_file}...")
else:
print("Starting CSV bulk import...")

if not csv_directory or not os.path.exists(csv_directory):
print(f'Directory {csv_directory} does not exist. Exiting.')
exit(0)
Expand All @@ -97,8 +166,26 @@ def csv_bulk_import(csv_directory=None, load_type='ao', only_validate='false', s

bulk_import_report = []

for f in glob.iglob(f'{csv_directory}*.csv'):
print(f'Processing file {f}...')
if report_text_file:
try:
with open(report_text_file, "r") as file:
entries = yaml.safe_load(file)
except FileNotFoundError:
print(f"File {report_text_file} not found.")
exit(0)
else:
entries = glob.iglob(f'{csv_directory}*.csv')

for f in entries:
if report_text_file:
if f.get("java_mysql_error", 0) > 0:
f = f"{csv_directory}{f['identifier']}.csv"
print(f'Retrying file {f}...')
else:
continue
else:
print(f'Processing file {f}...')

file_import_report = {}
file_import_report["identifier"] = Path(f).stem
file_import_report["type"] = load_type
Expand Down Expand Up @@ -136,15 +223,28 @@ def csv_bulk_import(csv_directory=None, load_type='ao', only_validate='false', s
file_import_report["rid"] = rid

if load_type == "ao":
child_count = check_for_children(repo, rid, client)
if child_count > 0:
report_csv_error(file_import_report, f'EAD ID {ead_id} already has {child_count} top-level children in ASpace. Not imported.')
bulk_import_report.append(file_import_report)
continue
elif child_count == -1:
report_csv_error(file_import_report, f'Error checking children for EAD ID {ead_id}. Not imported.')
bulk_import_report.append(file_import_report)
continue
if overwrite_children or only_delete_children:
deleted_children = delete_children(repo, rid, client)
file_import_report["deleted_children"] = deleted_children
if deleted_children == -1:
report_csv_error(file_import_report, f'Error deleting children for EAD ID {ead_id}. Not imported.')
bulk_import_report.append(file_import_report)
continue
if only_delete_children:
file_import_report["results_status"] = "Completed"
file_import_report["results_warnings"] = f"Deleted {deleted_children} children. No import performed."
bulk_import_report.append(file_import_report)
continue
else:
child_count = check_for_children(repo, rid, client)
if child_count > 0:
report_csv_error(file_import_report, f'EAD ID {ead_id} already has {child_count} top-level children in ASpace. Not imported.')
bulk_import_report.append(file_import_report)
continue
elif child_count == -1:
report_csv_error(file_import_report, f'Error checking children for EAD ID {ead_id}. Not imported.')
bulk_import_report.append(file_import_report)
continue

file_list = []
with open(f, 'rb') as file:
Expand Down Expand Up @@ -183,6 +283,10 @@ def csv_bulk_import(csv_directory=None, load_type='ao', only_validate='false', s

bulk_import_report.append(file_import_report)
print(json.dumps(import_job, indent=4))

if not bulk_import_report:
print("No more files to process. Exiting.")
exit(0)

if save_output_files:
try:
Expand All @@ -206,12 +310,12 @@ def save_report(path, report_list, validate_only):

txt_report_save_path = os.path.join(report_save_path, report_text_file_name)
with open(txt_report_save_path, 'w', encoding='utf-8') as report:
print("Import Job Info", file=report)
print("# Import Job Info", file=report)
json.dump(report_list, report, indent=4)

report_csv_file_name = report_file_name_stem + ".csv"

fieldnames = ['identifier','ead_id','aspace_url','import_date','repo_id', 'rid', 'only_validate','type','resource_id','error','results_status','results_warnings','results_id','results_uri']
fieldnames = ['identifier','ead_id','aspace_url','import_date','repo_id', 'rid', 'only_validate','type','resource_id','error','results_status','results_warnings','results_id','results_uri','deleted_children']
issue_assessment_fieldnames = get_issue_assessment_fieldnames()
fieldnames.extend(issue_assessment_fieldnames)

Expand All @@ -222,6 +326,8 @@ def save_report(path, report_list, validate_only):
for row in report_list:
writer.writerow(row)

return f"{report_save_path}/{report_text_file_name}"

def check_job_status(asnake_client, repo_id, job_id):
"""Function to check whether a job has completed (and thus output files are ready)."""
while True:
Expand Down Expand Up @@ -359,18 +465,44 @@ def main():
'--save-output-files',
action='store_true',
help='Download job output files',)
parser.add_argument(
'--overwrite-children',
action='store_true',
help='Overwrite/delete existing children during import/validation',)
parser.add_argument(
'--only-delete-children',
action='store_true',
help='Only delete existing children without performing import',)
parser.add_argument(
'--max-retries',
type=int,
default=0,
help='Number of times to retry a failed job (default: 0)',)
args = parser.parse_args()

if not args.dir.endswith('/'):
args.dir += '/'

import_report = csv_bulk_import(
csv_directory=args.dir,
load_type=args.load_type,
only_validate='true' if args.only_validate else 'false',
save_output_files=args.save_output_files)

save_report(args.dir, import_report, args.only_validate)
report_text_file = ""
is_retrying = args.max_retries > 0
while True:
if args.max_retries < 0:
if is_retrying:
print("Maximum retries reached. Exiting.")
break
else:
import_report = csv_bulk_import(
csv_directory=args.dir,
load_type=args.load_type,
only_validate='true' if args.only_validate else 'false',
save_output_files=args.save_output_files,
overwrite_children=args.overwrite_children,
only_delete_children=args.only_delete_children,
report_text_file=report_text_file)

report_text_file = save_report(args.dir, import_report, args.only_validate)

args.max_retries -= 1

if __name__ == '__main__':
main()