Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ PREFECT_API_URL=http://prefect:4200/api
FLOW_NAME="Parent flow/launch_parent_flow"
TIMEZONE="US/Pacific"
PREFECT_TAGS='["latent-space-explorer"]'
FLOW_TYPE="docker"

# MLFlow
MLFLOW_TRACKING_URI=http://mlflow:5000
Expand Down
3 changes: 1 addition & 2 deletions docker-compose.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
services:
prefect:
image: prefecthq/prefect:2.14-python3.11
image: prefecthq/prefect:3.4.2-python3.11
command: prefect server start
container_name: prefect-server
environment:
Expand Down Expand Up @@ -136,7 +136,6 @@ services:
FLOW_NAME: '${FLOW_NAME}'
TIMEZONE: "${TIMEZONE}"
PREFECT_TAGS: "${PREFECT_TAGS}"
FLOW_TYPE: "${FLOW_TYPE}"
CONTAINER_NETWORK: "${CONTAINER_NETWORK}"
# Slurm jobs
PARTITIONS_CPU: "${PARTITIONS_CPU}"
Expand Down
4 changes: 2 additions & 2 deletions live_operator_example/lse_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch

# Import the MLflowClient class
from src.utils.mlflow_utils import MLflowClient
from mlex_utils.mlflow_utils.mlflow_model_client import MLflowModelClient
from tiled.client import from_uri

from tiled_utils import write_results
Expand Down Expand Up @@ -36,7 +36,7 @@ def get_default_device():

if __name__ == "__main__":
# Create MLflow client
mlflow_client = MLflowClient(
mlflow_client = MLflowModelClient(
tracking_uri=MLFLOW_TRACKING_URI,
username=MLFLOW_TRACKING_USERNAME,
password=MLFLOW_TRACKING_PASSWORD,
Expand Down
7 changes: 5 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ lse = [
"kaleido<=0.2.1",
"humanhash3",
"mlex_file_manager@git+https://github.com/mlexchange/mlex_file_manager.git",
"mlex_utils[all]@git+https://github.com/mlexchange/mlex_utils.git",
"mlex_utils[all]@git+https://github.com/xiaoyachong/mlex_utils.git@xiaoya-update-prefect3",
"numpy<2.0.0",
"pandas",
"Pillow",
Expand Down Expand Up @@ -77,5 +77,8 @@ arroyo = [
"torchvision==0.17.2",
"transformers==4.47.1",
"umap-learn",
"joblib==1.4.2"
"joblib==1.4.2",
"mlflow==2.22.0",
"mlex_utils[all]@git+https://github.com/xiaoyachong/mlex_utils.git@xiaoya-update-prefect3",
"numpy<2.0.0"
]
4 changes: 2 additions & 2 deletions src/arroyo_reduction/reducer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from arroyosas.schemas import RawFrameEvent
from PIL import Image

from src.utils.mlflow_utils import MLflowClient
from mlex_utils.mlflow_utils.mlflow_model_client import MLflowModelClient

from .redis_model_store import RedisModelStore

Expand Down Expand Up @@ -88,7 +88,7 @@ def __init__(self):
self.device = device

# Load models from MLflow
mlflow_client = MLflowClient()
mlflow_client = MLflowModelClient()
self.mlflow_client = mlflow_client # Store for later use

# Set loading flags before loading models
Expand Down
14 changes: 4 additions & 10 deletions src/callbacks/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,22 +29,21 @@
parse_job_params,
parse_model_params,
)
from src.utils.mlflow_utils import MLflowClient
from mlex_utils.mlflow_utils.mlflow_model_client import MLflowModelClient
from src.utils.plot_utils import generate_notification

MODE = os.getenv("MODE", "")
TIMEZONE = os.getenv("TIMEZONE", "US/Pacific")
FLOW_NAME = os.getenv("FLOW_NAME", "")
PREFECT_TAGS = json.loads(os.getenv("PREFECT_TAGS", '["latent-space-explorer"]'))
RESULTS_DIR = os.getenv("RESULTS_DIR", "")
FLOW_TYPE = os.getenv("FLOW_TYPE", "conda")

# Initialize Redis model store instead of direct Redis client
REDIS_HOST = os.getenv("REDIS_HOST", "kvrocks")
REDIS_PORT = int(os.getenv("REDIS_PORT", 6666))

logger = logging.getLogger(__name__)
mlflow_client = MLflowClient()
mlflow_client = MLflowModelClient()

@callback(
Output("mlflow-model-dropdown", "options", allow_duplicate=True),
Expand Down Expand Up @@ -189,7 +188,6 @@ def run_latent_space(
model_parameters,
USER,
project_name,
FLOW_TYPE,
latent_space_params,
dim_reduction_params,
mlflow_model_id,
Expand Down Expand Up @@ -459,17 +457,13 @@ def run_clustering(
api_key=tiled_results.data_tiled_api_key,
)

model_exec_params = clustering_models[model_name]
clustering_params = clustering_models[model_name]
job_params = parse_clustering_job_params(
data_project_fvec,
model_parameters,
USER,
project_name,
FLOW_TYPE,
model_exec_params["image_name"],
model_exec_params["image_tag"],
model_exec_params["python_file_name"],
model_exec_params["conda_env"],
clustering_params
)

if MODE == "dev":
Expand Down
6 changes: 3 additions & 3 deletions src/callbacks/infrastructure_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

from src.components.infrastructure import create_infra_state_details
from src.utils.data_utils import tiled_results
from src.utils.mlflow_utils import MLflowClient
from src.utils.prefect import check_prefect_ready, check_prefect_worker_ready
from mlex_utils.mlflow_utils.mlflow_model_client import MLflowModelClient
from mlex_utils.prefect_utils.core import check_prefect_ready, check_prefect_worker_ready

TIMEZONE = os.getenv("TIMEZONE", "US/Pacific")
FLOW_NAME = os.getenv("FLOW_NAME", "")
Expand All @@ -33,7 +33,7 @@ def check_infra_state(n_intervals):

# MLFLOW: Check MLFlow is reachable
try:
mlflow_client = MLflowClient()
mlflow_client = MLflowModelClient()
infra_state["mlflow_ready"] = mlflow_client.check_mlflow_ready()
if not infra_state["mlflow_ready"]:
any_infra_down = True
Expand Down
4 changes: 2 additions & 2 deletions src/callbacks/live_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from src.arroyo_reduction.redis_model_store import (
RedisModelStore,
) # Import the RedisModelStore class
from src.utils.mlflow_utils import MLflowClient
from mlex_utils.mlflow_utils.mlflow_model_client import MLflowModelClient
from src.utils.plot_utils import (
generate_scatter_data,
plot_empty_heatmap,
Expand All @@ -33,7 +33,7 @@
redis_model_store = RedisModelStore(host=REDIS_HOST, port=REDIS_PORT)

logger = logging.getLogger("lse.live_mode")
mlflow_client = MLflowClient()
mlflow_client = MLflowModelClient()


@callback(
Expand Down
66 changes: 29 additions & 37 deletions src/test/test_mlflow_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,27 @@
import pytest

from src.test.test_utils import mlflow_test_client, mock_mlflow_client, mock_os_makedirs
from src.utils.mlflow_utils import MLflowClient
from mlex_utils.mlflow_utils.mlflow_model_client import MLflowModelClient


class TestMLflowClient:

@pytest.fixture(autouse=True)
def setup_and_teardown(self):
"""Reset MLflowClient._model_cache before and after each test"""
"""Reset MLflowModelClient._model_cache before and after each test"""
# Save original cache
original_cache = MLflowClient._model_cache.copy()
original_cache = MLflowModelClient._model_cache.copy()

# Clear cache before test
MLflowClient._model_cache = {}
MLflowModelClient._model_cache = {}

yield

# Restore original cache after test
MLflowClient._model_cache = original_cache
MLflowModelClient._model_cache = original_cache

def test_init(self, mlflow_test_client, mock_os_makedirs):
"""Test initialization of MLflowClient"""
"""Test initialization of MLflowModelClient"""
client = mlflow_test_client
# Verify environment variables were set
assert os.environ["MLFLOW_TRACKING_USERNAME"] == "test-user"
Expand Down Expand Up @@ -89,9 +89,16 @@ def test_get_mlflow_params(self, mlflow_test_client, mock_mlflow_client):
# Verify the result contains the expected parameters
assert result == {"param1": "value1", "param2": "value2"}

def test_get_mlflow_models(self, mlflow_test_client, mock_mlflow_client):
@patch("mlex_utils.mlflow_utils.mlflow_model_client.get_flow_run_name")
@patch("mlex_utils.mlflow_utils.mlflow_model_client.get_flow_run_parent_id")
def test_get_mlflow_models(self, mock_get_parent_id, mock_get_flow_name, mlflow_test_client, mock_mlflow_client):
"""Test retrieving MLflow models"""
client = mlflow_test_client

# Configure Prefect mocks
mock_get_flow_name.return_value = "Flow Run 1"
mock_get_parent_id.return_value = "parent-id"

# Create mock model versions
mock_version1 = MagicMock()
mock_version1.name = "model1"
Expand Down Expand Up @@ -119,18 +126,7 @@ def test_get_mlflow_models(self, mlflow_test_client, mock_mlflow_client):
# Configure get_run to return our mock runs
mock_mlflow_client.get_run.side_effect = [mock_run1, mock_run2]

# Mock the get_flow_run_name and get_flow_run_parent_id functions
with (
patch(
"src.utils.mlflow_utils.get_flow_run_name", return_value="Flow Run 1"
),
patch(
"src.utils.mlflow_utils.get_flow_run_parent_id",
return_value="parent-id",
),
):

result = client.get_mlflow_models()
result = client.get_mlflow_models()

# Verify search_model_versions was called
mock_mlflow_client.search_model_versions.assert_called_once()
Expand Down Expand Up @@ -184,11 +180,18 @@ def test_get_mlflow_models_with_livemode(
assert result[0]["label"] == "model1"
assert result[0]["value"] == "model1"

@patch("mlex_utils.mlflow_utils.mlflow_model_client.get_flow_run_name")
@patch("mlex_utils.mlflow_utils.mlflow_model_client.get_flow_run_parent_id")
def test_get_mlflow_models_with_model_type(
self, mlflow_test_client, mock_mlflow_client
self, mock_get_parent_id, mock_get_flow_name, mlflow_test_client, mock_mlflow_client
):
"""Test retrieving MLflow models with model_type filter"""
client = mlflow_test_client

# Configure Prefect mocks
mock_get_flow_name.return_value = "Flow Run 1"
mock_get_parent_id.return_value = "parent-id"

# Create mock model versions
mock_version1 = MagicMock()
mock_version1.name = "model1"
Expand Down Expand Up @@ -216,18 +219,7 @@ def test_get_mlflow_models_with_model_type(
# Configure get_run to return our mock runs
mock_mlflow_client.get_run.side_effect = [mock_run1, mock_run2]

# Mock the get_flow_run_name and get_flow_run_parent_id functions
with (
patch(
"src.utils.mlflow_utils.get_flow_run_parent_id",
return_value="parent-id",
),
patch(
"src.utils.mlflow_utils.get_flow_run_name", return_value="Flow Run 1"
),
):

result = client.get_mlflow_models(model_type="autoencoder")
result = client.get_mlflow_models(model_type="autoencoder")

# Verify the result contains only models with model_type "autoencoder"
assert len(result) == 1
Expand All @@ -251,7 +243,7 @@ def test_load_model_from_memory_cache(self, mlflow_test_client):
client = mlflow_test_client
# Set up memory cache
mock_model = MagicMock(name="memory_model")
MLflowClient._model_cache = {"test-model": mock_model}
MLflowModelClient._model_cache = {"test-model": mock_model}

# Load model
result = client.load_model("test-model")
Expand Down Expand Up @@ -367,13 +359,13 @@ def test_load_model_error(self, mlflow_test_client, mock_mlflow_client):
def test_clear_memory_cache(self):
"""Test clearing the memory cache"""
# Set up memory cache
MLflowClient._model_cache = {"test-model": MagicMock()}
MLflowModelClient._model_cache = {"test-model": MagicMock()}

# Clear memory cache
MLflowClient.clear_memory_cache()
MLflowModelClient.clear_memory_cache()

# Verify memory cache is empty
assert len(MLflowClient._model_cache) == 0
assert len(MLflowModelClient._model_cache) == 0

def test_clear_disk_cache(self, mlflow_test_client):
"""Test clearing the disk cache"""
Expand All @@ -391,4 +383,4 @@ def test_clear_disk_cache(self, mlflow_test_client):
mock_rmtree.assert_called_once_with(client.cache_dir)

# Verify makedirs was called with the cache directory
mock_makedirs.assert_called_once_with(client.cache_dir, exist_ok=True)
mock_makedirs.assert_called_once_with(client.cache_dir, exist_ok=True)
10 changes: 5 additions & 5 deletions src/test/test_reducer.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def test_init_loads_models_from_redis(self):
patch(
"src.arroyo_reduction.redis_model_store.RedisModelStore", autospec=True
) as redis_class_mock,
patch("src.utils.mlflow_utils.MLflowClient") as mlflow_client_mock,
patch("mlex_utils.mlflow_utils.mlflow_model_client.MLflowModelClient") as mlflow_client_mock,
patch(
"src.arroyo_reduction.reducer.LatentSpaceReducer._subscribe_to_model_updates"
),
Expand Down Expand Up @@ -173,7 +173,7 @@ def test_handle_model_update(self):

# Create the patch for MLflowClient and redis before importing anything
with (
patch("src.utils.mlflow_utils.MLflowClient") as mlflow_client_mock,
patch("mlex_utils.mlflow_utils.mlflow_model_client.MLflowModelClient") as mlflow_client_mock,
patch(
"src.arroyo_reduction.redis_model_store.RedisModelStore"
) as redis_mock,
Expand Down Expand Up @@ -264,7 +264,7 @@ def test_handle_duplicate_model_update(self):
"""Test handling duplicate model update notifications with version"""
# Create the patch for MLflowClient and redis before importing anything
with (
patch("src.utils.mlflow_utils.MLflowClient") as mlflow_client_mock,
patch("mlex_utils.mlflow_utils.mlflow_model_client.MLflowModelClient") as mlflow_client_mock,
patch(
"src.arroyo_reduction.redis_model_store.RedisModelStore"
) as redis_mock,
Expand Down Expand Up @@ -325,7 +325,7 @@ def test_loading_flags_during_model_update(self):
"""Test that loading flags are set and reset correctly during model update"""
# Create the patch for MLflowClient and redis before importing anything
with (
patch("src.utils.mlflow_utils.MLflowClient") as mlflow_client_mock,
patch("mlex_utils.mlflow_utils.mlflow_model_client.MLflowModelClient") as mlflow_client_mock,
patch(
"src.arroyo_reduction.redis_model_store.RedisModelStore"
) as redis_mock,
Expand Down Expand Up @@ -391,7 +391,7 @@ def test_subscribe_to_model_updates(self):
with (
patch("threading.Thread", return_value=mock_thread) as mock_thread_class,
patch("src.arroyo_reduction.redis_model_store.RedisModelStore"),
patch("src.utils.mlflow_utils.MLflowClient"),
patch("mlex_utils.mlflow_utils.mlflow_model_client.MLflowModelClient"),
):

# Import the real class but patch the __init__ to avoid complex initialization
Expand Down
Loading