Skip to content

Commit

Permalink
Refactor k8s workloads streaming (#256)
Browse files Browse the repository at this point in the history
* Refactor k8s workloads streaming

* Fix tests
  • Loading branch information
LeaveMyYard authored Apr 4, 2024
1 parent d0e531b commit 76ed553
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 106 deletions.
97 changes: 47 additions & 50 deletions robusta_krr/core/integrations/kubernetes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from typing import Any, AsyncGenerator, AsyncIterable, Awaitable, Callable, Iterable, Optional, Union
from typing import Any, Awaitable, Callable, Iterable, Optional, Union

from kubernetes import client, config # type: ignore
from kubernetes.client import ApiException
Expand All @@ -20,7 +20,6 @@
from robusta_krr.core.models.config import settings
from robusta_krr.core.models.objects import HPAData, K8sObjectData, KindLiteral, PodData
from robusta_krr.core.models.result import ResourceAllocations
from robusta_krr.utils.async_gen_merge import async_gen_merge
from robusta_krr.utils.object_like_dict import ObjectLikeDict

from . import config_patch as _
Expand Down Expand Up @@ -49,7 +48,7 @@ def __init__(self, cluster: Optional[str]=None):
self.__jobs_for_cronjobs: dict[str, list[V1Job]] = {}
self.__jobs_loading_locks: defaultdict[str, asyncio.Lock] = defaultdict(asyncio.Lock)

async def list_scannable_objects(self) -> AsyncGenerator[K8sObjectData, None]:
async def list_scannable_objects(self) -> list[K8sObjectData]:
"""List all scannable objects.
Returns:
Expand All @@ -61,22 +60,23 @@ async def list_scannable_objects(self) -> AsyncGenerator[K8sObjectData, None]:
logger.debug(f"Resources: {settings.resources}")

self.__hpa_list = await self._try_list_hpa()

# https://stackoverflow.com/questions/55299564/join-multiple-async-generators-in-python
# This will merge all the streams from all the cluster loaders into a single stream
async for object in async_gen_merge(
workload_object_lists = await asyncio.gather(
self._list_deployments(),
self._list_rollouts(),
self._list_deploymentconfig(),
self._list_all_statefulsets(),
self._list_all_daemon_set(),
self._list_all_jobs(),
self._list_all_cronjobs(),
):
)

return [
object
for workload_objects in workload_object_lists
for object in workload_objects
# NOTE: By default we will filter out kube-system namespace
if settings.namespaces == "*" and object.namespace == "kube-system":
continue
yield object
if not (settings.namespaces == "*" and object.namespace == "kube-system")
]

async def _list_jobs_for_cronjobs(self, namespace: str) -> list[V1Job]:
if namespace not in self.__jobs_for_cronjobs:
Expand Down Expand Up @@ -185,12 +185,12 @@ async def _list_namespaced_or_global_objects(
kind: KindLiteral,
all_namespaces_request: Callable,
namespaced_request: Callable
) -> AsyncIterable[Any]:
) -> list[Any]:
logger.debug(f"Listing {kind}s in {self.cluster}")
loop = asyncio.get_running_loop()

if settings.namespaces == "*":
tasks = [
requests = [
loop.run_in_executor(
self.executor,
lambda: all_namespaces_request(
Expand All @@ -200,7 +200,7 @@ async def _list_namespaced_or_global_objects(
)
]
else:
tasks = [
requests = [
loop.run_in_executor(
self.executor,
lambda ns=namespace: namespaced_request(
Expand All @@ -212,14 +212,14 @@ async def _list_namespaced_or_global_objects(
for namespace in settings.namespaces
]

total_items = 0
for task in asyncio.as_completed(tasks):
ret_single = await task
total_items += len(ret_single.items)
for item in ret_single.items:
yield item
result = [
item
for request_result in await asyncio.gather(*requests)
for item in request_result.items
]

logger.debug(f"Found {total_items} {kind} in {self.cluster}")
logger.debug(f"Found {len(result)} {kind} in {self.cluster}")
return result

async def _list_scannable_objects(
self,
Expand All @@ -228,25 +228,25 @@ async def _list_scannable_objects(
namespaced_request: Callable,
extract_containers: Callable[[Any], Union[Iterable[V1Container], Awaitable[Iterable[V1Container]]]],
filter_workflows: Optional[Callable[[Any], bool]] = None,
) -> AsyncIterable[K8sObjectData]:
) -> list[K8sObjectData]:
if not self._should_list_resource(kind):
logger.debug(f"Skipping {kind}s in {self.cluster}")
return

if not self.__kind_available[kind]:
return


result = []
try:
async for item in self._list_namespaced_or_global_objects(kind, all_namespaces_request, namespaced_request):
for item in await self._list_namespaced_or_global_objects(kind, all_namespaces_request, namespaced_request):
if filter_workflows is not None and not filter_workflows(item):
continue

containers = extract_containers(item)
if asyncio.iscoroutine(containers):
containers = await containers

for container in containers:
yield self.__build_scannable_object(item, container, kind)
result.extend(self.__build_scannable_object(item, container, kind) for container in containers)
except ApiException as e:
if kind in ("Rollout", "DeploymentConfig") and e.status in [400, 401, 403, 404]:
if self.__kind_available[kind]:
Expand All @@ -256,15 +256,17 @@ async def _list_scannable_objects(
logger.exception(f"Error {e.status} listing {kind} in cluster {self.cluster}: {e.reason}")
logger.error("Will skip this object type and continue.")

def _list_deployments(self) -> AsyncIterable[K8sObjectData]:
return result

def _list_deployments(self) -> list[K8sObjectData]:
return self._list_scannable_objects(
kind="Deployment",
all_namespaces_request=self.apps.list_deployment_for_all_namespaces,
namespaced_request=self.apps.list_namespaced_deployment,
extract_containers=lambda item: item.spec.template.spec.containers,
)

def _list_rollouts(self) -> AsyncIterable[K8sObjectData]:
def _list_rollouts(self) -> list[K8sObjectData]:
async def _extract_containers(item: Any) -> list[V1Container]:
if item.spec.template is not None:
return item.spec.template.spec.containers
Expand Down Expand Up @@ -311,7 +313,7 @@ async def _extract_containers(item: Any) -> list[V1Container]:
extract_containers=_extract_containers,
)

def _list_deploymentconfig(self) -> AsyncIterable[K8sObjectData]:
def _list_deploymentconfig(self) -> list[K8sObjectData]:
# NOTE: Using custom objects API returns dicts, but all other APIs return objects
# We need to handle this difference using a small wrapper
return self._list_scannable_objects(
Expand All @@ -335,23 +337,23 @@ def _list_deploymentconfig(self) -> AsyncIterable[K8sObjectData]:
extract_containers=lambda item: item.spec.template.spec.containers,
)

def _list_all_statefulsets(self) -> AsyncIterable[K8sObjectData]:
def _list_all_statefulsets(self) -> list[K8sObjectData]:
return self._list_scannable_objects(
kind="StatefulSet",
all_namespaces_request=self.apps.list_stateful_set_for_all_namespaces,
namespaced_request=self.apps.list_namespaced_stateful_set,
extract_containers=lambda item: item.spec.template.spec.containers,
)

def _list_all_daemon_set(self) -> AsyncIterable[K8sObjectData]:
def _list_all_daemon_set(self) -> list[K8sObjectData]:
return self._list_scannable_objects(
kind="DaemonSet",
all_namespaces_request=self.apps.list_daemon_set_for_all_namespaces,
namespaced_request=self.apps.list_namespaced_daemon_set,
extract_containers=lambda item: item.spec.template.spec.containers,
)

def _list_all_jobs(self) -> AsyncIterable[K8sObjectData]:
def _list_all_jobs(self) -> list[K8sObjectData]:
return self._list_scannable_objects(
kind="Job",
all_namespaces_request=self.batch.list_job_for_all_namespaces,
Expand All @@ -363,7 +365,7 @@ def _list_all_jobs(self) -> AsyncIterable[K8sObjectData]:
),
)

def _list_all_cronjobs(self) -> AsyncIterable[K8sObjectData]:
def _list_all_cronjobs(self) -> list[K8sObjectData]:
return self._list_scannable_objects(
kind="CronJob",
all_namespaces_request=self.batch.list_cron_job_for_all_namespaces,
Expand Down Expand Up @@ -398,14 +400,10 @@ async def __list_hpa_v1(self) -> dict[HPAKey, HPAData]:
}

async def __list_hpa_v2(self) -> dict[HPAKey, HPAData]:
loop = asyncio.get_running_loop()
res = await loop.run_in_executor(
self.executor,
lambda: self._list_namespaced_or_global_objects(
kind="HPA-v2",
all_namespaces_request=self.autoscaling_v2.list_horizontal_pod_autoscaler_for_all_namespaces,
namespaced_request=self.autoscaling_v2.list_namespaced_horizontal_pod_autoscaler,
),
res = await self._list_namespaced_or_global_objects(
kind="HPA-v2",
all_namespaces_request=self.autoscaling_v2.list_horizontal_pod_autoscaler_for_all_namespaces,
namespaced_request=self.autoscaling_v2.list_namespaced_horizontal_pod_autoscaler,
)
def __get_metric(hpa: V2HorizontalPodAutoscaler, metric_name: str) -> Optional[float]:
return next(
Expand All @@ -429,7 +427,7 @@ def __get_metric(hpa: V2HorizontalPodAutoscaler, metric_name: str) -> Optional[f
target_cpu_utilization_percentage=__get_metric(hpa, "cpu"),
target_memory_utilization_percentage=__get_metric(hpa, "memory"),
)
async for hpa in res
for hpa in res
}

# TODO: What should we do in case of other metrics bound to the HPA?
Expand Down Expand Up @@ -514,7 +512,7 @@ def _try_create_cluster_loader(self, cluster: Optional[str]) -> Optional[Cluster
logger.error(f"Could not load cluster {cluster} and will skip it: {e}")
return None

async def list_scannable_objects(self, clusters: Optional[list[str]]) -> AsyncIterable[K8sObjectData]:
async def list_scannable_objects(self, clusters: Optional[list[str]]) -> list[K8sObjectData]:
"""List all scannable objects.
Yields:
Expand All @@ -529,13 +527,12 @@ async def list_scannable_objects(self, clusters: Optional[list[str]]) -> AsyncIt
if self.cluster_loaders == {}:
logger.error("Could not load any cluster.")
return

# https://stackoverflow.com/questions/55299564/join-multiple-async-generators-in-python
# This will merge all the streams from all the cluster loaders into a single stream
async for object in async_gen_merge(
*[cluster_loader.list_scannable_objects() for cluster_loader in self.cluster_loaders.values()]
):
yield object

return [
object
for cluster_loader in self.cluster_loaders.values()
for object in await cluster_loader.list_scannable_objects()
]

async def load_pods(self, object: K8sObjectData) -> list[PodData]:
try:
Expand Down
8 changes: 2 additions & 6 deletions robusta_krr/core/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,12 +275,8 @@ async def _collect_result(self) -> Result:
await asyncio.gather(*[self._check_data_availability(cluster) for cluster in clusters])

with ProgressBar(title="Calculating Recommendation") as self.__progressbar:
scans_tasks = [
asyncio.create_task(self._gather_object_allocations(k8s_object))
async for k8s_object in self._k8s_loader.list_scannable_objects(clusters)
]

scans = await asyncio.gather(*scans_tasks)
workloads = await self._k8s_loader.list_scannable_objects(clusters)
scans = await asyncio.gather(*[self._gather_object_allocations(k8s_object) for k8s_object in workloads])

successful_scans = [scan for scan in scans if scan is not None]

Expand Down
39 changes: 0 additions & 39 deletions robusta_krr/utils/async_gen_merge.py

This file was deleted.

13 changes: 2 additions & 11 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import random
from datetime import datetime, timedelta
from unittest.mock import AsyncMock, MagicMock, patch
from unittest.mock import AsyncMock, patch

import numpy as np
import pytest
Expand All @@ -26,15 +26,6 @@
)


class AsyncIter:
def __init__(self, items):
self.items = items

async def __aiter__(self):
for item in self.items:
yield item


@pytest.fixture(autouse=True, scope="session")
def mock_list_clusters():
with patch(
Expand All @@ -48,7 +39,7 @@ def mock_list_clusters():
def mock_list_scannable_objects():
with patch(
"robusta_krr.core.integrations.kubernetes.KubernetesLoader.list_scannable_objects",
new=MagicMock(return_value=AsyncIter([TEST_OBJECT])),
new=AsyncMock(return_value=[TEST_OBJECT]),
):
yield

Expand Down

0 comments on commit 76ed553

Please sign in to comment.