diff --git a/.env.example b/.env.example index c67b8c2..fcfa257 100644 --- a/.env.example +++ b/.env.example @@ -33,7 +33,6 @@ PREFECT_API_URL=http://prefect:4200/api FLOW_NAME="Parent flow/launch_parent_flow" TIMEZONE="US/Pacific" PREFECT_TAGS='["latent-space-explorer"]' -FLOW_TYPE="docker" # MLFlow MLFLOW_TRACKING_URI=http://mlflow:5000 diff --git a/docker-compose.yml b/docker-compose.yml index 176d004..2c18815 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,6 +1,6 @@ services: prefect: - image: prefecthq/prefect:2.14-python3.11 + image: prefecthq/prefect:3.4.2-python3.11 command: prefect server start container_name: prefect-server environment: @@ -136,7 +136,6 @@ services: FLOW_NAME: '${FLOW_NAME}' TIMEZONE: "${TIMEZONE}" PREFECT_TAGS: "${PREFECT_TAGS}" - FLOW_TYPE: "${FLOW_TYPE}" CONTAINER_NETWORK: "${CONTAINER_NETWORK}" # Slurm jobs PARTITIONS_CPU: "${PARTITIONS_CPU}" diff --git a/live_operator_example/lse_operator.py b/live_operator_example/lse_operator.py index 0932366..05847b9 100644 --- a/live_operator_example/lse_operator.py +++ b/live_operator_example/lse_operator.py @@ -6,7 +6,7 @@ import torch # Import the MLflowClient class -from src.utils.mlflow_utils import MLflowClient +from mlex_utils.mlflow_utils.mlflow_model_client import MLflowModelClient from tiled.client import from_uri from tiled_utils import write_results @@ -36,7 +36,7 @@ def get_default_device(): if __name__ == "__main__": # Create MLflow client - mlflow_client = MLflowClient( + mlflow_client = MLflowModelClient( tracking_uri=MLFLOW_TRACKING_URI, username=MLFLOW_TRACKING_USERNAME, password=MLFLOW_TRACKING_PASSWORD, diff --git a/pyproject.toml b/pyproject.toml index fcf363e..e53be9b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,7 @@ lse = [ "kaleido<=0.2.1", "humanhash3", "mlex_file_manager@git+https://github.com/mlexchange/mlex_file_manager.git", - "mlex_utils[all]@git+https://github.com/mlexchange/mlex_utils.git", + "mlex_utils[all]@git+https://github.com/xiaoyachong/mlex_utils.git@xiaoya-update-prefect3", "numpy<2.0.0", "pandas", "Pillow", @@ -77,5 +77,8 @@ arroyo = [ "torchvision==0.17.2", "transformers==4.47.1", "umap-learn", - "joblib==1.4.2" + "joblib==1.4.2", + "mlflow==2.22.0", + "mlex_utils[all]@git+https://github.com/xiaoyachong/mlex_utils.git@xiaoya-update-prefect3", + "numpy<2.0.0" ] diff --git a/src/arroyo_reduction/reducer.py b/src/arroyo_reduction/reducer.py index cc5b5e1..ae053a1 100644 --- a/src/arroyo_reduction/reducer.py +++ b/src/arroyo_reduction/reducer.py @@ -12,7 +12,7 @@ from arroyosas.schemas import RawFrameEvent from PIL import Image -from src.utils.mlflow_utils import MLflowClient +from mlex_utils.mlflow_utils.mlflow_model_client import MLflowModelClient from .redis_model_store import RedisModelStore @@ -88,7 +88,7 @@ def __init__(self): self.device = device # Load models from MLflow - mlflow_client = MLflowClient() + mlflow_client = MLflowModelClient() self.mlflow_client = mlflow_client # Store for later use # Set loading flags before loading models diff --git a/src/callbacks/execute.py b/src/callbacks/execute.py index 087eef6..70839e9 100644 --- a/src/callbacks/execute.py +++ b/src/callbacks/execute.py @@ -29,7 +29,7 @@ parse_job_params, parse_model_params, ) -from src.utils.mlflow_utils import MLflowClient +from mlex_utils.mlflow_utils.mlflow_model_client import MLflowModelClient from src.utils.plot_utils import generate_notification MODE = os.getenv("MODE", "") @@ -37,14 +37,13 @@ FLOW_NAME = os.getenv("FLOW_NAME", "") PREFECT_TAGS = json.loads(os.getenv("PREFECT_TAGS", '["latent-space-explorer"]')) RESULTS_DIR = os.getenv("RESULTS_DIR", "") -FLOW_TYPE = os.getenv("FLOW_TYPE", "conda") # Initialize Redis model store instead of direct Redis client REDIS_HOST = os.getenv("REDIS_HOST", "kvrocks") REDIS_PORT = int(os.getenv("REDIS_PORT", 6666)) logger = logging.getLogger(__name__) -mlflow_client = MLflowClient() +mlflow_client = MLflowModelClient() @callback( Output("mlflow-model-dropdown", "options", allow_duplicate=True), @@ -189,7 +188,6 @@ def run_latent_space( model_parameters, USER, project_name, - FLOW_TYPE, latent_space_params, dim_reduction_params, mlflow_model_id, @@ -459,17 +457,13 @@ def run_clustering( api_key=tiled_results.data_tiled_api_key, ) - model_exec_params = clustering_models[model_name] + clustering_params = clustering_models[model_name] job_params = parse_clustering_job_params( data_project_fvec, model_parameters, USER, project_name, - FLOW_TYPE, - model_exec_params["image_name"], - model_exec_params["image_tag"], - model_exec_params["python_file_name"], - model_exec_params["conda_env"], + clustering_params ) if MODE == "dev": diff --git a/src/callbacks/infrastructure_check.py b/src/callbacks/infrastructure_check.py index 3f0c9cb..b8461d4 100644 --- a/src/callbacks/infrastructure_check.py +++ b/src/callbacks/infrastructure_check.py @@ -6,8 +6,8 @@ from src.components.infrastructure import create_infra_state_details from src.utils.data_utils import tiled_results -from src.utils.mlflow_utils import MLflowClient -from src.utils.prefect import check_prefect_ready, check_prefect_worker_ready +from mlex_utils.mlflow_utils.mlflow_model_client import MLflowModelClient +from mlex_utils.prefect_utils.core import check_prefect_ready, check_prefect_worker_ready TIMEZONE = os.getenv("TIMEZONE", "US/Pacific") FLOW_NAME = os.getenv("FLOW_NAME", "") @@ -33,7 +33,7 @@ def check_infra_state(n_intervals): # MLFLOW: Check MLFlow is reachable try: - mlflow_client = MLflowClient() + mlflow_client = MLflowModelClient() infra_state["mlflow_ready"] = mlflow_client.check_mlflow_ready() if not infra_state["mlflow_ready"]: any_infra_down = True diff --git a/src/callbacks/live_mode.py b/src/callbacks/live_mode.py index f05f334..5151df4 100644 --- a/src/callbacks/live_mode.py +++ b/src/callbacks/live_mode.py @@ -20,7 +20,7 @@ from src.arroyo_reduction.redis_model_store import ( RedisModelStore, ) # Import the RedisModelStore class -from src.utils.mlflow_utils import MLflowClient +from mlex_utils.mlflow_utils.mlflow_model_client import MLflowModelClient from src.utils.plot_utils import ( generate_scatter_data, plot_empty_heatmap, @@ -33,7 +33,7 @@ redis_model_store = RedisModelStore(host=REDIS_HOST, port=REDIS_PORT) logger = logging.getLogger("lse.live_mode") -mlflow_client = MLflowClient() +mlflow_client = MLflowModelClient() @callback( diff --git a/src/test/test_mlflow_client.py b/src/test/test_mlflow_client.py index 32c4ed2..7a9dab6 100644 --- a/src/test/test_mlflow_client.py +++ b/src/test/test_mlflow_client.py @@ -6,27 +6,27 @@ import pytest from src.test.test_utils import mlflow_test_client, mock_mlflow_client, mock_os_makedirs -from src.utils.mlflow_utils import MLflowClient +from mlex_utils.mlflow_utils.mlflow_model_client import MLflowModelClient class TestMLflowClient: @pytest.fixture(autouse=True) def setup_and_teardown(self): - """Reset MLflowClient._model_cache before and after each test""" + """Reset MLflowModelClient._model_cache before and after each test""" # Save original cache - original_cache = MLflowClient._model_cache.copy() + original_cache = MLflowModelClient._model_cache.copy() # Clear cache before test - MLflowClient._model_cache = {} + MLflowModelClient._model_cache = {} yield # Restore original cache after test - MLflowClient._model_cache = original_cache + MLflowModelClient._model_cache = original_cache def test_init(self, mlflow_test_client, mock_os_makedirs): - """Test initialization of MLflowClient""" + """Test initialization of MLflowModelClient""" client = mlflow_test_client # Verify environment variables were set assert os.environ["MLFLOW_TRACKING_USERNAME"] == "test-user" @@ -89,9 +89,16 @@ def test_get_mlflow_params(self, mlflow_test_client, mock_mlflow_client): # Verify the result contains the expected parameters assert result == {"param1": "value1", "param2": "value2"} - def test_get_mlflow_models(self, mlflow_test_client, mock_mlflow_client): + @patch("mlex_utils.mlflow_utils.mlflow_model_client.get_flow_run_name") + @patch("mlex_utils.mlflow_utils.mlflow_model_client.get_flow_run_parent_id") + def test_get_mlflow_models(self, mock_get_parent_id, mock_get_flow_name, mlflow_test_client, mock_mlflow_client): """Test retrieving MLflow models""" client = mlflow_test_client + + # Configure Prefect mocks + mock_get_flow_name.return_value = "Flow Run 1" + mock_get_parent_id.return_value = "parent-id" + # Create mock model versions mock_version1 = MagicMock() mock_version1.name = "model1" @@ -119,18 +126,7 @@ def test_get_mlflow_models(self, mlflow_test_client, mock_mlflow_client): # Configure get_run to return our mock runs mock_mlflow_client.get_run.side_effect = [mock_run1, mock_run2] - # Mock the get_flow_run_name and get_flow_run_parent_id functions - with ( - patch( - "src.utils.mlflow_utils.get_flow_run_name", return_value="Flow Run 1" - ), - patch( - "src.utils.mlflow_utils.get_flow_run_parent_id", - return_value="parent-id", - ), - ): - - result = client.get_mlflow_models() + result = client.get_mlflow_models() # Verify search_model_versions was called mock_mlflow_client.search_model_versions.assert_called_once() @@ -184,11 +180,18 @@ def test_get_mlflow_models_with_livemode( assert result[0]["label"] == "model1" assert result[0]["value"] == "model1" + @patch("mlex_utils.mlflow_utils.mlflow_model_client.get_flow_run_name") + @patch("mlex_utils.mlflow_utils.mlflow_model_client.get_flow_run_parent_id") def test_get_mlflow_models_with_model_type( - self, mlflow_test_client, mock_mlflow_client + self, mock_get_parent_id, mock_get_flow_name, mlflow_test_client, mock_mlflow_client ): """Test retrieving MLflow models with model_type filter""" client = mlflow_test_client + + # Configure Prefect mocks + mock_get_flow_name.return_value = "Flow Run 1" + mock_get_parent_id.return_value = "parent-id" + # Create mock model versions mock_version1 = MagicMock() mock_version1.name = "model1" @@ -216,18 +219,7 @@ def test_get_mlflow_models_with_model_type( # Configure get_run to return our mock runs mock_mlflow_client.get_run.side_effect = [mock_run1, mock_run2] - # Mock the get_flow_run_name and get_flow_run_parent_id functions - with ( - patch( - "src.utils.mlflow_utils.get_flow_run_parent_id", - return_value="parent-id", - ), - patch( - "src.utils.mlflow_utils.get_flow_run_name", return_value="Flow Run 1" - ), - ): - - result = client.get_mlflow_models(model_type="autoencoder") + result = client.get_mlflow_models(model_type="autoencoder") # Verify the result contains only models with model_type "autoencoder" assert len(result) == 1 @@ -251,7 +243,7 @@ def test_load_model_from_memory_cache(self, mlflow_test_client): client = mlflow_test_client # Set up memory cache mock_model = MagicMock(name="memory_model") - MLflowClient._model_cache = {"test-model": mock_model} + MLflowModelClient._model_cache = {"test-model": mock_model} # Load model result = client.load_model("test-model") @@ -367,13 +359,13 @@ def test_load_model_error(self, mlflow_test_client, mock_mlflow_client): def test_clear_memory_cache(self): """Test clearing the memory cache""" # Set up memory cache - MLflowClient._model_cache = {"test-model": MagicMock()} + MLflowModelClient._model_cache = {"test-model": MagicMock()} # Clear memory cache - MLflowClient.clear_memory_cache() + MLflowModelClient.clear_memory_cache() # Verify memory cache is empty - assert len(MLflowClient._model_cache) == 0 + assert len(MLflowModelClient._model_cache) == 0 def test_clear_disk_cache(self, mlflow_test_client): """Test clearing the disk cache""" @@ -391,4 +383,4 @@ def test_clear_disk_cache(self, mlflow_test_client): mock_rmtree.assert_called_once_with(client.cache_dir) # Verify makedirs was called with the cache directory - mock_makedirs.assert_called_once_with(client.cache_dir, exist_ok=True) + mock_makedirs.assert_called_once_with(client.cache_dir, exist_ok=True) \ No newline at end of file diff --git a/src/test/test_reducer.py b/src/test/test_reducer.py index cd1041a..755021b 100644 --- a/src/test/test_reducer.py +++ b/src/test/test_reducer.py @@ -98,7 +98,7 @@ def test_init_loads_models_from_redis(self): patch( "src.arroyo_reduction.redis_model_store.RedisModelStore", autospec=True ) as redis_class_mock, - patch("src.utils.mlflow_utils.MLflowClient") as mlflow_client_mock, + patch("mlex_utils.mlflow_utils.mlflow_model_client.MLflowModelClient") as mlflow_client_mock, patch( "src.arroyo_reduction.reducer.LatentSpaceReducer._subscribe_to_model_updates" ), @@ -173,7 +173,7 @@ def test_handle_model_update(self): # Create the patch for MLflowClient and redis before importing anything with ( - patch("src.utils.mlflow_utils.MLflowClient") as mlflow_client_mock, + patch("mlex_utils.mlflow_utils.mlflow_model_client.MLflowModelClient") as mlflow_client_mock, patch( "src.arroyo_reduction.redis_model_store.RedisModelStore" ) as redis_mock, @@ -264,7 +264,7 @@ def test_handle_duplicate_model_update(self): """Test handling duplicate model update notifications with version""" # Create the patch for MLflowClient and redis before importing anything with ( - patch("src.utils.mlflow_utils.MLflowClient") as mlflow_client_mock, + patch("mlex_utils.mlflow_utils.mlflow_model_client.MLflowModelClient") as mlflow_client_mock, patch( "src.arroyo_reduction.redis_model_store.RedisModelStore" ) as redis_mock, @@ -325,7 +325,7 @@ def test_loading_flags_during_model_update(self): """Test that loading flags are set and reset correctly during model update""" # Create the patch for MLflowClient and redis before importing anything with ( - patch("src.utils.mlflow_utils.MLflowClient") as mlflow_client_mock, + patch("mlex_utils.mlflow_utils.mlflow_model_client.MLflowModelClient") as mlflow_client_mock, patch( "src.arroyo_reduction.redis_model_store.RedisModelStore" ) as redis_mock, @@ -391,7 +391,7 @@ def test_subscribe_to_model_updates(self): with ( patch("threading.Thread", return_value=mock_thread) as mock_thread_class, patch("src.arroyo_reduction.redis_model_store.RedisModelStore"), - patch("src.utils.mlflow_utils.MLflowClient"), + patch("mlex_utils.mlflow_utils.mlflow_model_client.MLflowModelClient"), ): # Import the real class but patch the __init__ to avoid complex initialization diff --git a/src/test/test_utils.py b/src/test/test_utils.py index 41c7946..3634018 100644 --- a/src/test/test_utils.py +++ b/src/test/test_utils.py @@ -5,15 +5,6 @@ # Common fixtures for MLflow testing -@pytest.fixture -def mock_mlflow_client(): - """Mock MlflowClient class""" - with patch("src.utils.mlflow_utils.MlflowClient") as mock_client_class: - mock_client = MagicMock() - mock_client_class.return_value = mock_client - yield mock_client - - @pytest.fixture def mock_os_makedirs(): """Mock os.makedirs to avoid file system errors""" @@ -22,18 +13,29 @@ def mock_os_makedirs(): @pytest.fixture -def mlflow_test_client(mock_mlflow_client, mock_os_makedirs): - """Create a MLflowClient instance with mocked dependencies""" - with patch("mlflow.set_tracking_uri"): # Avoid actually setting tracking URI - from src.utils.mlflow_utils import MLflowClient - - client = MLflowClient( - tracking_uri="http://mock-mlflow:5000", - username="test-user", - password="test-password", - cache_dir="/tmp/test_mlflow_cache", - ) - return client +def mock_mlflow_client(): + """Create a mock MLflow client""" + return MagicMock() + + +@pytest.fixture +def mlflow_test_client(mock_os_makedirs, mock_mlflow_client): + """Create a MLflowModelClient instance with mocked dependencies""" + from mlex_utils.mlflow_utils.mlflow_model_client import MLflowModelClient + + # Patch the MLflowClient class to return our mock + with patch("mlex_utils.mlflow_utils.mlflow_model_client.MlflowClient") as mock_mlflow_class: + mock_mlflow_class.return_value = mock_mlflow_client + + with patch("mlex_utils.mlflow_utils.mlflow_model_client.mlflow.set_tracking_uri"): + client = MLflowModelClient( + tracking_uri="http://mock-mlflow:5000", + username="test-user", + password="test-password", + cache_dir="/tmp/test_mlflow_cache", + ) + + yield client # Common fixtures for Redis testing @@ -87,7 +89,7 @@ def redis_mlflow_mocks(): """Set up and start Redis and MLflow mocks""" # Create the patches redis_mock_patch = patch("src.arroyo_reduction.redis_model_store.RedisModelStore") - mlflow_client_mock_patch = patch("src.utils.mlflow_utils.MLflowClient") + mlflow_client_mock_patch = patch("mlex_utils.mlflow_utils.mlflow_model_client.MLflowModelClient") # Start all the patches redis_mock = redis_mock_patch.start() @@ -116,7 +118,7 @@ def redis_mlflow_mocks(): mock_dimred.predict.return_value = {"umap_coords": umap_coords} # Configure the load_model method to return appropriate models - mock_mlflow_client.load_model.side_effect = lambda model_name: ( + mock_mlflow_client.load_model.side_effect = lambda model_name, version=None: ( mock_autoencoder if model_name == "test_autoencoder" else mock_dimred ) @@ -149,10 +151,10 @@ def mock_logger(): # Add a specific mock for the live_mode MLflow client @pytest.fixture def mock_live_mode_mlflow_client(): - """Mock MLflowClient for live_mode callbacks""" + """Mock MLflowModelClient for live_mode callbacks""" with patch("src.callbacks.live_mode.mlflow_client") as mock_client: # Configure check_model_compatibility for testing mock_client.check_model_compatibility.side_effect = ( lambda auto, dimred: auto and dimred and auto != "incompatible" ) - yield mock_client + yield mock_client \ No newline at end of file diff --git a/src/utils/job_utils.py b/src/utils/job_utils.py index 33f1bc7..f11d09b 100644 --- a/src/utils/job_utils.py +++ b/src/utils/job_utils.py @@ -3,32 +3,15 @@ import os from urllib.parse import urljoin -from src.utils.mlflow_utils import MLflowClient +from mlex_utils.mlflow_utils.mlflow_model_client import MLflowModelClient # I/O parameters for job execution -READ_DIR_MOUNT = os.getenv("READ_DIR_MOUNT", None) -WRITE_DIR_MOUNT = os.getenv("WRITE_DIR_MOUNT", None) WRITE_DIR = os.getenv("WRITE_DIR", "") RESULTS_TILED_URI = os.getenv("RESULTS_TILED_URI", "") -RESULTS_TILED_API_KEY = os.getenv("RESULTS_TILED_API_KEY", "") MLFLOW_TRACKING_URI = os.getenv("MLFLOW_TRACKING_URI", "http://mlflow:5000") -MLFLOW_TRACKING_USERNAME = os.getenv("MLFLOW_TRACKING_USERNAME", "") -MLFLOW_TRACKING_PASSWORD = os.getenv("MLFLOW_TRACKING_PASSWORD", "") - -# Flow parameters -PARTITIONS_CPU = json.loads(os.getenv("PARTITIONS_CPU", "[]")) -RESERVATIONS_CPU = json.loads(os.getenv("RESERVATIONS_CPU", "[]")) -MAX_TIME_CPU = os.getenv("MAX_TIME_CPU", "1:00:00") -PARTITIONS_GPU = json.loads(os.getenv("PARTITIONS_CPU", "[]")) -RESERVATIONS_GPU = json.loads(os.getenv("RESERVATIONS_CPU", "[]")) -MAX_TIME_GPU = os.getenv("MAX_TIME_CPU", "1:00:00") -SUBMISSION_SSH_KEY = os.getenv("SUBMISSION_SSH_KEY", "") -FORWARD_PORTS = json.loads(os.getenv("FORWARD_PORTS", "[]")) -CONTAINER_NETWORK = os.getenv("CONTAINER_NETWORK", "") -FLOW_TYPE = os.getenv("FLOW_TYPE", "conda") logger = logging.getLogger(__name__) -mlflow_client = MLflowClient() +mlflow_client = MLflowModelClient() def parse_tiled_url(url, user, project_name, tiled_base_path="/api/v1/metadata"): """ @@ -47,7 +30,6 @@ def parse_job_params( model_parameters, user, project_name, - flow_type, latent_space_params, dim_reduction_params, mlflow_model_id=None, @@ -55,7 +37,6 @@ def parse_job_params( """ Parse training job parameters """ - # TODO: Use model_name to define the conda_env/algorithm to be executed data_uris = [dataset.uri for dataset in data_project.datasets] results_dir = f"{WRITE_DIR}/{user}" @@ -63,117 +44,42 @@ def parse_job_params( io_parameters = { "uid_retrieve": "", "data_uris": data_uris, - "data_tiled_api_key": data_project.api_key, "data_type": data_project.data_type, "root_uri": data_project.root_uri, "models_dir": f"{results_dir}/models", "results_tiled_uri": parse_tiled_url(RESULTS_TILED_URI, user, project_name), - "results_tiled_api_key": RESULTS_TILED_API_KEY, "results_dir": f"{results_dir}", "mlflow_uri": MLFLOW_TRACKING_URI, - "mlflow_tracking_username": MLFLOW_TRACKING_USERNAME, - "mlflow_tracking_password": MLFLOW_TRACKING_PASSWORD, "mlflow_model": mlflow_model_id, } auto_params = mlflow_client.get_mlflow_params(mlflow_model_id) logger.info(f"Autoencoder parameters: {auto_params}") - ls_python_file_name_inference = latent_space_params["python_file_name"]["inference"] - dm_python_file_name = dim_reduction_params["python_file_name"] - - if flow_type == "podman" or "docker": - job_params = { - "flow_type": flow_type, - "params_list": [ - { - "image_name": latent_space_params["image_name"], - "image_tag": latent_space_params["image_tag"], - "command": f"python {ls_python_file_name_inference}", - "params": { - "io_parameters": io_parameters, - "model_parameters": auto_params, - }, - "volumes": [ - f"{READ_DIR_MOUNT}:/tiled_storage", - ], - "network": CONTAINER_NETWORK, - }, - { - "image_name": dim_reduction_params["image_name"], - "image_tag": dim_reduction_params["image_tag"], - "command": f"python {dm_python_file_name}", - "params": { - "io_parameters": io_parameters, - "model_parameters": model_parameters, - }, - "volumes": [ - f"{READ_DIR_MOUNT}:/tiled_storage", - ], - "network": CONTAINER_NETWORK, - }, - ], - } - - elif flow_type == "conda": - job_params = { - "flow_type": "conda", - "params_list": [ - { - "conda_env_name": latent_space_params["conda_env"], - "python_file_name": ls_python_file_name_inference, - "params": { - "io_parameters": io_parameters, - "model_parameters": auto_params, - }, - }, - { - "conda_env_name": dim_reduction_params["conda_env"], - "python_file_name": dm_python_file_name, - "params": { - "io_parameters": io_parameters, - "model_parameters": model_parameters, - }, - }, - ], - } - - else: - job_params = { - "flow_type": "slurm", - "params_list": [ - { - "job_name": "latent_space_explorer", - "num_nodes": 1, - "partitions": PARTITIONS_CPU, - "reservations": RESERVATIONS_CPU, - "max_time": MAX_TIME_CPU, - "conda_env_name": latent_space_params["conda_env"], - "python_file_name": ls_python_file_name_inference, - "submission_ssh_key": SUBMISSION_SSH_KEY, - "forward_ports": FORWARD_PORTS, - "params": { - "io_parameters": io_parameters, - "model_parameters": auto_params, - }, - }, - { - "job_name": "latent_space_explorer", - "num_nodes": 1, - "partitions": PARTITIONS_CPU, - "reservations": RESERVATIONS_CPU, - "max_time": MAX_TIME_CPU, - "conda_env_name": dim_reduction_params["conda_env"], - "python_file_name": dm_python_file_name, - "submission_ssh_key": SUBMISSION_SSH_KEY, - "forward_ports": FORWARD_PORTS, - "params": { - "io_parameters": io_parameters, - "model_parameters": model_parameters, - }, - }, - ], - } + # Create a simpler params_list structure with model_name and task_name + params_list = [ + { + "model_name": latent_space_params["model_name"], + "task_name": "inference", + "params": { + "io_parameters": io_parameters, + "model_parameters": auto_params, + }, + }, + { + "model_name": dim_reduction_params["model_name"], + "task_name": "excute", + "params": { + "io_parameters": io_parameters, + "model_parameters": model_parameters, + }, + }, + ] + + # Keep the job params simplified + job_params = { + "params_list": params_list, + } return job_params @@ -183,16 +89,11 @@ def parse_clustering_job_params( model_parameters, user, project_name, - flow_type, - image_name=None, - image_tag=None, - python_file_name=None, - conda_env=None, + clustering_params ): """ - Parse job parameters + Parse job parameters for clustering """ - # TODO: Use model_name to define the conda_env/algorithm to be executed data_uris = [dataset.uri for dataset in data_project.datasets] results_dir = f"{WRITE_DIR}/{user}" @@ -200,70 +101,29 @@ def parse_clustering_job_params( io_parameters = { "uid_retrieve": "", "data_uris": data_uris, - "data_tiled_api_key": data_project.api_key, "data_type": data_project.data_type, "root_uri": data_project.root_uri, "save_model_path": f"{results_dir}/models", "results_tiled_uri": parse_tiled_url(RESULTS_TILED_URI, user, project_name), - "results_tiled_api_key": RESULTS_TILED_API_KEY, "results_dir": f"{results_dir}", } - if flow_type == "podman" or flow_type == "docker": - job_params = { - "flow_type": flow_type, - "params_list": [ - { - "image_name": image_name, - "image_tag": image_tag, - "command": f"python {python_file_name}", - "params": { - "io_parameters": io_parameters, - "model_parameters": model_parameters, - }, - "volumes": [ - f"{READ_DIR_MOUNT}:/tiled_storage", - ], - "network": CONTAINER_NETWORK, - } - ], - } - - elif flow_type == "conda": - job_params = { - "flow_type": "conda", - "params_list": [ - { - "conda_env_name": conda_env, - "python_file_name": python_file_name, - "params": { - "io_parameters": io_parameters, - "model_parameters": model_parameters, - }, - }, - ], + # Create a simpler params_list structure with model_name and task_name + params_list = [ + { + "model_name": clustering_params["model_name"], + "task_name": "excute", + "params": { + "io_parameters": io_parameters, + "model_parameters": model_parameters, + }, } + ] - else: - job_params = { - "flow_type": "slurm", - "params_list": [ - { - "job_name": "latent_space_explorer", - "num_nodes": 1, - "partitions": PARTITIONS_CPU, - "reservations": RESERVATIONS_CPU, - "max_time": MAX_TIME_CPU, - "conda_env_name": "mlex_dimension_reduction_pca", - "submission_ssh_key": SUBMISSION_SSH_KEY, - "forward_ports": FORWARD_PORTS, - "params": { - "io_parameters": io_parameters, - "model_parameters": model_parameters, - }, - } - ], - } + # Keep the job params simplified + job_params = { + "params_list": params_list, + } return job_params @@ -279,7 +139,7 @@ def parse_model_params(model_parameters_html, log, percentiles, mask): # param["props"]["children"][0] is the label # param["props"]["children"][1] is the input parameter_container = param["props"]["children"][1] - # The achtual parameter item is the first and only child of the parameter container + # The actual parameter item is the first and only child of the parameter container parameter_item = parameter_container["props"]["children"]["props"] key = parameter_item["id"]["param_key"] if "value" in parameter_item: @@ -295,4 +155,4 @@ def parse_model_params(model_parameters_html, log, percentiles, mask): input_params["log"] = log input_params["percentiles"] = percentiles input_params["mask"] = mask if mask != "None" else None - return input_params, errors + return input_params, errors \ No newline at end of file diff --git a/src/utils/mlflow_utils.py b/src/utils/mlflow_utils.py deleted file mode 100644 index b3758bb..0000000 --- a/src/utils/mlflow_utils.py +++ /dev/null @@ -1,348 +0,0 @@ -import logging -import os -import shutil -import hashlib -import tempfile - -import mlflow -from mlex_utils.prefect_utils.core import get_flow_run_name -from mlflow.tracking import MlflowClient - -from src.utils.prefect import get_flow_run_parent_id - -MLFLOW_TRACKING_URI = os.getenv("MLFLOW_TRACKING_URI") -MLFLOW_TRACKING_USERNAME = os.getenv("MLFLOW_TRACKING_USERNAME", "") -MLFLOW_TRACKING_PASSWORD = os.getenv("MLFLOW_TRACKING_PASSWORD", "") -# Define a cache directory that will be mounted as a volume -MLFLOW_CACHE_DIR = os.getenv("MLFLOW_CACHE_DIR", os.path.join(tempfile.gettempdir(), "mlflow_cache")) - -logger = logging.getLogger("lse.mlflow_utils") - - -class MLflowClient: - """A wrapper class for MLflow client operations.""" - - # In-memory model cache (for quick access) - _model_cache = {} - - def __init__( - self, - tracking_uri=None, - username=None, - password=None, - cache_dir=None - ): - """ - Initialize the MLflow client with connection parameters. - - Args: - tracking_uri: MLflow tracking server URI - username: MLflow authentication username - password: MLflow authentication password - cache_dir: Directory to store cached models - """ - self.tracking_uri = tracking_uri or os.getenv("MLFLOW_TRACKING_URI") - self.username = username or os.getenv("MLFLOW_TRACKING_USERNAME", "") - self.password = password or os.getenv("MLFLOW_TRACKING_PASSWORD", "") - self.cache_dir = cache_dir or MLFLOW_CACHE_DIR - - # Create cache directory if it doesn't exist - os.makedirs(self.cache_dir, exist_ok=True) - - # Set environment variables - os.environ['MLFLOW_TRACKING_USERNAME'] = self.username - os.environ['MLFLOW_TRACKING_PASSWORD'] = self.password - - # Set tracking URI - mlflow.set_tracking_uri(self.tracking_uri) - - # Create client - self.client = MlflowClient() - - def check_model_compatibility(self, autoencoder_model, dim_reduction_model): - """ - Check if autoencoder and dimension reduction models are compatible. - - Models are compatible if autoencoder latent_dim matches dimension reduction input_dim. - - Args: - autoencoder_model (str): Autoencoder model name (or "name:version" format) - dim_reduction_model (str): Dimension reduction model name (or "name:version" format) - - Returns: - bool: True if models are compatible, False otherwise - """ - if not autoencoder_model or not dim_reduction_model: - return False - - # Check dimension compatibility - try: - # get_mlflow_params now handles "name:version" format automatically - auto_params = self.get_mlflow_params(autoencoder_model) - dimred_params = self.get_mlflow_params(dim_reduction_model) - - auto_dim = int(auto_params.get("latent_dim", 0)) - dimred_dim = int(dimred_params.get("input_dim", 0)) - - return auto_dim > 0 and auto_dim == dimred_dim - except Exception as e: - logger.warning(f"Error checking dimensions: {e}") - return False - - def check_mlflow_ready(self): - """ - Check if MLflow server is reachable by performing a lightweight API call. - - Returns: - bool: True if MLflow server is reachable, False otherwise - """ - try: - # Perform a lightweight API call to verify connectivity - # search_experiments() is a simple call that requires minimal server resources - self.client.search_experiments(max_results=1) - logger.info("MLflow server is reachable") - return True - except Exception as e: - logger.warning(f"MLflow server is not reachable: {e}") - return False - - def get_mlflow_params(self, mlflow_model_id, version=None): - """ - Get MLflow model parameters for a specific version. - - Args: - mlflow_model_id: Model name or "name:version" format - version: Specific version (optional, can be parsed from mlflow_model_id) - - Returns: - dict: Model parameters - """ - # Parse version from identifier if present - if version is None: - if isinstance(mlflow_model_id, str) and ":" in mlflow_model_id: - mlflow_model_id, version = mlflow_model_id.split(":", 1) - else: - version = "1" # Default to version 1 for backward compatibility - - model_version_details = self.client.get_model_version( - name=mlflow_model_id, - version=str(version) - ) - run_id = model_version_details.run_id - - run_info = self.client.get_run(run_id) - params = run_info.data.params - return params - - def get_mlflow_models(self, livemode=False, model_type=None): - """ - Retrieve available MLflow models and create dropdown options. - - Args: - livemode (bool): If True, only include models where exp_type == "live_mode". - If False, exclude models where exp_type == "live_mode" and use custom labels. - model_type (str, optional): Filter by run tag 'model_type'. - - Returns: - list: Dropdown options for MLflow models matching the tag filters. - """ - try: - all_versions = self.client.search_model_versions() - - model_map = {} # model name -> latest version info - - for v in all_versions: - try: - current = model_map.get(v.name) - if current and int(v.version) <= int(current.version): - continue - - run = self.client.get_run(v.run_id) - run_tags = run.data.tags - - # Tag-based filtering - exp_type = run_tags.get("exp_type") - if livemode: - if exp_type != "live_mode": - continue - else: - if exp_type == "live_mode": - continue - - if model_type is not None and run_tags.get("model_type") != model_type: - continue - - model_map[v.name] = v - - except Exception as e: - logger.warning(f"Error processing model version {v.name} v{v.version}: {e}") - continue - - # Build dropdown options - model_options = [] - for name in sorted(model_map.keys()): - if livemode: - label = name - else: - try: - parent_id = get_flow_run_parent_id(name) - label = get_flow_run_name(parent_id) - except Exception as e: - logger.warning(f"Failed to get label for model '{name}': {e}") - label = name # fallback - - model_options.append({"label": label, "value": name}) - - return model_options - - except Exception as e: - logger.warning(f"Error retrieving MLflow models: {e}") - return [{"label": "Error loading models", "value": None}] - - def get_model_versions(self, model_name): - """ - Get all available versions for a specific model. - - Args: - model_name (str): Name of the model - - Returns: - list: List of version options sorted by version number (latest first) - """ - try: - versions = self.client.search_model_versions(f"name='{model_name}'") - - if not versions: - return [] - - # Sort versions by version number (descending - latest first) - sorted_versions = sorted( - versions, - key=lambda v: int(v.version), - reverse=True - ) - - # Create dropdown options - version_options = [ - {"label": f"Version {v.version}", "value": v.version} - for v in sorted_versions - ] - - return version_options - - except Exception as e: - logger.error(f"Error retrieving versions for model {model_name}: {e}") - return [] - - def _get_cache_path(self, model_name, version=None): - """Get the cache path for a model""" - # Create a unique filename based on model name and version - if version is None: - # Use a hash of the model name as part of the filename - hash_obj = hashlib.md5(model_name.encode()) - hash_str = hash_obj.hexdigest() - return os.path.join(self.cache_dir, f"{model_name}_{hash_str}") - else: - # Include version in the filename - return os.path.join(self.cache_dir, f"{model_name}_v{version}") - - def load_model(self, model_name, version=None): - """ - Load a model from MLflow by name with disk caching - - Args: - model_name: Name of the model in MLflow - version: Specific version to load (optional, defaults to latest) - - Returns: - The loaded model or None if loading fails - """ - if model_name is None: - logger.error("Cannot load model: model_name is None") - return None - - # Create a cache key that includes version if specified - cache_key = f"{model_name}:{version}" if version else model_name - - # Check in-memory cache first - if cache_key in self._model_cache: - logger.info(f"Using in-memory cached model: {cache_key}") - return self._model_cache[cache_key] - - try: - # Get the specific version or latest version - if version is None: - versions = self.client.search_model_versions(f"name='{model_name}'") - - if not versions: - logger.error(f"No versions found for model {model_name}") - return None - - version = max([int(mv.version) for mv in versions]) - - model_uri = f"models:/{model_name}/{version}" - - # Check disk cache - cache_path = self._get_cache_path(model_name, version) - if os.path.exists(cache_path): - logger.info(f"Loading model from disk cache: {cache_path}") - try: - model = mlflow.pyfunc.load_model(cache_path) - self._model_cache[cache_key] = model - logger.info(f"Successfully loaded cached model: {cache_key}") - return model - except Exception as e: - logger.warning(f"Error loading model from cache: {e}") - - # Create cache directory if it doesn't exist - os.makedirs(os.path.dirname(cache_path) if os.path.dirname(cache_path) else ".", exist_ok=True) - - logger.info(f"Downloading model {model_name}, version {version} from MLflow to cache") - - try: - # Download the model directly to the cache location - download_path = mlflow.artifacts.download_artifacts( - artifact_uri=f"models:/{model_name}/{version}", - dst_path=cache_path - ) - logger.info(f"Downloaded model artifacts to: {download_path}") - - # Load the model from the cached location - model = mlflow.pyfunc.load_model(download_path) - logger.info(f"Successfully loaded model from cache: {cache_key}") - - # Store in memory cache - self._model_cache[cache_key] = model - - return model - except Exception as e: - logger.warning(f"Error downloading artifacts: {e}") - - # Fallback: Load the model directly from MLflow - logger.info(f"Falling back to direct model loading from MLflow") - model = mlflow.pyfunc.load_model(model_uri) - logger.info(f"Successfully loaded model: {cache_key}") - - # Store in memory cache - self._model_cache[cache_key] = model - - return model - except Exception as e: - logger.error(f"Error loading model {cache_key}: {e}") - return None - - @classmethod - def clear_memory_cache(cls): - """Clear the in-memory model cache""" - logger.info("Clearing in-memory model cache") - cls._model_cache.clear() - - def clear_disk_cache(self): - """Clear the disk cache""" - logger.info(f"Clearing disk cache at {self.cache_dir}") - try: - # Delete and recreate the cache directory - shutil.rmtree(self.cache_dir) - os.makedirs(self.cache_dir, exist_ok=True) - except Exception as e: - logger.error(f"Error clearing disk cache: {e}") \ No newline at end of file diff --git a/src/utils/prefect.py b/src/utils/prefect.py deleted file mode 100644 index cbe5588..0000000 --- a/src/utils/prefect.py +++ /dev/null @@ -1,43 +0,0 @@ -import asyncio - -from prefect import get_client -from prefect.client.schemas.objects import DeploymentStatus - -# TODO: Move this to mlex_utils - - -async def _check_prefect_ready(): - async with get_client() as client: - healthcheck_result = await client.api_healthcheck() - if healthcheck_result is not None: - raise Exception("Prefect API is not healthy.") - - -def check_prefect_ready(): - return asyncio.run(_check_prefect_ready()) - - -async def _check_prefect_worker_ready(deployment_name: str): - async with get_client() as client: - deployment = await client.read_deployment_by_name(deployment_name) - assert ( - deployment - ), f"No deployment found in config for deployment_name {deployment_name}" - if deployment.status != DeploymentStatus.READY: - raise Exception("Deployment used for training and inference is not ready.") - - -def check_prefect_worker_ready(deployment_name: str): - return asyncio.run(_check_prefect_worker_ready(deployment_name)) - - -async def _get_flow_run_parent_id(flow_run_id): - async with get_client() as client: - child_flow_run = await client.read_flow_run(flow_run_id) - parent_task_run_id = child_flow_run.parent_task_run_id - parent_task_run = await client.read_task_run(parent_task_run_id) - return parent_task_run.flow_run_id - - -def get_flow_run_parent_id(flow_run_id): - return asyncio.run(_get_flow_run_parent_id(flow_run_id))