Skip to content

Commit

Permalink
Migrate EventBranchCreate to InfrahubEvent
Browse files Browse the repository at this point in the history
  • Loading branch information
LucasG0 committed Nov 13, 2024
1 parent 319c9fc commit 86faba6
Show file tree
Hide file tree
Showing 9 changed files with 191 additions and 207 deletions.
30 changes: 22 additions & 8 deletions backend/infrahub/core/branch/tasks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import traceback
from typing import Any

import pydantic
Expand All @@ -22,13 +23,17 @@
from infrahub.core.validators.models.validate_migration import SchemaValidateMigrationData
from infrahub.core.validators.tasks import schema_validate_migrations
from infrahub.dependencies.registry import get_component_registry
from infrahub.events.branch_action import BranchDeleteEvent
from infrahub.events.branch_action import BranchCreateEvent, BranchDeleteEvent
from infrahub.exceptions import BranchNotFoundError, MergeFailedError, ValidationError
from infrahub.log import get_log_data
from infrahub.message_bus import Meta, messages
from infrahub.services import services
from infrahub.worker import WORKER_IDENTITY
from infrahub.workflows.catalogue import BRANCH_CANCEL_PROPOSED_CHANGES, IPAM_RECONCILIATION
from infrahub.workflows.catalogue import (
BRANCH_CANCEL_PROPOSED_CHANGES,
GIT_REPOSITORIES_CREATE_BRANCH,
IPAM_RECONCILIATION,
)
from infrahub.workflows.utils import add_branch_tag


Expand Down Expand Up @@ -261,6 +266,13 @@ async def validate_branch(branch: str) -> State:
@flow(name="create-branch", flow_run_name="Create branch {branch_name}")
async def create_branch(branch_name: str, data: dict[str, Any]) -> None:
service = services.service
log = get_run_logger()
log.error("create_branch: logging through run logger")
print("create_branch: printing with print")

print(f"{id(service)=}")
print(f"{id(service.event)=}")
print(f"{id(service.event._service)=}")

await add_branch_tag(branch_name)

Expand Down Expand Up @@ -291,9 +303,11 @@ async def create_branch(branch_name: str, data: dict[str, Any]) -> None:
# Add Branch to registry
registry.branch[obj.name] = obj

message = messages.EventBranchCreate(
branch=obj.name,
branch_id=str(obj.id),
sync_with_git=obj.sync_with_git,
)
await service.send(message=message)
event = BranchCreateEvent(branch=obj.name, branch_id=str(obj.id), sync_with_git=obj.sync_with_git)
await service.event.send(event=event)

if obj.sync_with_git:
await service.workflow.submit_workflow(
workflow=GIT_REPOSITORIES_CREATE_BRANCH,
parameters={"branch": obj.name, "branch_id": str(obj.id)},
)
33 changes: 33 additions & 0 deletions backend/infrahub/events/branch_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from infrahub.message_bus.messages.event_branch_delete import EventBranchDelete
from infrahub.message_bus.messages.refresh_registry_branches import RefreshRegistryBranches

from ..message_bus.messages import EventBranchCreate
from .models import InfrahubBranchEvent


Expand All @@ -24,6 +25,8 @@ def get_resource(self) -> dict[str, str]:

def get_messages(self) -> list[InfrahubMessage]:
events = [
# TODO: Sending EventBranchDelete currently has no effect.
# We should either consider handle it or remove it.
EventBranchDelete(
branch=self.branch,
branch_id=self.branch_id,
Expand All @@ -33,3 +36,33 @@ def get_messages(self) -> list[InfrahubMessage]:
RefreshRegistryBranches(),
]
return events


class BranchCreateEvent(InfrahubBranchEvent):
"""Event generated when a branch has been created"""

branch_id: str = Field(..., description="The ID of the mutated node")
sync_with_git: bool = Field(..., description="Indicates if the branch was extended to Git")

def get_name(self) -> str:
return f"{self.get_event_namespace()}.branch.created"

def get_resource(self) -> dict[str, str]:
return {
"prefect.resource.id": f"infrahub.branch.{self.branch}",
"infrahub.branch.id": self.branch_id,
}

def get_messages(self) -> list[InfrahubMessage]:
events = [
# TODO: Sending EventBranchCreate currently has no effect.
# We should either consider handle it or remove it.
EventBranchCreate(
branch=self.branch,
branch_id=self.branch_id,
sync_with_git=self.sync_with_git,
meta=self.get_message_meta(),
),
RefreshRegistryBranches(),
]
return events
7 changes: 7 additions & 0 deletions backend/infrahub/graphql/mutations/branch.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,13 @@ async def mutate(
context: GraphqlContext = info.context
task: dict | None = None

log.error("mutate: logging through logger")
print("mutate: printing with print")

print(f"{id(context.active_service)=}")
print(f"{id(context.active_service.event)=}")
print(f"{id(context.active_service.event._service)=}")

await context.active_service.workflow.execute_workflow(
workflow=BRANCH_CREATE, parameters={"branch_name": data.name, "data": dict(data)}
)
Expand Down
1 change: 0 additions & 1 deletion backend/infrahub/message_bus/operations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
"check.repository.check_definition": check.repository.check_definition,
"check.repository.merge_conflicts": check.repository.merge_conflicts,
"check.repository.user_check": check.repository.user_check,
"event.branch.create": event.branch.create,
"event.branch.merge": event.branch.merge,
"event.branch.rebased": event.branch.rebased,
"event.node.mutated": event.node.mutated,
Expand Down
17 changes: 0 additions & 17 deletions backend/infrahub/message_bus/operations/event/branch.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from infrahub.message_bus import InfrahubMessage, messages
from infrahub.services import InfrahubServices
from infrahub.workflows.catalogue import (
GIT_REPOSITORIES_CREATE_BRANCH,
REQUEST_DIFF_REFRESH,
REQUEST_DIFF_UPDATE,
TRIGGER_ARTIFACT_DEFINITION_GENERATE,
Expand All @@ -21,22 +20,6 @@
log = get_logger()


@flow(name="event-branch-create")
async def create(message: messages.EventBranchCreate, service: InfrahubServices) -> None:
log.info("run_message", branch=message.branch)

events: List[InfrahubMessage] = [messages.RefreshRegistryBranches()]
if message.sync_with_git:
await service.workflow.submit_workflow(
workflow=GIT_REPOSITORIES_CREATE_BRANCH,
parameters={"branch": message.branch, "branch_id": message.branch_id},
)

for event in events:
event.assign_meta(parent=message)
await service.send(message=event)


@flow(name="branch-event-merge")
async def merge(message: messages.EventBranchMerge, service: InfrahubServices) -> None:
log.info("Branch merged", source_branch=message.source_branch, target_branch=message.target_branch)
Expand Down
2 changes: 1 addition & 1 deletion backend/infrahub/services/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(
@property
def client(self) -> InfrahubClient:
if not self._client:
raise InitializationError()
raise InitializationError("Service is not initialized with a client")

return self._client

Expand Down
4 changes: 2 additions & 2 deletions backend/infrahub/services/adapters/workflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,12 @@ async def execute_workflow(
parameters: dict[str, Any] | None = None,
tags: list[str] | None = None,
) -> Any:
raise NotImplementedError()
raise NotImplementedError("InfrahubWorkflow.execute_workflow is an abstract method")

async def submit_workflow(
self,
workflow: WorkflowDefinition,
parameters: dict[str, Any] | None = None,
tags: list[str] | None = None,
) -> WorkflowInfo:
raise NotImplementedError()
raise NotImplementedError("InfrahubWorkflow.submit_workflow is an abstract method")
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,16 @@
from dataclasses import dataclass
from functools import partial
from typing import TYPE_CHECKING, Any
from uuid import uuid4

import httpx
import pytest
import ujson
from aio_pika import Message

from infrahub import config
from infrahub.components import ComponentType
from infrahub.message_bus import messages
from infrahub.message_bus.messages.send_echo_request import SendEchoRequestResponse
from infrahub.message_bus.operations import execute_message
from infrahub.message_bus.types import MessageTTL
from infrahub.services import InfrahubServices
from infrahub.services.adapters.message_bus.rabbitmq import RabbitMQMessageBus
from infrahub.worker import WORKER_IDENTITY
Expand Down Expand Up @@ -338,41 +335,6 @@ async def test_rabbitmq_initial_setup(rabbitmq_api: RabbitMQManager) -> None:
)


async def test_rabbitmq_publish(rabbitmq_api: RabbitMQManager) -> None:
"""Validate that the adapter publishes messages to the correct queue"""

bus = RabbitMQMessageBus(settings=rabbitmq_api.settings)
service = InfrahubServices(message_bus=bus, component_type=ComponentType.API_SERVER)

normal_message = messages.EventBranchCreate(branch="normal", branch_id=str(uuid4()), sync_with_git=False)
delayed_message = messages.EventBranchCreate(branch="delayed", branch_id=str(uuid4()), sync_with_git=False)

await bus.initialize(service=service)
await service.send(message=normal_message)
await service.send(message=delayed_message, delay=MessageTTL.FIVE)

queue = await bus.channel.get_queue(name=f"{bus.settings.namespace}.rpcs")
delayed_queue = await bus.channel.get_queue(name=f"{bus.settings.namespace}.delay.five_seconds")
message_from_queue = await queue.get()
delayed_message_from_queue = await delayed_queue.get()
parsed_message = ujson.loads(message_from_queue.body)
parsed_delayed_message = ujson.loads(delayed_message_from_queue.body)

await bus.shutdown()

parsed_message = messages.EventBranchCreate(**parsed_message)
parsed_delayed_message = messages.EventBranchCreate(**parsed_delayed_message)

# The priority isn't currently included in the header, reset it to show expected priority
normal_message.meta.priority = 3
delayed_message.meta.priority = 3
parsed_delayed_message.meta.headers = {"delay": 5000}
assert message_from_queue.priority == 5
assert delayed_message_from_queue.priority == 5
assert parsed_message == normal_message
assert parsed_delayed_message == delayed_message


async def test_rabbitmq_callback(rabbitmq_api: RabbitMQManager, fake_log: FakeLogger) -> None:
"""Validates that incoming messages gets parsed by the callback method."""

Expand Down
Loading

0 comments on commit 86faba6

Please sign in to comment.