A federated learning system for embryo viability prediction. This system allows multiple IVF clinics to collaboratively train a machine learning model for embryo classification while keeping their patient data private and secure.
- Privacy-First Approach: All patient data remains at each clinic - only model updates are shared
- Collaborative Learning: Multiple clinics can contribute to a single powerful model
- Real-time Visualization: Monitor training progress through an interactive UI
- External Connectivity: Connect clinics across different networks using ngrok
- Real Embryo Data Support: Works with real human embryo images
The system consists of several key components:
- Core Utilities (
embryo_fl_utils.py): Common functions and classes for the federated learning system - Federated Learning Demo (
federated_embryo_demo.py): The server and client implementation - Integrated App (
integrated_embryo_fl_app.py): A user-friendly Gradio interface - Sample Data Creation (
create_sample_data.py): Script to generate test data - Real Data Preparation (
prepare_real_embryo_data.py): Scripts to prepare real embryo images
pip install torch torchvision flwr gradio pyngrok numpy pillow matplotlibEither use synthetic data:
python create_sample_data.pyOr use real embryo data:
python download_real_data.py # Download real embryo data from Google Drive
python prepare_real_embryo_data.py # Process the data for binary classificationpython integrated_embryo_fl_app.pyThis will open a Gradio interface in your browser (typically at http://127.0.0.1:7860).
-
Start Server:
- Go to the "Server Control" tab
- Set the server port (default: 8090)
- Check "Use ngrok for external connections" to allow other clinics to connect over the internet
- Click "Start Server"
-
Share Connection Details:
- Copy the connection command shown in the Server Output box
- Share this command with other participating clinics
-
Start Local Test Clinics (Optional):
- Go to the "IVF Clinic Control" tab
- Enter the server address (usually 127.0.0.1:8090)
- Set a Clinic ID (start with 1)
- Click "Start Local IVF Clinic"
- Repeat with different Clinic IDs if desired
-
Monitor Progress:
- Go to the "Visualization" tab to see training progress
- Check "System Status" periodically to see connected clinics and training status
-
Get the Required Files:
federated_embryo_demo.pyembryo_fl_utils.pycreate_sample_data.py(to generate sample data)
-
Prepare Local Data:
- Generate sample data:
python create_sample_data.py - Or use their own embryo images in the required format
- Generate sample data:
-
Connect to the Server:
- Run the command provided by the host clinic:
python federated_embryo_demo.py client --server_address=X.ngrok.io:YYYY --client_id=Z- Replace Z with a unique clinic ID (2, 3, 4, etc.)
- Model Definition: A CNN model architecture for embryo classification is defined
- Server Initialization: The central server coordinates but never sees patient data
- Client Training: Each clinic trains on their local data
- Aggregation: The server combines all clinic model updates into a global model
- Distribution: The improved global model is sent back to all clinics
- Repeat: This process continues for multiple rounds, improving accuracy
This project is designed to showcase federated learning in a healthcare setting, particularly for IVF clinics where privacy concerns are paramount. The system demonstrates how clinics can leverage collective data insights while respecting patient privacy.
This project is open source and available under the MIT License.