diff --git a/api/import_export/export.py b/api/import_export/export.py index 30cb541680e9..6a1043d85cbd 100644 --- a/api/import_export/export.py +++ b/api/import_export/export.py @@ -2,10 +2,14 @@ import json import logging import typing +from collections.abc import Iterator from dataclasses import dataclass -from tempfile import TemporaryFile +from typing import TYPE_CHECKING import boto3 + +if TYPE_CHECKING: + from mypy_boto3_s3.client import S3Client from django.core import serializers from django.core.serializers.json import DjangoJSONEncoder from django.db.models import F, Model, Q @@ -25,6 +29,7 @@ MultivariateFeatureStateValue, ) from features.versioning.models import EnvironmentFeatureVersion +from import_export.utils import S3MultipartUploadWriter from integrations.datadog.models import DataDogConfiguration from integrations.heap.models import HeapConfiguration from integrations.mixpanel.models import MixpanelConfiguration @@ -53,8 +58,8 @@ class S3OrganisationExporter: - def __init__(self, s3_client=None): # type: ignore[no-untyped-def] - self.s3_client = s3_client or boto3.client("s3") + def __init__(self, s3_client: "S3Client | None" = None) -> None: + self.s3_client: "S3Client" = s3_client or boto3.client("s3") def export_to_s3( self, @@ -63,34 +68,40 @@ def export_to_s3( key: str, ) -> None: data = full_export(organisation_id) - logger.debug("Got data export for organisation.") + logger.debug("Starting streaming export for organisation.") - file = TemporaryFile() - file.write(json.dumps(data, cls=DjangoJSONEncoder).encode("utf-8")) - file.seek(0) - logger.debug("Wrote data export to temporary file.") + with S3MultipartUploadWriter(self.s3_client, bucket_name, key) as writer: + writer.write(b"[") + first = True + for item in data: + if not first: + writer.write(b",") + first = False + writer.write(json.dumps(item, cls=DjangoJSONEncoder).encode("utf-8")) + writer.write(b"]") - self.s3_client.upload_fileobj(file, bucket_name, key) - logger.info("Finished writing data export to s3.") + logger.info("Finished streaming data export to S3.") -def full_export(organisation_id: int) -> typing.List[dict]: # type: ignore[type-arg] - return [ - *export_organisation(organisation_id), - *export_projects(organisation_id), - *export_environments(organisation_id), - *export_identities(organisation_id), - *export_features(organisation_id), - *export_metadata(organisation_id), - *export_edge_identities(organisation_id), - ] +def full_export( + organisation_id: int, +) -> Iterator[dict[str, typing.Any]]: + yield from export_organisation(organisation_id) + yield from export_projects(organisation_id) + yield from export_environments(organisation_id) + yield from export_identities(organisation_id) + yield from export_features(organisation_id) + yield from export_metadata(organisation_id) + yield from export_edge_identities(organisation_id) -def export_organisation(organisation_id: int) -> typing.List[dict]: # type: ignore[type-arg] +def export_organisation( + organisation_id: int, +) -> Iterator[dict[str, typing.Any]]: """ Serialize an organisation and all its related objects. """ - return _export_entities( + yield from _export_entities( _EntityExportConfig(Organisation, Q(id=organisation_id)), _EntityExportConfig(InviteLink, Q(organisation__id=organisation_id)), _EntityExportConfig(OrganisationWebhook, Q(organisation__id=organisation_id)), @@ -98,8 +109,10 @@ def export_organisation(organisation_id: int) -> typing.List[dict]: # type: ign ) -def export_metadata(organisation_id: int) -> typing.List[dict]: # type: ignore[type-arg] - return _export_entities( +def export_metadata( + organisation_id: int, +) -> Iterator[dict[str, typing.Any]]: + yield from _export_entities( _EntityExportConfig(MetadataField, Q(organisation__id=organisation_id)), _EntityExportConfig( MetadataModelField, Q(field__organisation__id=organisation_id) @@ -116,56 +129,55 @@ def export_metadata(organisation_id: int) -> typing.List[dict]: # type: ignore[ def export_projects( organisation_id: int, -) -> typing.List[dict]: # type: ignore[type-arg] +) -> Iterator[dict[str, typing.Any]]: default_filter = Q(project__organisation__id=organisation_id) - exported_projects = _export_entities( + for project in _export_entities( _EntityExportConfig(Project, Q(organisation__id=organisation_id)), - ) - for project in exported_projects: + ): project["fields"]["enable_dynamo_db"] = False + yield project - return [ - *exported_projects, - *_export_entities( - _EntityExportConfig( - Segment, - Q(project__organisation__id=organisation_id, id=F("version_of")), - ), - _EntityExportConfig( - SegmentRule, - Q( - segment__project__organisation__id=organisation_id, - segment_id=F("segment__version_of"), - ) - | Q( - rule__segment__project__organisation__id=organisation_id, - rule__segment_id=F("rule__segment__version_of"), - ), + yield from _export_entities( + _EntityExportConfig( + Segment, + Q(project__organisation__id=organisation_id, id=F("version_of")), + ), + _EntityExportConfig( + SegmentRule, + Q( + segment__project__organisation__id=organisation_id, + segment_id=F("segment__version_of"), + ) + | Q( + rule__segment__project__organisation__id=organisation_id, + rule__segment_id=F("rule__segment__version_of"), ), - _EntityExportConfig( - Condition, - Q( - rule__segment__project__organisation__id=organisation_id, - rule__segment_id=F("rule__segment__version_of"), - ) - | Q( - rule__rule__segment__project__organisation__id=organisation_id, - rule__rule__segment_id=F("rule__rule__segment__version_of"), - ), + ), + _EntityExportConfig( + Condition, + Q( + rule__segment__project__organisation__id=organisation_id, + rule__segment_id=F("rule__segment__version_of"), + ) + | Q( + rule__rule__segment__project__organisation__id=organisation_id, + rule__rule__segment_id=F("rule__rule__segment__version_of"), ), - _EntityExportConfig(Tag, default_filter), - _EntityExportConfig(DataDogConfiguration, default_filter), - _EntityExportConfig(NewRelicConfiguration, default_filter), - _EntityExportConfig(SlackConfiguration, default_filter), ), - ] + _EntityExportConfig(Tag, default_filter), + _EntityExportConfig(DataDogConfiguration, default_filter), + _EntityExportConfig(NewRelicConfiguration, default_filter), + _EntityExportConfig(SlackConfiguration, default_filter), + ) -def export_environments(organisation_id: int) -> typing.List[dict]: # type: ignore[type-arg] +def export_environments( + organisation_id: int, +) -> Iterator[dict[str, typing.Any]]: default_filter = Q(environment__project__organisation__id=organisation_id) - return _export_entities( + yield from _export_entities( _EntityExportConfig(Environment, Q(project__organisation__id=organisation_id)), _EntityExportConfig(EnvironmentAPIKey, default_filter), _EntityExportConfig(Webhook, default_filter), @@ -178,18 +190,26 @@ def export_environments(organisation_id: int) -> typing.List[dict]: # type: ign ) -def export_identities(organisation_id: int) -> typing.List[dict]: # type: ignore[type-arg] - traits = _export_entities( - _EntityExportConfig( - Trait, - Q( - identity__environment__project__organisation__id=organisation_id, - identity__environment__project__enable_dynamo_db=False, +def export_identities( + organisation_id: int, +) -> Iterator[dict[str, typing.Any]]: + # We export the traits first so that we take a 'snapshot' before exporting the + # identities, otherwise we end up with issues where new traits are created for new + # identities during the export process and the identity doesn't exist in the import. + # We then need to reverse the order so that the identities are imported first. + traits = list( + _export_entities( + _EntityExportConfig( + Trait, + Q( + identity__environment__project__organisation__id=organisation_id, + identity__environment__project__enable_dynamo_db=False, + ), ), - ), + ) ) - identities = _export_entities( + yield from _export_entities( _EntityExportConfig( Identity, Q( @@ -199,36 +219,33 @@ def export_identities(organisation_id: int) -> typing.List[dict]: # type: ignor ), ) - # We export the traits first so that we take a 'snapshot' before exporting the - # identities, otherwise we end up with issues where new traits are created for new - # identities during the export process and the identity doesn't exist in the import. - # We then need to reverse the order so that the identities are imported first. - return [*identities, *traits] + yield from traits -def export_edge_identities(organisation_id: int) -> typing.List[dict]: # type: ignore[type-arg] - identities = [] - traits = [] - identity_overrides = [] +def export_edge_identities( + organisation_id: int, +) -> Iterator[dict[str, typing.Any]]: for environment in Environment.objects.filter( project__organisation__id=organisation_id, project__enable_dynamo_db=True ): exported_identities, exported_traits, exported_overrides = ( export_edge_identity_and_overrides(environment.api_key) ) - identities.extend(exported_identities) - traits.extend(exported_traits) - identity_overrides.extend(exported_overrides) - - return [*identities, *traits, *identity_overrides] + yield from exported_identities + yield from exported_traits + yield from exported_overrides -def export_features(organisation_id: int) -> typing.List[dict]: # type: ignore[type-arg] +def export_features( + organisation_id: int, +) -> Iterator[dict[str, typing.Any]]: """ Export all features and related entities, except ChangeRequests. """ - feature_states = [] + # Buffer feature states because we need to modify them (remove change_request FK) + # and they need to be imported after Feature, EnvironmentFeatureVersion, etc. + feature_states: list[dict[str, typing.Any]] = [] for feature_state in _export_entities( _EntityExportConfig( FeatureState, Q(feature__project__organisation__id=organisation_id) @@ -240,38 +257,39 @@ def export_features(organisation_id: int) -> typing.List[dict]: # type: ignore[ feature_state["fields"]["change_request"] = None feature_states.append(feature_state) - return ( - _export_entities( - _EntityExportConfig( - Feature, - Q(project__organisation__id=organisation_id), - exclude_fields=["owners", "group_owners"], - ), - _EntityExportConfig( - EnvironmentFeatureVersion, - Q(feature__project__organisation__id=organisation_id), - exclude_fields=["created_by", "published_by"], - ), - _EntityExportConfig( - MultivariateFeatureOption, - Q(feature__project__organisation__id=organisation_id), - ), - _EntityExportConfig( - FeatureSegment, - Q(feature__project__organisation__id=organisation_id), - ), - ) - + feature_states # feature states need to be imported in correct order - + _export_entities( - _EntityExportConfig( - FeatureStateValue, - Q(feature_state__feature__project__organisation__id=organisation_id), - ), - _EntityExportConfig( - MultivariateFeatureStateValue, - Q(feature_state__feature__project__organisation__id=organisation_id), - ), - ) + yield from _export_entities( + _EntityExportConfig( + Feature, + Q(project__organisation__id=organisation_id), + exclude_fields=["owners", "group_owners"], + ), + _EntityExportConfig( + EnvironmentFeatureVersion, + Q(feature__project__organisation__id=organisation_id), + exclude_fields=["created_by", "published_by"], + ), + _EntityExportConfig( + MultivariateFeatureOption, + Q(feature__project__organisation__id=organisation_id), + ), + _EntityExportConfig( + FeatureSegment, + Q(feature__project__organisation__id=organisation_id), + ), + ) + + # Feature states need to be imported in correct order (after features) + yield from feature_states + + yield from _export_entities( + _EntityExportConfig( + FeatureStateValue, + Q(feature_state__feature__project__organisation__id=organisation_id), + ), + _EntityExportConfig( + MultivariateFeatureStateValue, + Q(feature_state__feature__project__organisation__id=organisation_id), + ), ) @@ -284,19 +302,17 @@ class _EntityExportConfig: def _export_entities( *export_configs: _EntityExportConfig, -) -> typing.List[dict]: # type: ignore[type-arg] - entities = [] +) -> Iterator[dict[str, typing.Any]]: for config in export_configs: args = ("python", config.model_class.objects.filter(config.qs_filter)) - kwargs = {} + kwargs: dict[str, typing.Any] = {} if config.exclude_fields: kwargs["fields"] = [ f.name for f in config.model_class._meta.get_fields() if f.name not in config.exclude_fields ] - entities.extend(_serialize_natural(*args, **kwargs)) # type: ignore[arg-type] - return entities + yield from _serialize_natural(*args, **kwargs) _serialize_natural = functools.partial( diff --git a/api/import_export/management/commands/dumporganisationtos3.py b/api/import_export/management/commands/dumporganisationtos3.py index ebd1a59c5016..d0341c430963 100644 --- a/api/import_export/management/commands/dumporganisationtos3.py +++ b/api/import_export/management/commands/dumporganisationtos3.py @@ -10,7 +10,7 @@ class Command(BaseCommand): def __init__(self, *args, **kwargs): # type: ignore[no-untyped-def] super().__init__(*args, **kwargs) - self.exporter = S3OrganisationExporter() # type: ignore[no-untyped-call] + self.exporter = S3OrganisationExporter() def add_arguments(self, parser: CommandParser): # type: ignore[no-untyped-def] parser.add_argument( diff --git a/api/import_export/utils.py b/api/import_export/utils.py new file mode 100644 index 000000000000..3496d8b6a37b --- /dev/null +++ b/api/import_export/utils.py @@ -0,0 +1,102 @@ +import io +import logging +import typing + +if typing.TYPE_CHECKING: + from mypy_boto3_s3.client import S3Client + from mypy_boto3_s3.type_defs import CompletedPartTypeDef + +logger = logging.getLogger(__name__) + + +class S3MultipartUploadWriter: + """ + A file-like writer that streams data to S3 using multipart upload. + + Buffers data until the minimum part size (5MB) is reached, then uploads + each part. This allows streaming large exports without holding the entire + payload in memory. + """ + + # S3 multipart upload minimum part size is 5MB (except for the last part) + MIN_PART_SIZE = 5 * 1024 * 1024 + + def __init__( + self, + s3_client: "S3Client", + bucket_name: str, + key: str, + ) -> None: + self._s3_client = s3_client + self._bucket_name = bucket_name + self._key = key + self._buffer = io.BytesIO() + self._parts: list[CompletedPartTypeDef] = [] + self._part_number = 1 + self._upload_id: str | None = None + + def __enter__(self) -> "S3MultipartUploadWriter": + response = self._s3_client.create_multipart_upload( + Bucket=self._bucket_name, + Key=self._key, + ) + self._upload_id = response["UploadId"] + logger.debug("Started multipart upload with ID: %s", self._upload_id) + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: typing.Any, + ) -> None: + if exc_type is not None: + # Abort the upload on error + if self._upload_id: + self._s3_client.abort_multipart_upload( + Bucket=self._bucket_name, + Key=self._key, + UploadId=self._upload_id, + ) + logger.warning("Aborted multipart upload due to error: %s", exc_val) + return + + # Upload any remaining data in the buffer (or an empty part if no data) + # S3 requires at least one part to complete a multipart upload + if self._buffer.tell() > 0 or not self._parts: + self._upload_part() + + assert self._upload_id + # Complete the multipart upload + self._s3_client.complete_multipart_upload( + Bucket=self._bucket_name, + Key=self._key, + UploadId=self._upload_id, + MultipartUpload={"Parts": self._parts}, + ) + logger.debug("Completed multipart upload with %d parts", len(self._parts)) + + def write(self, data: bytes) -> None: + self._buffer.write(data) + if self._buffer.tell() >= self.MIN_PART_SIZE: + self._upload_part() + + def _upload_part(self) -> None: + assert self._upload_id + self._buffer.seek(0) + response = self._s3_client.upload_part( + Bucket=self._bucket_name, + Key=self._key, + PartNumber=self._part_number, + UploadId=self._upload_id, + Body=self._buffer.read(), + ) + self._parts.append( + { + "PartNumber": self._part_number, + "ETag": response["ETag"], + } + ) + logger.debug("Uploaded part %d", self._part_number) + self._part_number += 1 + self._buffer = io.BytesIO() diff --git a/api/poetry.lock b/api/poetry.lock index e9a03e7680e2..3a1a66c80be0 100644 --- a/api/poetry.lock +++ b/api/poetry.lock @@ -3204,6 +3204,21 @@ files = [ [package.dependencies] typing-extensions = {version = ">=4.1.0", markers = "python_version < \"3.12\""} +[[package]] +name = "mypy-boto3-s3" +version = "1.42.37" +description = "Type annotations for boto3 S3 1.42.37 service generated with mypy-boto3-builder 8.12.0" +optional = false +python-versions = ">=3.9" +groups = ["dev"] +files = [ + {file = "mypy_boto3_s3-1.42.37-py3-none-any.whl", hash = "sha256:7c118665f3f583dbfde1013ce47908749f9d2a760f28f59ec65732306ee9cec9"}, + {file = "mypy_boto3_s3-1.42.37.tar.gz", hash = "sha256:628a4652f727870a07e1c3854d6f30dc545a7dd5a4b719a2c59c32a95d92e4c1"}, +] + +[package.dependencies] +typing-extensions = {version = "*", markers = "python_version < \"3.12\""} + [[package]] name = "mypy-extensions" version = "1.0.0" @@ -4324,7 +4339,6 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -5599,4 +5613,4 @@ files = [ [metadata] lock-version = "2.1" python-versions = ">3.11,<3.13" -content-hash = "4c7176bf5bb304a1b1c514fef9ffb8c0ee6bddbbb81d7c00f82dab29ef43ca8b" +content-hash = "a1b3cf0629ecd28b7090b1f945f80730595887c59ef257b9dd7b39c852f67cbd" diff --git a/api/pyproject.toml b/api/pyproject.toml index f06d3cc411ea..3bc5cacacb5c 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -228,6 +228,7 @@ requests-mock = "^1.11.0" django-extensions = "^3.2.3" pdbpp = "^0.10.3" mypy-boto3-dynamodb = "^1.33.0" +mypy-boto3-s3 = "^1.36.0" pytest-structlog = "^1.1" pyfakefs = "^5.7.4" mypy = "^1.15.0" diff --git a/api/sse/sse_service.py b/api/sse/sse_service.py index 9fac789e1e86..d684333fcc47 100644 --- a/api/sse/sse_service.py +++ b/api/sse/sse_service.py @@ -61,6 +61,8 @@ def send_environment_update_message_for_environment(environment): # type: ignor def stream_access_logs( timeout_seconds: int = 300, ) -> Generator[SSEAccessLogs, None, None]: + assert settings.AWS_SSE_LOGS_BUCKET_NAME + gpg = gnupg.GPG(gnupghome=GNUPG_HOME) bucket = boto3.resource("s3").Bucket(settings.AWS_SSE_LOGS_BUCKET_NAME) diff --git a/api/tests/unit/import_export/test_unit_import_export_commands.py b/api/tests/unit/import_export/test_unit_import_export_commands.py new file mode 100644 index 000000000000..6d7d850a7565 --- /dev/null +++ b/api/tests/unit/import_export/test_unit_import_export_commands.py @@ -0,0 +1,17 @@ +from django.core.management import call_command +from pytest_mock import MockerFixture + + +def test_dumporganisationtos3_command__calls_exporter(mocker: MockerFixture) -> None: + # Given + mock_exporter = mocker.MagicMock() + mocker.patch( + "import_export.management.commands.dumporganisationtos3.S3OrganisationExporter", + return_value=mock_exporter, + ) + + # When + call_command("dumporganisationtos3", "1", "test-bucket", "test-key") + + # Then + mock_exporter.export_to_s3.assert_called_once_with(1, "test-bucket", "test-key") diff --git a/api/tests/unit/import_export/test_unit_import_export_export.py b/api/tests/unit/import_export/test_unit_import_export_export.py index 8bb9f035b503..0ce80670b8ad 100644 --- a/api/tests/unit/import_export/test_unit_import_export_export.py +++ b/api/tests/unit/import_export/test_unit_import_export_export.py @@ -70,9 +70,7 @@ def test_export_organisation(db): # type: ignore[no-untyped-def] export = export_organisation(organisation.id) # Then - assert export - - # TODO: test whether the export is importable + assert list(export) def test_export_project(organisation): # type: ignore[no-untyped-def] @@ -123,7 +121,7 @@ def test_export_project__only_live_segments_are_exported( # type: ignore[no-unt rule=segment_rule2, operator=EQUAL, property="foo", value="bar" ) # When - export = export_projects(organisation.id) + export = list(export_projects(organisation.id)) # Then # only the project and the live segment should be exported @@ -194,8 +192,10 @@ def test_export_metadata(environment, organisation, settings): # type: ignore[n field_value="some_data", ) # When - exported_environment = export_environments(environment.project.organisation_id) - exported_metadata = export_metadata(organisation.id) + exported_environment = list( + export_environments(environment.project.organisation_id) + ) + exported_metadata = list(export_metadata(organisation.id)) data = exported_environment + exported_metadata @@ -271,7 +271,7 @@ def test_export_features(project, environment, segment, admin_user): # type: ig ) # When - export = export_features(organisation_id=project.organisation_id) + export = list(export_features(organisation_id=project.organisation_id)) # Then assert export @@ -502,7 +502,7 @@ def test_export_edge_identities( # When mocker.patch("edge_api.identities.export.EXPORT_EDGE_IDENTITY_PAGINATION_LIMIT", 1) - export_json = export_edge_identities(project.organisation_id) + export_json = list(export_edge_identities(project.organisation_id)) # Let's load the data file_path = f"/tmp/{uuid.uuid4()}.json" @@ -590,7 +590,7 @@ def test_organisation_exporter_export_to_s3(organisation): # type: ignore[no-un s3_client = boto3.client("s3") - exporter = S3OrganisationExporter(s3_client=s3_client) # type: ignore[no-untyped-call] + exporter = S3OrganisationExporter(s3_client=s3_client) # When exporter.export_to_s3(organisation.id, bucket_name, file_key) @@ -610,7 +610,7 @@ def test_export_dynamo_project( ) # When - we export the data - data = export_projects(organisation.id) + data = list(export_projects(organisation.id)) # and delete the project project.hard_delete() diff --git a/api/tests/unit/import_export/test_unit_import_export_import.py b/api/tests/unit/import_export/test_unit_import_export_import.py index e5331d865505..a64d59b9e1db 100644 --- a/api/tests/unit/import_export/test_unit_import_export_import.py +++ b/api/tests/unit/import_export/test_unit_import_export_import.py @@ -9,8 +9,8 @@ from organisations.models import Organisation -@mock_s3 -def test_import_organisation(organisation): # type: ignore[no-untyped-def] +@mock_s3 # type: ignore[misc] +def test_import_organisation(organisation: Organisation) -> None: # Given bucket_name = "test-bucket" file_key = "organisation-exports/org-1.json" @@ -22,7 +22,7 @@ def test_import_organisation(organisation): # type: ignore[no-untyped-def] ) body = json.dumps( - export_organisation(organisation.id), cls=DjangoJSONEncoder + list(export_organisation(organisation.id)), cls=DjangoJSONEncoder ).encode("utf-8") s3_client = boto3.client("s3") diff --git a/api/tests/unit/import_export/test_unit_import_export_utils.py b/api/tests/unit/import_export/test_unit_import_export_utils.py new file mode 100644 index 000000000000..67c0cf6382ab --- /dev/null +++ b/api/tests/unit/import_export/test_unit_import_export_utils.py @@ -0,0 +1,178 @@ +import boto3 +from moto import mock_s3 # type: ignore[import-untyped] +from pytest_mock import MockerFixture + +from import_export.utils import S3MultipartUploadWriter + + +@mock_s3 # type: ignore[misc] +def test_s3_multipart_upload_writer__single_part__completes_upload() -> None: + # Given + bucket_name = "test-bucket" + key = "test-key" + data = b"small data" + + s3_resource = boto3.resource("s3", region_name="eu-west-2") + s3_resource.create_bucket( + Bucket=bucket_name, + CreateBucketConfiguration={"LocationConstraint": "eu-west-2"}, + ) + s3_client = boto3.client("s3", region_name="eu-west-2") + + # When + with S3MultipartUploadWriter(s3_client, bucket_name, key) as writer: + writer.write(data) + + # Then + result = s3_client.get_object(Bucket=bucket_name, Key=key) + assert result["Body"].read() == data + + +@mock_s3 # type: ignore[misc] +def test_s3_multipart_upload_writer__multiple_parts__uploads_each_part( + mocker: MockerFixture, +) -> None: + # Given + bucket_name = "test-bucket" + key = "test-key" + # Create data larger than MIN_PART_SIZE (5MB) + chunk_size = S3MultipartUploadWriter.MIN_PART_SIZE + first_chunk = b"a" * chunk_size + second_chunk = b"b" * chunk_size + final_chunk = b"final" + + s3_resource = boto3.resource("s3", region_name="eu-west-2") + s3_resource.create_bucket( + Bucket=bucket_name, + CreateBucketConfiguration={"LocationConstraint": "eu-west-2"}, + ) + s3_client = boto3.client("s3", region_name="eu-west-2") + upload_part_spy = mocker.spy(s3_client, "upload_part") + + # When + with S3MultipartUploadWriter(s3_client, bucket_name, key) as writer: + writer.write(first_chunk) + writer.write(second_chunk) + writer.write(final_chunk) + + # Then + result = s3_client.get_object(Bucket=bucket_name, Key=key) + assert result["Body"].read() == first_chunk + second_chunk + final_chunk + + # Verify exactly 3 parts were uploaded + assert upload_part_spy.call_count == 3 + # Verify part numbers are sequential + part_numbers = [ + call.kwargs["PartNumber"] for call in upload_part_spy.call_args_list + ] + assert part_numbers == [1, 2, 3] + + +@mock_s3 # type: ignore[misc] +def test_s3_multipart_upload_writer__accumulates_small_writes__uploads_correctly( + mocker: MockerFixture, +) -> None: + # Given + bucket_name = "test-bucket" + key = "test-key" + small_chunk = b"x" * 1000 # 1KB + writes_to_reach_threshold = (S3MultipartUploadWriter.MIN_PART_SIZE // 1000) + 1 + + s3_resource = boto3.resource("s3", region_name="eu-west-2") + s3_resource.create_bucket( + Bucket=bucket_name, + CreateBucketConfiguration={"LocationConstraint": "eu-west-2"}, + ) + s3_client = boto3.client("s3", region_name="eu-west-2") + upload_part_spy = mocker.spy(s3_client, "upload_part") + + # When + with S3MultipartUploadWriter(s3_client, bucket_name, key) as writer: + for _ in range(writes_to_reach_threshold): + writer.write(small_chunk) + writer.write(b"final") + + # Then + result = s3_client.get_object(Bucket=bucket_name, Key=key) + expected_data = (small_chunk * writes_to_reach_threshold) + b"final" + assert result["Body"].read() == expected_data + + # Verify buffering: one part when threshold reached, one final part on exit + assert upload_part_spy.call_count == 2 + + +@mock_s3 # type: ignore[misc] +def test_s3_multipart_upload_writer__error_during_write__aborts_upload() -> None: + # Given + bucket_name = "test-bucket" + key = "test-key" + + s3_resource = boto3.resource("s3", region_name="eu-west-2") + s3_resource.create_bucket( + Bucket=bucket_name, + CreateBucketConfiguration={"LocationConstraint": "eu-west-2"}, + ) + s3_client = boto3.client("s3", region_name="eu-west-2") + + # When + try: + with S3MultipartUploadWriter(s3_client, bucket_name, key) as writer: + writer.write(b"some data") + raise ValueError("test error") + except ValueError: + pass + + # Then - the object should not exist (upload was aborted) + objects = s3_client.list_objects_v2(Bucket=bucket_name) + assert objects.get("KeyCount", 0) == 0 + + +@mock_s3 # type: ignore[misc] +def test_s3_multipart_upload_writer__no_data__completes_with_empty_object() -> None: + # Given + bucket_name = "test-bucket" + key = "test-key" + + s3_resource = boto3.resource("s3", region_name="eu-west-2") + s3_resource.create_bucket( + Bucket=bucket_name, + CreateBucketConfiguration={"LocationConstraint": "eu-west-2"}, + ) + s3_client = boto3.client("s3", region_name="eu-west-2") + + # When + with S3MultipartUploadWriter(s3_client, bucket_name, key): + pass # No writes + + # Then + result = s3_client.get_object(Bucket=bucket_name, Key=key) + assert result["Body"].read() == b"" + + +@mock_s3 # type: ignore[misc] +def test_s3_multipart_upload_writer__multiple_small_writes__buffers_correctly( + mocker: MockerFixture, +) -> None: + # Given + bucket_name = "test-bucket" + key = "test-key" + + s3_resource = boto3.resource("s3", region_name="eu-west-2") + s3_resource.create_bucket( + Bucket=bucket_name, + CreateBucketConfiguration={"LocationConstraint": "eu-west-2"}, + ) + s3_client = boto3.client("s3", region_name="eu-west-2") + upload_part_spy = mocker.spy(s3_client, "upload_part") + + # When + with S3MultipartUploadWriter(s3_client, bucket_name, key) as writer: + writer.write(b"hello ") + writer.write(b"world") + + # Then + result = s3_client.get_object(Bucket=bucket_name, Key=key) + assert result["Body"].read() == b"hello world" + + # Verify both writes were buffered and uploaded as a single part + assert upload_part_spy.call_count == 1 diff --git a/api/tests/unit/sse/test_sse_service.py b/api/tests/unit/sse/test_sse_service.py index 11d99ca5e159..ef7946ae9451 100644 --- a/api/tests/unit/sse/test_sse_service.py +++ b/api/tests/unit/sse/test_sse_service.py @@ -125,6 +125,7 @@ def test_stream_access_logs(mocker: MockerFixture, aws_credentials: None) -> Non # Next, let's create a bucket bucket_name = settings.AWS_SSE_LOGS_BUCKET_NAME + assert bucket_name s3_client = boto3.client("s3", region_name="eu-west-2") s3_client.create_bucket( Bucket=bucket_name,