Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ services:
mle_net:

mlflow:
image: ghcr.io/mlflow/mlflow:v2.22.0
image: ghcr.io/mlflow/mlflow:v2.22.0
container_name: mlflow-server
command: >
/bin/sh -c "pip install --no-cache-dir psycopg2-binary 'mlflow[auth]' &&
Expand Down Expand Up @@ -167,16 +167,16 @@ services:
networks:
mle_net:

kvrocks:
image: docker.io/apache/kvrocks:latest
container_name: kvrocks # Added container name
volumes:
- ./conf/kvrocks/kvrocks.conf:/etc/kvrocks/kvrocks.conf
- ./persist/kvrocks/data:/data
kvrocks:
image: docker.io/apache/kvrocks:latest
container_name: kvrocks # Added container name
volumes:
- ./conf/kvrocks/kvrocks.conf:/etc/kvrocks/kvrocks.conf
- ./persist/kvrocks/data:/data
ports:
- "127.0.0.1:6666:6666" # Added port mapping
networks:
mle_net:
networks:
mle_net:

# arroyo, arroyo_vec_sim and arroyo_sim are made optional via profiles
# To run vector simulator, use
Expand Down Expand Up @@ -221,6 +221,10 @@ services:
- .:/app:Z
ports:
- 127.0.0.1:8765:8765
environment:
SIMULATION_TYPE: "default" # Set to "mnist" for MNIST simulation
RESULTS_TILED_URI: '${RESULTS_TILED_URI}'
RESULTS_TILED_API_KEY: '${RESULTS_TILED_API_KEY}'
networks:
mle_net:

Expand Down
2 changes: 1 addition & 1 deletion simulator/data_simulator_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def get_num_frames(tiled_uri, tiled_api_key=None):
}


async def stream():
async def mnist_stream():
"""
Connect to the existing WebSocket server elsewhere,
send messages, then close the connection.
Expand Down
107 changes: 69 additions & 38 deletions simulator/tiled_ingestor_mnist.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,12 @@
import os
import logging
import tempfile

import load_dotenv
import numpy as np
from tiled.client import from_uri
from tiled.structures.array import ArrayStructure
from torchvision import datasets, transforms

# Load .env for API key
load_dotenv.load_dotenv()
tiled_api_key = os.getenv("RESULTS_TILED_API_KEY")

# Connect to root of Tiled server
client = from_uri("http://tiled:8000/api/v1", api_key=tiled_api_key)

# Create a container at /mnist if it doesn't exist
if "mnist" not in client:
client.create_container("mnist", metadata={"purpose": "MNIST digit subsets"})

mnist_container = client["mnist"]

# Load MNIST dataset
mnist_dataset = datasets.MNIST(
root="./data", train=True, download=True, transform=transforms.ToTensor()
)

# Group images by label
images_by_label = {i: [] for i in range(10)}
for img, label in mnist_dataset:
images_by_label[label].append(img.squeeze().numpy()) # 28x28

# Define how many images per label to include
label_mapping = {
LABEL_COUNTS = {
1: 10,
2: 3,
3: 1,
Expand All @@ -43,7 +19,7 @@
0: 6,
}

label_names = {
LABEL_NAMES = {
1: "ones",
2: "twos",
3: "three",
Expand All @@ -56,16 +32,71 @@
0: "zero",
}

# Write digit subsets into the mnist container
for digit, count in label_mapping.items():
label_name = label_names[digit]
selected_images = images_by_label[digit][:count]
selected_array = np.stack(selected_images) if count > 1 else selected_images[0]
structure = ArrayStructure.from_array(selected_array)
logger = logging.getLogger(__name__)
logging.basicConfig(
level=logging.INFO, format="%(levelname)s: (%(name)s) %(message)s "
)
logger.setLevel(logging.INFO)


def ingest_mnist_to_tiled(tiled_uri, api_key: str = None) -> bool:
"""Ingest MNIST dataset into Tiled under /mnist/ with subsets for each digit."""

try:
# Connect to Tiled and create container
client = from_uri(tiled_uri, api_key=api_key)

if "mnist" not in client:
client.create_container(
"mnist", metadata={"purpose": "MNIST digit subsets"}
)
else:
logger.info(
"⚠️ Container 'mnist' already exists, no data has been ingested."
)
return True

mnist_container = client["mnist"]

# Load MNIST dataset
with tempfile.TemporaryDirectory() as tmp_dir:
mnist_dataset = datasets.MNIST(
root=tmp_dir, train=True, download=True, transform=transforms.ToTensor()
)
logger.debug(f"📊 Loaded MNIST dataset with {len(mnist_dataset)} images")

# Group images by label
images_by_label = {i: [] for i in range(10)}
for img, label in mnist_dataset:
images_by_label[label].append(img.squeeze().numpy()) # 28x28

# Write arrays to Tiled
for digit, count in LABEL_COUNTS.items():
label_name = LABEL_NAMES[digit]
selected_images = images_by_label[digit][:count]
selected_array = (
np.stack(selected_images) if count > 1 else selected_images[0]
)
logger.debug(
f"Writing {len(selected_images)} images for label {label_name} (digit {digit})"
)
array_shape = selected_array.shape
logger.debug(
f"Shape of selected array: {array_shape if isinstance(selected_array, np.ndarray) else 'single image'}"
)

mnist_container.write_array(
key=label_name,
array=selected_array,
metadata={"label": digit, "count": count},
)

mnist_container.write_array(
key=label_name, array=selected_array, metadata={"label": digit, "count": count}
)
logger.info("✅ MNIST subset successfully uploaded to Tiled")
return True

except Exception as e:
import traceback

print("✅ MNIST subset successfully uploaded to Tiled under /mnist/")
logger.debug(traceback.format_exc())
logger.info(f"❌ Error: {e}")
return False
32 changes: 26 additions & 6 deletions simulator/websocket_simulator.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
import asyncio
import logging
import os
import sys

import numpy as np
import typer

from .data_simulator import stream
from src.arroyo_reduction.publisher import LSEWSResultPublisher
from src.arroyo_reduction.schemas import LatentSpaceEvent

from .data_simulator import stream
from .data_simulator_mnist import mnist_stream
from .tiled_ingestor_mnist import ingest_mnist_to_tiled

app = typer.Typer()
logger = logging.getLogger("arroyo_reduction")

SIMULATION_TYPE = os.getenv("SIMULATION_TYPE", "default").lower() # default or mnist


def setup_logger(logger: logging.Logger, log_level: str = "INFO"):
Expand All @@ -22,11 +27,14 @@ def setup_logger(logger: logging.Logger, log_level: str = "INFO"):
logger.setLevel(log_level.upper())
logger.debug("DEBUG LOGGING SET")


setup_logger(logger)


def get_feature_vectors(num_messages):
return 5 * np.random.rand(num_messages, 2)


class DummyWSPublisher(LSEWSResultPublisher):

def __init__(self, host="0.0.0.0", port=8765, path="/ws"):
Expand All @@ -37,16 +45,19 @@ async def start(self) -> None:
logger.info("DummyWSPublisher started, but does nothing.")
await super().start()




@app.command()
def start() -> None:
async def main():
ws_publisher = DummyWSPublisher()
asyncio.create_task(ws_publisher.start())
while True:
gen = stream()
if SIMULATION_TYPE == "mnist":
logger.info("Starting MNIST simulation...")
gen = mnist_stream()
else:
gen = stream()
logger.info("Starting default simulation...")
while True:
# Simulate receiving data
message = await anext(gen, None)
Expand All @@ -58,4 +69,13 @@ async def main():


if __name__ == "__main__":
app()
if SIMULATION_TYPE == "mnist":
logger.info("Running MNIST simulation...")
ingestion = ingest_mnist_to_tiled(
tiled_uri=os.getenv("RESULTS_TILED_URI", "http://tiled:8000/api/v1"),
api_key=os.getenv("RESULTS_TILED_API_KEY", None),
)
if not ingestion:
sys.exit(1)

app()