From be10c480986794299dc5ec58c42747256e140f22 Mon Sep 17 00:00:00 2001 From: sudhakartag1 Date: Tue, 16 Dec 2025 14:42:23 +1100 Subject: [PATCH 1/2] Add initContainer support (#426) --- .../core/integrations/kubernetes/__init__.py | 156 +++--- tests/test_initcontainers.py | 487 ++++++++++++++++++ 2 files changed, 580 insertions(+), 63 deletions(-) create mode 100644 tests/test_initcontainers.py diff --git a/robusta_krr/core/integrations/kubernetes/__init__.py b/robusta_krr/core/integrations/kubernetes/__init__.py index 1398c62b..dc796262 100644 --- a/robusta_krr/core/integrations/kubernetes/__init__.py +++ b/robusta_krr/core/integrations/kubernetes/__init__.py @@ -24,8 +24,26 @@ from robusta_krr.utils.object_like_dict import ObjectLikeDict +def _get_init_containers(spec: Any) -> list: + """Get init containers from a pod spec, handling both snake_case and camelCase. + + Standard K8s Python client uses snake_case (init_containers), but CustomObjectsApi + returns raw JSON with camelCase (initContainers). This helper checks both. + """ + # Try snake_case first (standard K8s Python client) + init_containers = getattr(spec, "init_containers", None) + if init_containers is not None: + return init_containers + # Fall back to camelCase (CustomObjectsApi / ObjectLikeDict) + init_containers = getattr(spec, "initContainers", None) + if init_containers is not None: + return init_containers + return [] + + class LightweightJobInfo: """Lightweight job object containing only the fields needed for GroupedJob processing.""" + def __init__(self, name: str, namespace: str): self.name = name self.namespace = namespace @@ -40,7 +58,7 @@ def __init__(self, name: str, namespace: str): class ClusterLoader: - def __init__(self, cluster: Optional[str]=None): + def __init__(self, cluster: Optional[str] = None): self.cluster = cluster # This executor will be running requests to Kubernetes API self.executor = ThreadPoolExecutor(settings.max_workers) @@ -86,7 +104,7 @@ def namespaces(self) -> Union[list[str], Literal["*"]]: if expand_list: logger.info("found regex pattern in provided namespace argument, expanding namespace list") - all_ns = [ ns.metadata.name for ns in self.core.list_namespace().items ] + all_ns = [ns.metadata.name for ns in self.core.list_namespace().items] for expand_ns in expand_list: for ns in all_ns: if expand_ns.fullmatch(ns) and ns not in self.__namespaces: @@ -156,9 +174,9 @@ async def list_pods(self, object: K8sObjectData) -> list[PodData]: selector = f"batch.kubernetes.io/controller-uid in ({','.join(ownered_jobs_uids)})" elif object.kind == "GroupedJob": - if not hasattr(object._api_resource, '_label_filter') or not object._api_resource._label_filter: + if not hasattr(object._api_resource, "_label_filter") or not object._api_resource._label_filter: return [] - + # Use the label+value filter to get pods ret: V1PodList = await loop.run_in_executor( self.executor, @@ -166,9 +184,9 @@ async def list_pods(self, object: K8sObjectData) -> list[PodData]: namespace=object.namespace, label_selector=object._api_resource._label_filter ), ) - + # Apply the job grouping limit to pod results - limited_pods = ret.items[:settings.job_grouping_limit] + limited_pods = ret.items[: settings.job_grouping_limit] return [PodData(name=pod.metadata.name, deleted=False) for pod in limited_pods] else: @@ -209,14 +227,15 @@ def _build_selector_query(selector: Any) -> Union[str, None]: label_filters += [ ClusterLoader._get_match_expression_filter(expression) for expression in selector.match_expressions ] - + # normally the kubernetes API client renames matchLabels to match_labels in python # but for CRDs like ArgoRollouts that renaming doesn't happen and we have selector={'matchLabels': {'app': 'test-app'}} if getattr(selector, "matchLabels", None): label_filters += [f"{label[0]}={label[1]}" for label in getattr(selector, "matchLabels").items()] if getattr(selector, "matchExpressions", None): label_filters += [ - ClusterLoader._get_match_expression_filter(expression) for expression in getattr(selector, "matchExpressions").items() + ClusterLoader._get_match_expression_filter(expression) + for expression in getattr(selector, "matchExpressions").items() ] if label_filters == []: @@ -241,7 +260,7 @@ def __build_scannable_object( if item.metadata.labels: if type(item.metadata.labels) is ObjectLikeDict: labels = item.metadata.labels.__dict__ - else: + else: labels = item.metadata.labels if item.metadata.annotations: @@ -259,7 +278,7 @@ def __build_scannable_object( allocations=ResourceAllocations.from_container(container), hpa=self.__hpa_list.get((namespace, kind, name)), labels=labels, - annotations= annotations + annotations=annotations, ) obj._api_resource = item return obj @@ -315,19 +334,16 @@ async def _list_namespaced_or_global_objects_batched( limit=limit, _continue=continue_ref, ), - ) ] + ) + ] gathered_results = await asyncio.gather(*requests) - - result = [ - item - for request_result in gathered_results - for item in request_result.items - ] + + result = [item for request_result in gathered_results for item in request_result.items] next_continue_ref = None if gathered_results: - next_continue_ref = getattr(gathered_results[0].metadata, '_continue', None) + next_continue_ref = getattr(gathered_results[0].metadata, "_continue", None) return result, next_continue_ref @@ -335,6 +351,7 @@ async def _list_namespaced_or_global_objects_batched( if e.status == 410 and e.body: # Continue token expired import json + try: error_body = json.loads(e.body) new_continue_token = error_body.get("metadata", {}).get("continue") @@ -346,10 +363,7 @@ async def _list_namespaced_or_global_objects_batched( raise async def _list_namespaced_or_global_objects( - self, - kind: KindLiteral, - all_namespaces_request: Callable, - namespaced_request: Callable + self, kind: KindLiteral, all_namespaces_request: Callable, namespaced_request: Callable ) -> list[Any]: logger.debug(f"Listing {kind}s in {self.cluster}") loop = asyncio.get_running_loop() @@ -377,11 +391,7 @@ async def _list_namespaced_or_global_objects( for namespace in self.namespaces ] - result = [ - item - for request_result in await asyncio.gather(*requests) - for item in request_result.items - ] + result = [item for request_result in await asyncio.gather(*requests) for item in request_result.items] logger.debug(f"Found {len(result)} {kind} in {self.cluster}") return result @@ -400,7 +410,7 @@ async def _list_scannable_objects( if not self.__kind_available[kind]: return [] - + result = [] try: for item in await self._list_namespaced_or_global_objects(kind, all_namespaces_request, namespaced_request): @@ -428,13 +438,15 @@ def _list_deployments(self) -> list[K8sObjectData]: kind="Deployment", all_namespaces_request=self.apps.list_deployment_for_all_namespaces, namespaced_request=self.apps.list_namespaced_deployment, - extract_containers=lambda item: item.spec.template.spec.containers, + extract_containers=lambda item: item.spec.template.spec.containers + + (item.spec.template.spec.init_containers or []), ) def _list_rollouts(self) -> list[K8sObjectData]: async def _extract_containers(item: Any) -> list[V1Container]: if item.spec.template is not None: - return item.spec.template.spec.containers + # CustomObjectsApi returns camelCase, use helper for dual lookup + return item.spec.template.spec.containers + _get_init_containers(item.spec.template.spec) loop = asyncio.get_running_loop() @@ -451,7 +463,7 @@ async def _extract_containers(item: Any) -> list[V1Container]: namespace=item.metadata.namespace, name=workloadRef.name ), ) - return ret.spec.template.spec.containers + return ret.spec.template.spec.containers + (ret.spec.template.spec.init_containers or []) return [] @@ -499,7 +511,9 @@ def _list_strimzipodsets(self) -> list[K8sObjectData]: **kwargs, ) ), - extract_containers=lambda item: item.spec.pods[0].spec.containers, + # CustomObjectsApi returns camelCase, use helper for dual lookup + extract_containers=lambda item: item.spec.pods[0].spec.containers + + _get_init_containers(item.spec.pods[0].spec), ) def _list_deploymentconfig(self) -> list[K8sObjectData]: @@ -523,7 +537,9 @@ def _list_deploymentconfig(self) -> list[K8sObjectData]: **kwargs, ) ), - extract_containers=lambda item: item.spec.template.spec.containers, + # CustomObjectsApi returns camelCase, use helper for dual lookup + extract_containers=lambda item: item.spec.template.spec.containers + + _get_init_containers(item.spec.template.spec), ) def _list_all_statefulsets(self) -> list[K8sObjectData]: @@ -531,7 +547,8 @@ def _list_all_statefulsets(self) -> list[K8sObjectData]: kind="StatefulSet", all_namespaces_request=self.apps.list_stateful_set_for_all_namespaces, namespaced_request=self.apps.list_namespaced_stateful_set, - extract_containers=lambda item: item.spec.template.spec.containers, + extract_containers=lambda item: item.spec.template.spec.containers + + (item.spec.template.spec.init_containers or []), ) def _list_all_daemon_set(self) -> list[K8sObjectData]: @@ -539,15 +556,15 @@ def _list_all_daemon_set(self) -> list[K8sObjectData]: kind="DaemonSet", all_namespaces_request=self.apps.list_daemon_set_for_all_namespaces, namespaced_request=self.apps.list_namespaced_daemon_set, - extract_containers=lambda item: item.spec.template.spec.containers, + extract_containers=lambda item: item.spec.template.spec.containers + + (item.spec.template.spec.init_containers or []), ) - async def _list_all_jobs(self) -> list[K8sObjectData]: """List all jobs using batched loading with 500 batch size.""" if not self._should_list_resource("Job"): return [] - + namespaces = self.namespaces if self.namespaces != "*" else ["*"] all_jobs = [] try: @@ -572,7 +589,7 @@ async def _list_all_jobs(self) -> list[K8sObjectData]: if not jobs_batch: # no more jobs to batch do not count empty batches break - + batch_count += 1 for job in jobs_batch: if self._is_job_owned_by_cronjob(job): @@ -581,12 +598,14 @@ async def _list_all_jobs(self) -> list[K8sObjectData]: continue for container in job.spec.template.spec.containers: all_jobs.append(self.__build_scannable_object(job, container, "Job")) + for container in job.spec.template.spec.init_containers or []: + all_jobs.append(self.__build_scannable_object(job, container, "Job")) if not continue_ref: break - + logger.debug("Found %d regular jobs", len(all_jobs)) return all_jobs - + except Exception as e: logger.error( "Failed to run jobs discovery", @@ -599,7 +618,8 @@ def _list_all_cronjobs(self) -> list[K8sObjectData]: kind="CronJob", all_namespaces_request=self.batch.list_cron_job_for_all_namespaces, namespaced_request=self.batch.list_namespaced_cron_job, - extract_containers=lambda item: item.spec.job_template.spec.template.spec.containers, + extract_containers=lambda item: item.spec.job_template.spec.template.spec.containers + + (item.spec.job_template.spec.template.spec.init_containers or []), ) async def _list_all_groupedjobs(self) -> list[K8sObjectData]: @@ -607,13 +627,13 @@ async def _list_all_groupedjobs(self) -> list[K8sObjectData]: if not settings.job_grouping_labels: logger.debug("No job grouping labels configured, skipping GroupedJob listing") return [] - + if not self._should_list_resource("GroupedJob"): logger.debug("Skipping GroupedJob in cluster") return [] - + logger.debug("Listing GroupedJobs with grouping labels: %s", settings.job_grouping_labels) - + grouped_jobs = defaultdict(list) grouped_jobs_template = {} # Store only ONE full job as template per group - needed for class K8sObjectData continue_ref: Optional[str] = None @@ -632,7 +652,7 @@ async def _list_all_groupedjobs(self) -> list[K8sObjectData]: limit=settings.discovery_job_batch_size, continue_ref=continue_ref, ) - + continue_ref = next_continue_ref if not jobs_batch and continue_ref: @@ -645,7 +665,11 @@ async def _list_all_groupedjobs(self) -> list[K8sObjectData]: batch_count += 1 for job in jobs_batch: - if not job.metadata.labels or self._is_job_owned_by_cronjob(job) or not self._is_job_grouped(job): + if ( + not job.metadata.labels + or self._is_job_owned_by_cronjob(job) + or not self._is_job_grouped(job) + ): continue for label_name in settings.job_grouping_labels: if label_name not in job.metadata.labels: @@ -654,8 +678,7 @@ async def _list_all_groupedjobs(self) -> list[K8sObjectData]: label_value = job.metadata.labels[label_name] group_key = f"{label_name}={label_value}" lightweight_job = LightweightJobInfo( - name=job.metadata.name, - namespace=job.metadata.namespace + name=job.metadata.name, namespace=job.metadata.namespace ) # Store lightweight job info only for grouped jobs grouped_jobs[group_key].append(lightweight_job) @@ -664,50 +687,55 @@ async def _list_all_groupedjobs(self) -> list[K8sObjectData]: grouped_jobs_template[group_key] = job if not continue_ref: break - + except Exception as e: logger.error( "Failed to run grouped jobs discovery", exc_info=True, ) raise - + result = [] for group_name, jobs in grouped_jobs.items(): template_job = grouped_jobs_template[group_name] - + jobs_by_namespace = defaultdict(list) for job in jobs: jobs_by_namespace[job.namespace].append(job) - + for namespace, namespace_jobs in jobs_by_namespace.items(): - limited_jobs = namespace_jobs[:settings.job_grouping_limit] - + limited_jobs = namespace_jobs[: settings.job_grouping_limit] + container_names = set() for container in template_job.spec.template.spec.containers: container_names.add(container.name) - + for container in template_job.spec.template.spec.init_containers or []: + container_names.add(container.name) + for container_name in container_names: template_container = None for container in template_job.spec.template.spec.containers: if container.name == container_name: template_container = container break - + if template_container is None: + for container in template_job.spec.template.spec.init_containers or []: + if container.name == container_name: + template_container = container + break + if template_container: grouped_job = self.__build_scannable_object( - item=template_job, - container=template_container, - kind="GroupedJob" + item=template_job, container=template_container, kind="GroupedJob" ) - + grouped_job.name = group_name grouped_job.namespace = namespace grouped_job._api_resource._grouped_jobs = limited_jobs grouped_job._api_resource._label_filter = group_name - + result.append(grouped_job) - + logger.debug("Found %d GroupedJob groups", len(result)) return result @@ -743,6 +771,7 @@ async def __list_hpa_v2(self) -> dict[HPAKey, HPAData]: all_namespaces_request=self.autoscaling_v2.list_horizontal_pod_autoscaler_for_all_namespaces, namespaced_request=self.autoscaling_v2.list_namespaced_horizontal_pod_autoscaler, ) + def __get_metric(hpa: V2HorizontalPodAutoscaler, metric_name: str) -> Optional[float]: return next( ( @@ -752,6 +781,7 @@ def __get_metric(hpa: V2HorizontalPodAutoscaler, metric_name: str) -> Optional[f ), None, ) + return { ( hpa.metadata.namespace, @@ -865,7 +895,7 @@ async def list_scannable_objects(self, clusters: Optional[list[str]]) -> list[K8 if self.cluster_loaders == {}: logger.error("Could not load any cluster.") return - + return [ object for cluster_loader in self.cluster_loaders.values() diff --git a/tests/test_initcontainers.py b/tests/test_initcontainers.py new file mode 100644 index 00000000..826c6e48 --- /dev/null +++ b/tests/test_initcontainers.py @@ -0,0 +1,487 @@ +""" +Tests for initContainer support in KRR. + +Verifies that init_containers are properly extracted alongside regular containers +for all 9 supported workload types. + +Also tests the dual snake_case/camelCase handling for CustomObjectsApi workloads. +""" + +import pytest +from unittest.mock import MagicMock, patch +from robusta_krr.core.integrations.kubernetes import ClusterLoader, _get_init_containers +from robusta_krr.core.models.config import Config +from robusta_krr.utils.object_like_dict import ObjectLikeDict + + +@pytest.fixture +def mock_config(): + """Mock config for testing""" + config = MagicMock(spec=Config) + config.job_grouping_labels = [] + config.job_grouping_limit = 10 + config.discovery_job_batch_size = 1000 + config.discovery_job_max_batches = 50 + config.max_workers = 4 + config.get_kube_client = MagicMock() + config.resources = "*" + config.selector = None + config.namespaces = "*" + return config + + +@pytest.fixture +def mock_cluster_loader(mock_config): + """Create a ClusterLoader instance with mocked dependencies""" + with patch("robusta_krr.core.integrations.kubernetes.settings", mock_config): + loader = ClusterLoader() + loader.apps = MagicMock() + loader.batch = MagicMock() + loader.core = MagicMock() + loader.custom_objects = MagicMock() + loader._ClusterLoader__hpa_list = {} + + # Mock executor + from concurrent.futures import Future + + mock_future = Future() + mock_future.set_result(None) + loader.executor = MagicMock() + loader.executor.submit.return_value = mock_future + + return loader + + +def create_mock_container(name: str): + """Create a mock container object""" + container = MagicMock() + container.name = name + container.resources = MagicMock() + container.resources.requests = {"cpu": "100m", "memory": "128Mi"} + container.resources.limits = {"cpu": "200m", "memory": "256Mi"} + return container + + +def create_mock_deployment(name: str, namespace: str, containers: list, init_containers: list = None): + """Create a mock Deployment with containers and optional init_containers""" + deployment = MagicMock() + deployment.metadata.name = name + deployment.metadata.namespace = namespace + deployment.metadata.labels = {"app": name} + deployment.metadata.annotations = {} + deployment.spec.template.spec.containers = containers + deployment.spec.template.spec.init_containers = init_containers + deployment.spec.selector = MagicMock() + deployment.spec.selector.match_labels = {"app": name} + deployment.spec.selector.match_expressions = None + return deployment + + +def create_mock_job(name: str, namespace: str, containers: list, init_containers: list = None): + """Create a mock Job with containers and optional init_containers""" + job = MagicMock() + job.metadata.name = name + job.metadata.namespace = namespace + job.metadata.labels = {"app": name} + job.metadata.annotations = {} + job.metadata.owner_references = [] + job.spec.template.spec.containers = containers + job.spec.template.spec.init_containers = init_containers + return job + + +def create_mock_cronjob(name: str, namespace: str, containers: list, init_containers: list = None): + """Create a mock CronJob with containers and optional init_containers""" + cronjob = MagicMock() + cronjob.metadata.name = name + cronjob.metadata.namespace = namespace + cronjob.metadata.labels = {"app": name} + cronjob.metadata.annotations = {} + cronjob.spec.job_template.spec.template.spec.containers = containers + cronjob.spec.job_template.spec.template.spec.init_containers = init_containers + cronjob.spec.selector = MagicMock() + cronjob.spec.selector.match_labels = {"app": name} + cronjob.spec.selector.match_expressions = None + return cronjob + + +# Helper functions to match the extract_containers patterns used in the main code +def extract_deployment_containers(item): + """Extract containers from Deployment/StatefulSet/DaemonSet (standard pattern)""" + return item.spec.template.spec.containers + (item.spec.template.spec.init_containers or []) + + +def extract_cronjob_containers(item): + """Extract containers from CronJob (job_template path)""" + return item.spec.job_template.spec.template.spec.containers + ( + item.spec.job_template.spec.template.spec.init_containers or [] + ) + + +def extract_strimzipodset_containers(item): + """Extract containers from StrimziPodSet (pods[0] path)""" + return item.spec.pods[0].spec.containers + (item.spec.pods[0].spec.init_containers or []) + + +class TestInitContainerExtraction: + """Test that init_containers are properly extracted from workloads using direct lambda tests""" + + def test_deployment_extract_with_initcontainers(self): + """Test that Deployment extract lambda includes both containers and init_containers""" + main_container = create_mock_container("main-app") + init_container = create_mock_container("init-db") + + deployment = create_mock_deployment( + "test-deployment", "default", containers=[main_container], init_containers=[init_container] + ) + + result = extract_deployment_containers(deployment) + assert len(result) == 2 + container_names = {c.name for c in result} + assert "main-app" in container_names + assert "init-db" in container_names + + def test_deployment_extract_without_initcontainers(self): + """Test that Deployment extract lambda works when init_containers is None""" + main_container = create_mock_container("main-app") + + deployment = create_mock_deployment( + "test-deployment", "default", containers=[main_container], init_containers=None + ) + + result = extract_deployment_containers(deployment) + assert len(result) == 1 + assert result[0].name == "main-app" + + def test_deployment_extract_with_empty_initcontainers(self): + """Test that Deployment extract lambda works when init_containers is empty list""" + main_container = create_mock_container("main-app") + + deployment = create_mock_deployment( + "test-deployment", "default", containers=[main_container], init_containers=[] + ) + + result = extract_deployment_containers(deployment) + assert len(result) == 1 + assert result[0].name == "main-app" + + def test_deployment_extract_with_multiple_initcontainers(self): + """Test Deployment extract lambda with multiple init_containers""" + main_container = create_mock_container("main-app") + init_container1 = create_mock_container("init-db") + init_container2 = create_mock_container("init-cache") + init_container3 = create_mock_container("init-config") + + deployment = create_mock_deployment( + "test-deployment", + "default", + containers=[main_container], + init_containers=[init_container1, init_container2, init_container3], + ) + + result = extract_deployment_containers(deployment) + assert len(result) == 4 + container_names = {c.name for c in result} + assert container_names == {"main-app", "init-db", "init-cache", "init-config"} + + +class TestCronJobInitContainers: + """Test initContainer extraction for CronJobs (different spec path)""" + + def test_cronjob_extract_with_initcontainers(self): + """Test that CronJob extract lambda includes init_containers from job_template path""" + main_container = create_mock_container("cron-task") + init_container = create_mock_container("init-setup") + + cronjob = create_mock_cronjob( + "test-cronjob", "default", containers=[main_container], init_containers=[init_container] + ) + + result = extract_cronjob_containers(cronjob) + assert len(result) == 2 + container_names = {c.name for c in result} + assert "cron-task" in container_names + assert "init-setup" in container_names + + def test_cronjob_extract_without_initcontainers(self): + """Test CronJob extract lambda when init_containers is None""" + main_container = create_mock_container("cron-task") + + cronjob = create_mock_cronjob("test-cronjob", "default", containers=[main_container], init_containers=None) + + result = extract_cronjob_containers(cronjob) + assert len(result) == 1 + assert result[0].name == "cron-task" + + +class TestJobInitContainers: + """Test initContainer extraction for Jobs (uses inline loop)""" + + @pytest.mark.asyncio + async def test_job_with_initcontainers(self, mock_cluster_loader, mock_config): + """Test that Job extracts both containers and init_containers via inline loop""" + main_container = create_mock_container("job-worker") + init_container = create_mock_container("init-data") + + job = create_mock_job("test-job", "default", containers=[main_container], init_containers=[init_container]) + + # Mock the batched method + async def mock_batched_method(*args, **kwargs): + return ([job], None) + + mock_cluster_loader._list_namespaced_or_global_objects_batched = mock_batched_method + + # Mock _is_job_owned_by_cronjob and _is_job_grouped + mock_cluster_loader._is_job_owned_by_cronjob = MagicMock(return_value=False) + mock_cluster_loader._is_job_grouped = MagicMock(return_value=False) + + with patch("robusta_krr.core.integrations.kubernetes.settings", mock_config): + result = await mock_cluster_loader._list_all_jobs() + + # Should find 2 containers (1 main + 1 init) + assert len(result) == 2 + container_names = {r.container for r in result} + assert "job-worker" in container_names + assert "init-data" in container_names + + @pytest.mark.asyncio + async def test_job_without_initcontainers(self, mock_cluster_loader, mock_config): + """Test that Job works when init_containers is None""" + main_container = create_mock_container("job-worker") + + job = create_mock_job("test-job", "default", containers=[main_container], init_containers=None) + + async def mock_batched_method(*args, **kwargs): + return ([job], None) + + mock_cluster_loader._list_namespaced_or_global_objects_batched = mock_batched_method + mock_cluster_loader._is_job_owned_by_cronjob = MagicMock(return_value=False) + mock_cluster_loader._is_job_grouped = MagicMock(return_value=False) + + with patch("robusta_krr.core.integrations.kubernetes.settings", mock_config): + result = await mock_cluster_loader._list_all_jobs() + + assert len(result) == 1 + assert result[0].container == "job-worker" + + +class TestGroupedJobInitContainers: + """Test initContainer extraction for GroupedJobs""" + + @pytest.mark.asyncio + async def test_groupedjob_with_initcontainers(self, mock_cluster_loader, mock_config): + """Test that GroupedJob extracts init_containers""" + mock_config.job_grouping_labels = ["app"] + + main_container = create_mock_container("worker") + init_container = create_mock_container("init-setup") + + job = create_mock_job("test-job-1", "default", containers=[main_container], init_containers=[init_container]) + job.metadata.labels = {"app": "test-app"} + + async def mock_batched_method(*args, **kwargs): + return ([job], None) + + mock_cluster_loader._list_namespaced_or_global_objects_batched = mock_batched_method + mock_cluster_loader._is_job_owned_by_cronjob = MagicMock(return_value=False) + mock_cluster_loader._is_job_grouped = MagicMock(return_value=True) + + with patch("robusta_krr.core.integrations.kubernetes.settings", mock_config): + result = await mock_cluster_loader._list_all_groupedjobs() + + # Should find 2 GroupedJob objects (one per container name) + assert len(result) == 2 + container_names = {r.container for r in result} + assert "worker" in container_names + assert "init-setup" in container_names + + +class TestExtractContainersLambda: + """Unit tests for the extract_containers lambda patterns""" + + def test_standard_pattern_with_initcontainers(self): + """Test the standard lambda pattern includes init_containers""" + mock_item = MagicMock() + mock_item.spec.template.spec.containers = [create_mock_container("main")] + mock_item.spec.template.spec.init_containers = [create_mock_container("init")] + + result = extract_deployment_containers(mock_item) + assert len(result) == 2 + + def test_standard_pattern_with_none_initcontainers(self): + """Test the standard lambda pattern handles None init_containers""" + mock_item = MagicMock() + mock_item.spec.template.spec.containers = [create_mock_container("main")] + mock_item.spec.template.spec.init_containers = None + + result = extract_deployment_containers(mock_item) + assert len(result) == 1 + + def test_cronjob_pattern_with_initcontainers(self): + """Test the CronJob lambda pattern includes init_containers""" + mock_item = MagicMock() + mock_item.spec.job_template.spec.template.spec.containers = [create_mock_container("cron-main")] + mock_item.spec.job_template.spec.template.spec.init_containers = [create_mock_container("cron-init")] + + result = extract_cronjob_containers(mock_item) + assert len(result) == 2 + + def test_strimzipodset_pattern_with_initcontainers(self): + """Test the StrimziPodSet lambda pattern includes init_containers""" + mock_item = MagicMock() + mock_item.spec.pods = [MagicMock()] + mock_item.spec.pods[0].spec.containers = [create_mock_container("kafka")] + mock_item.spec.pods[0].spec.init_containers = [create_mock_container("init-kafka")] + + result = extract_strimzipodset_containers(mock_item) + assert len(result) == 2 + + +class TestGetInitContainersHelper: + """Tests for the _get_init_containers helper function that handles snake_case/camelCase""" + + def test_snake_case_init_containers(self): + """Test that snake_case init_containers is found (standard K8s Python client)""" + spec = MagicMock() + spec.init_containers = [create_mock_container("init-db")] + + result = _get_init_containers(spec) + assert len(result) == 1 + assert result[0].name == "init-db" + + def test_camel_case_initContainers(self): + """Test that camelCase initContainers is found (CustomObjectsApi)""" + # Simulate ObjectLikeDict behavior where snake_case returns None + spec = MagicMock() + spec.init_containers = None # snake_case not present + spec.initContainers = [create_mock_container("init-setup")] + + result = _get_init_containers(spec) + assert len(result) == 1 + assert result[0].name == "init-setup" + + def test_neither_present_returns_empty(self): + """Test that empty list is returned when neither attribute exists""" + spec = MagicMock() + spec.init_containers = None + spec.initContainers = None + + result = _get_init_containers(spec) + assert result == [] + + def test_snake_case_takes_precedence(self): + """Test that snake_case is checked first when both exist""" + spec = MagicMock() + spec.init_containers = [create_mock_container("snake-init")] + spec.initContainers = [create_mock_container("camel-init")] + + result = _get_init_containers(spec) + assert len(result) == 1 + assert result[0].name == "snake-init" + + def test_empty_snake_case_returns_empty(self): + """Test that empty list from snake_case is returned (not fallback to camelCase)""" + spec = MagicMock() + spec.init_containers = [] # Empty but not None + spec.initContainers = [create_mock_container("camel-init")] + + result = _get_init_containers(spec) + assert result == [] + + +class TestCustomObjectsApiCamelCase: + """Tests for CustomObjectsApi workloads that use camelCase (initContainers)""" + + def test_rollout_with_camelcase_initcontainers(self): + """Test Rollout extraction with camelCase initContainers (CustomObjectsApi)""" + # Simulate ObjectLikeDict from CustomObjectsApi - uses camelCase + rollout_dict = { + "metadata": {"name": "test-rollout", "namespace": "default"}, + "spec": { + "template": { + "spec": { + "containers": [{"name": "main-app", "resources": {}}], + "initContainers": [{"name": "init-setup", "resources": {}}], + } + } + }, + } + item = ObjectLikeDict(rollout_dict) + + # Use the helper function as the implementation does + result = item.spec.template.spec.containers + _get_init_containers(item.spec.template.spec) + assert len(result) == 2 + container_names = {c.name if hasattr(c, "name") else c["name"] for c in result} + assert "main-app" in container_names + assert "init-setup" in container_names + + def test_rollout_without_initcontainers(self): + """Test Rollout extraction when initContainers is not present""" + rollout_dict = { + "metadata": {"name": "test-rollout", "namespace": "default"}, + "spec": { + "template": { + "spec": { + "containers": [{"name": "main-app", "resources": {}}], + } + } + }, + } + item = ObjectLikeDict(rollout_dict) + + result = item.spec.template.spec.containers + _get_init_containers(item.spec.template.spec) + assert len(result) == 1 + + def test_strimzipodset_with_camelcase_initcontainers(self): + """Test StrimziPodSet extraction with camelCase initContainers""" + strimzi_dict = { + "metadata": {"name": "kafka-cluster", "namespace": "kafka"}, + "spec": { + "pods": [ + { + "spec": { + "containers": [{"name": "kafka", "resources": {}}], + "initContainers": [{"name": "init-kafka", "resources": {}}], + } + } + ] + }, + } + item = ObjectLikeDict(strimzi_dict) + + result = item.spec.pods[0].spec.containers + _get_init_containers(item.spec.pods[0].spec) + assert len(result) == 2 + container_names = {c.name if hasattr(c, "name") else c["name"] for c in result} + assert "kafka" in container_names + assert "init-kafka" in container_names + + def test_deploymentconfig_with_camelcase_initcontainers(self): + """Test DeploymentConfig extraction with camelCase initContainers""" + dc_dict = { + "metadata": {"name": "myapp-dc", "namespace": "openshift-project"}, + "spec": { + "template": { + "spec": { + "containers": [{"name": "app", "resources": {}}], + "initContainers": [{"name": "init-config", "resources": {}}], + } + } + }, + } + item = ObjectLikeDict(dc_dict) + + result = item.spec.template.spec.containers + _get_init_containers(item.spec.template.spec) + assert len(result) == 2 + container_names = {c.name if hasattr(c, "name") else c["name"] for c in result} + assert "app" in container_names + assert "init-config" in container_names + + def test_objectlikedict_returns_none_for_missing_attr(self): + """Verify ObjectLikeDict returns None for missing attributes (the root cause)""" + obj_dict = {"containers": [{"name": "main"}]} + obj = ObjectLikeDict(obj_dict) + + # This is the behavior that caused the original bug + assert obj.init_containers is None # snake_case - not in dict + assert obj.initContainers is None # camelCase - also not in dict + assert obj.containers is not None # This key exists From ccb960734cc39fe487b1df2232dc20138706ebe8 Mon Sep 17 00:00:00 2001 From: sudhakartag1 Date: Tue, 16 Dec 2025 14:54:26 +1100 Subject: [PATCH 2/2] fix: use PEP 484 optional type hints --- tests/test_initcontainers.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_initcontainers.py b/tests/test_initcontainers.py index 826c6e48..ed391ae9 100644 --- a/tests/test_initcontainers.py +++ b/tests/test_initcontainers.py @@ -62,7 +62,7 @@ def create_mock_container(name: str): return container -def create_mock_deployment(name: str, namespace: str, containers: list, init_containers: list = None): +def create_mock_deployment(name: str, namespace: str, containers: list, init_containers: list | None = None): """Create a mock Deployment with containers and optional init_containers""" deployment = MagicMock() deployment.metadata.name = name @@ -77,7 +77,7 @@ def create_mock_deployment(name: str, namespace: str, containers: list, init_con return deployment -def create_mock_job(name: str, namespace: str, containers: list, init_containers: list = None): +def create_mock_job(name: str, namespace: str, containers: list, init_containers: list | None = None): """Create a mock Job with containers and optional init_containers""" job = MagicMock() job.metadata.name = name @@ -90,7 +90,7 @@ def create_mock_job(name: str, namespace: str, containers: list, init_containers return job -def create_mock_cronjob(name: str, namespace: str, containers: list, init_containers: list = None): +def create_mock_cronjob(name: str, namespace: str, containers: list, init_containers: list | None = None): """Create a mock CronJob with containers and optional init_containers""" cronjob = MagicMock() cronjob.metadata.name = name