Skip to content

Commit

Permalink
Migrate EventBranchCreate to InfrahubEvent
Browse files Browse the repository at this point in the history
  • Loading branch information
LucasG0 committed Nov 13, 2024
1 parent 39c3833 commit 267ed1f
Show file tree
Hide file tree
Showing 12 changed files with 265 additions and 228 deletions.
32 changes: 24 additions & 8 deletions backend/infrahub/core/branch/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,18 @@
from infrahub.core.validators.models.validate_migration import SchemaValidateMigrationData
from infrahub.core.validators.tasks import schema_validate_migrations
from infrahub.dependencies.registry import get_component_registry
from infrahub.events.branch_action import BranchDeleteEvent
from infrahub.events.branch_action import BranchCreateEvent, BranchDeleteEvent
from infrahub.exceptions import BranchNotFoundError, MergeFailedError, ValidationError
from infrahub.graphql.mutations.models import BranchCreateModel # noqa: TCH001
from infrahub.log import get_log_data
from infrahub.message_bus import Meta, messages
from infrahub.services import services
from infrahub.worker import WORKER_IDENTITY
from infrahub.workflows.catalogue import BRANCH_CANCEL_PROPOSED_CHANGES, IPAM_RECONCILIATION
from infrahub.workflows.catalogue import (
BRANCH_CANCEL_PROPOSED_CHANGES,
GIT_REPOSITORIES_CREATE_BRANCH,
IPAM_RECONCILIATION,
)
from infrahub.workflows.utils import add_branch_tag


Expand Down Expand Up @@ -257,6 +261,11 @@ async def validate_branch(branch: str) -> State:
@flow(name="create-branch", flow_run_name="Create branch {model.name}")
async def create_branch(model: BranchCreateModel) -> None:
service = services.service
print("create_branch: printing with print")

print(f"{id(service)=}")
print(f"{id(service.event)=}")
print(f"{id(service.event._service)=}")

await add_branch_tag(model.name)

Expand Down Expand Up @@ -287,9 +296,16 @@ async def create_branch(model: BranchCreateModel) -> None:
# Add Branch to registry
registry.branch[obj.name] = obj

message = messages.EventBranchCreate(
branch=obj.name,
branch_id=str(obj.id),
sync_with_git=obj.sync_with_git,
)
await service.send(message=message)
print("before sending event")

event = BranchCreateEvent(branch=obj.name, branch_id=str(obj.id), sync_with_git=obj.sync_with_git)
print(f"{id(service)=}")
print(f"{id(service.message_bus)=}")
await service.event.send(event=event)

if obj.sync_with_git:
print("before sending GIT_REPOSITORIES_CREATE_BRANCH")
await service.workflow.submit_workflow(
workflow=GIT_REPOSITORIES_CREATE_BRANCH,
parameters={"branch": obj.name, "branch_id": str(obj.id)},
)
30 changes: 30 additions & 0 deletions backend/infrahub/events/branch_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ def get_resource(self) -> dict[str, str]:

def get_messages(self) -> list[InfrahubMessage]:
events = [
# TODO: Sending EventBranchDelete currently has no effect.
# We should either consider handle it or remove it.
EventBranchDelete(
branch=self.branch,
branch_id=self.branch_id,
Expand All @@ -33,3 +35,31 @@ def get_messages(self) -> list[InfrahubMessage]:
RefreshRegistryBranches(),
]
return events


class BranchCreateEvent(InfrahubBranchEvent):
"""Event generated when a branch has been created"""

branch_id: str = Field(..., description="The ID of the mutated node")
sync_with_git: bool = Field(..., description="Indicates if the branch was extended to Git")

def get_name(self) -> str:
return f"{self.get_event_namespace()}.branch.created"

def get_resource(self) -> dict[str, str]:
return {
"prefect.resource.id": f"infrahub.branch.{self.branch}",
"infrahub.branch.id": self.branch_id,
}

def get_messages(self) -> list[InfrahubMessage]:
events = [
# EventBranchCreate(
# branch=self.branch,
# branch_id=self.branch_id,
# sync_with_git=self.sync_with_git,
# meta=self.get_message_meta(),
# ),
RefreshRegistryBranches(),
]
return events # type: ignore
8 changes: 7 additions & 1 deletion backend/infrahub/git/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ async def add_git_repository_read_only(model: GitRepositoryAddReadOnly) -> None:
@flow(name="git_repositories_create_branch")
async def create_branch(branch: str, branch_id: str) -> None:
"""Request to the creation of git branches in available repositories."""
print("start git_repositories_create_branch")
service = services.service
await add_branch_tag(branch_name=branch)

Expand All @@ -118,6 +119,7 @@ async def create_branch(branch: str, branch_id: str) -> None:
batch = await service.client.create_batch()

for repository in repositories:
print(f"adding a batch for repo {repository.name.value=}")
batch.add(
task=git_branch_create,
client=service.client.client,
Expand Down Expand Up @@ -214,10 +216,13 @@ async def git_branch_create(
client: InfrahubClient, branch: str, branch_id: str, repository_id: str, repository_name: str
) -> None:
service = services.service

print(f"in git_branch_create and {repository_name=}")
repo = await InfrahubRepository.init(id=repository_id, name=repository_name, client=client)
print(" after InfrahubRepository.init")
async with lock.registry.get(name=repository_name, namespace="repository"):
print("after lock.registry.get")
await repo.create_branch_in_git(branch_name=branch, branch_id=branch_id)
print("after repo.create_branch_in_git")
if repo.location:
# New branch has been pushed remotely, tell workers to fetch it
message = messages.RefreshGitFetch(
Expand All @@ -228,6 +233,7 @@ async def git_branch_create(
repository_kind=InfrahubKind.REPOSITORY,
infrahub_branch_name=branch,
)
print("before service send")
await service.send(message=message)


Expand Down
38 changes: 18 additions & 20 deletions backend/infrahub/message_bus/operations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import ujson
from prefect import Flow

from infrahub.message_bus import RPCErrorResponse, messages
from infrahub.message_bus import messages
from infrahub.message_bus.operations import (
check,
event,
Expand All @@ -17,15 +17,13 @@
)
from infrahub.message_bus.types import MessageTTL
from infrahub.services import InfrahubServices
from infrahub.tasks.check import set_check_status

COMMAND_MAP = {
"check.artifact.create": check.artifact.create,
"check.generator.run": check.generator.run,
"check.repository.check_definition": check.repository.check_definition,
"check.repository.merge_conflicts": check.repository.merge_conflicts,
"check.repository.user_check": check.repository.user_check,
"event.branch.create": event.branch.create,
"event.branch.merge": event.branch.merge,
"event.branch.rebased": event.branch.rebased,
"event.node.mutated": event.node.mutated,
Expand Down Expand Up @@ -63,20 +61,20 @@ async def execute_message(
message_data = ujson.loads(message_body)
message = messages.MESSAGE_MAP[routing_key](**message_data)
message.set_log_data(routing_key=routing_key)
try:
func = COMMAND_MAP[routing_key]
if skip_flow and isinstance(func, Flow):
func = func.fn
await func(message=message, service=service)
except Exception as exc: # pylint: disable=broad-except
if message.reply_requested:
response = RPCErrorResponse(errors=[str(exc)], initial_message=message.model_dump())
await service.reply(message=response, initiator=message)
return None
if message.reached_max_retries:
service.log.exception("Message failed after maximum number of retries", error=exc)
await set_check_status(message, conclusion="failure", service=service)
return None
message.increase_retry_count()
await service.send(message, delay=MessageTTL.FIVE, is_retry=True)
return MessageTTL.FIVE
# try:
func = COMMAND_MAP[routing_key]
if skip_flow and isinstance(func, Flow):
func = func.fn
await func(message=message, service=service)
# except Exception as exc: # pylint: disable=broad-except
# if message.reply_requested:
# response = RPCErrorResponse(errors=[str(exc)], initial_message=message.model_dump())
# await service.reply(message=response, initiator=message)
# return None
# if message.reached_max_retries:
# service.log.exception("Message failed after maximum number of retries", error=exc)
# await set_check_status(message, conclusion="failure", service=service)
# return None
# message.increase_retry_count()
# await service.send(message, delay=MessageTTL.FIVE, is_retry=True)
# return MessageTTL.FIVE
17 changes: 0 additions & 17 deletions backend/infrahub/message_bus/operations/event/branch.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from infrahub.message_bus import InfrahubMessage, messages
from infrahub.services import InfrahubServices
from infrahub.workflows.catalogue import (
GIT_REPOSITORIES_CREATE_BRANCH,
REQUEST_DIFF_REFRESH,
REQUEST_DIFF_UPDATE,
TRIGGER_ARTIFACT_DEFINITION_GENERATE,
Expand All @@ -21,22 +20,6 @@
log = get_logger()


@flow(name="event-branch-create")
async def create(message: messages.EventBranchCreate, service: InfrahubServices) -> None:
log.info("run_message", branch=message.branch)

events: List[InfrahubMessage] = [messages.RefreshRegistryBranches()]
if message.sync_with_git:
await service.workflow.submit_workflow(
workflow=GIT_REPOSITORIES_CREATE_BRANCH,
parameters={"branch": message.branch, "branch_id": message.branch_id},
)

for event in events:
event.assign_meta(parent=message)
await service.send(message=event)


@flow(name="branch-event-merge")
async def merge(message: messages.EventBranchMerge, service: InfrahubServices) -> None:
log.info("Branch merged", source_branch=message.source_branch, target_branch=message.target_branch)
Expand Down
6 changes: 6 additions & 0 deletions backend/infrahub/message_bus/operations/refresh/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,21 @@


async def branches(message: messages.RefreshRegistryBranches, service: InfrahubServices) -> None:
print("start branches")
if message.meta and message.meta.initiator_id == WORKER_IDENTITY:
service.log.info("Ignoring refresh registry refresh request originating from self", worker=WORKER_IDENTITY)
return

print("before service.database")

async with service.database.start_session() as db:
await refresh_branches(db=db)

print("before service.component")
await service.component.refresh_schema_hash()

print("ending branches")


async def rebased_branch(message: messages.RefreshRegistryRebasedBranch, service: InfrahubServices) -> None:
if message.meta and message.meta.initiator_id == WORKER_IDENTITY:
Expand Down
2 changes: 1 addition & 1 deletion backend/infrahub/services/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(
@property
def client(self) -> InfrahubClient:
if not self._client:
raise InitializationError()
raise InitializationError("Service is not initialized with a client")

return self._client

Expand Down
4 changes: 3 additions & 1 deletion backend/infrahub/services/adapters/event/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ async def send(self, event: InfrahubEvent) -> None:
await asyncio.gather(*tasks)

async def _send_bus(self, event: InfrahubEvent) -> None:
for message in event.get_messages():
print(self.service.message_bus)
for i, message in enumerate(event.get_messages()):
print(f"here {i=} and {message=}")
await self.service.send(message=message)

async def _send_prefect(self, event: InfrahubEvent) -> None:
Expand Down
4 changes: 2 additions & 2 deletions backend/infrahub/services/adapters/workflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,12 @@ async def execute_workflow(
parameters: dict[str, Any] | None = None,
tags: list[str] | None = None,
) -> Any:
raise NotImplementedError()
raise NotImplementedError("InfrahubWorkflow.execute_workflow is an abstract method")

async def submit_workflow(
self,
workflow: WorkflowDefinition,
parameters: dict[str, Any] | None = None,
tags: list[str] | None = None,
) -> WorkflowInfo:
raise NotImplementedError()
raise NotImplementedError("InfrahubWorkflow.submit_workflow is an abstract method")
4 changes: 4 additions & 0 deletions backend/tests/adapters/message_bus.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class BusSimulator(InfrahubMessageBus):
def __init__(self, database: InfrahubDatabase | None = None, workflow: InfrahubWorkflow | None = None) -> None:
self.messages: list[InfrahubMessage] = []
self.messages_per_routing_key: dict[str, list[InfrahubMessage]] = {}

self.service: InfrahubServices = InfrahubServices(database=database, message_bus=self, workflow=workflow)
self.replies: dict[str, list[InfrahubMessage]] = defaultdict(list)
build_component_registry()
Expand Down Expand Up @@ -73,3 +74,6 @@ async def rpc(self, message: InfrahubMessage, response_class: type[ResponseClass
@property
def seen_routing_keys(self) -> list[str]:
return list(self.messages_per_routing_key.keys())

async def initialize(self, service: InfrahubServices) -> None:
self.service = service
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,16 @@
from dataclasses import dataclass
from functools import partial
from typing import TYPE_CHECKING, Any
from uuid import uuid4

import httpx
import pytest
import ujson
from aio_pika import Message

from infrahub import config
from infrahub.components import ComponentType
from infrahub.message_bus import messages
from infrahub.message_bus.messages.send_echo_request import SendEchoRequestResponse
from infrahub.message_bus.operations import execute_message
from infrahub.message_bus.types import MessageTTL
from infrahub.services import InfrahubServices
from infrahub.services.adapters.message_bus.rabbitmq import RabbitMQMessageBus
from infrahub.worker import WORKER_IDENTITY
Expand Down Expand Up @@ -338,41 +335,6 @@ async def test_rabbitmq_initial_setup(rabbitmq_api: RabbitMQManager) -> None:
)


async def test_rabbitmq_publish(rabbitmq_api: RabbitMQManager) -> None:
"""Validate that the adapter publishes messages to the correct queue"""

bus = RabbitMQMessageBus(settings=rabbitmq_api.settings)
service = InfrahubServices(message_bus=bus, component_type=ComponentType.API_SERVER)

normal_message = messages.EventBranchCreate(branch="normal", branch_id=str(uuid4()), sync_with_git=False)
delayed_message = messages.EventBranchCreate(branch="delayed", branch_id=str(uuid4()), sync_with_git=False)

await bus.initialize(service=service)
await service.send(message=normal_message)
await service.send(message=delayed_message, delay=MessageTTL.FIVE)

queue = await bus.channel.get_queue(name=f"{bus.settings.namespace}.rpcs")
delayed_queue = await bus.channel.get_queue(name=f"{bus.settings.namespace}.delay.five_seconds")
message_from_queue = await queue.get()
delayed_message_from_queue = await delayed_queue.get()
parsed_message = ujson.loads(message_from_queue.body)
parsed_delayed_message = ujson.loads(delayed_message_from_queue.body)

await bus.shutdown()

parsed_message = messages.EventBranchCreate(**parsed_message)
parsed_delayed_message = messages.EventBranchCreate(**parsed_delayed_message)

# The priority isn't currently included in the header, reset it to show expected priority
normal_message.meta.priority = 3
delayed_message.meta.priority = 3
parsed_delayed_message.meta.headers = {"delay": 5000}
assert message_from_queue.priority == 5
assert delayed_message_from_queue.priority == 5
assert parsed_message == normal_message
assert parsed_delayed_message == delayed_message


async def test_rabbitmq_callback(rabbitmq_api: RabbitMQManager, fake_log: FakeLogger) -> None:
"""Validates that incoming messages gets parsed by the callback method."""

Expand Down
Loading

0 comments on commit 267ed1f

Please sign in to comment.