-
Notifications
You must be signed in to change notification settings - Fork 344
New inference-time approach for Private MedHelm Tasks #3913
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
| @@ -0,0 +1,192 @@ | |||
| # MedHELM RunSpecs for the private benchmarks from Stanford. | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yifanmai what are your thoughts on adding this file?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sronaghi
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've made edits based on @MiguelAFH's comments.
…es_medhelm_private_proxy_tuning.conf
yifanmai
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general:
- The files need more documentation, which can be placed as a module level docstring in
proxy_tuning_client.py, in the comment inmodel_metadata.yamlandmodel_deployments.yaml, and in the comment on top ofrun_entries_medhelm_private_proxy_tuning.conf. - If this is experimental code, rather than intended for general use, your documentation should clearly say so.
- Please run the linter:
pip install black==24.3.0 mypy==1.16.0 flake8==5.0.4
./pre-commit.shI did not look at your model code too closely, let me know if there's any specific things you would like me to look at.
This addition allows for proxy tuning class to run for MedHelm scenarios. After creating conda environment, only need to run pip install -U "crfm-helm[proxy_tuning]"
|
@yifanmai @MiguelAFH @aunell @suhana13 @HennyJie I ran the formatting check and added documentation. Please let me know what else to do for this PR! |
I provide the code for testing a new inference-time approach which involves combining general and clinical domain LMs for some private MedHelm tasks.
I want to test my method on CLEAR, PatientInstruct, and NoteExtract.
To run the models, it involves downloading the following models locally and changing the model paths at the top of the proxy_tuning_client.py file. I can provide a script to download into carina as well. Here are the models and places for download:
Below are the model configurations and the amount of A100 40GB GPUs they use each:
I have added each model configuration to model_metadata.yaml, model_deployments.yaml, and tokenizer_config.yaml files in both prod_env and src/helm/config. run_entries_medhelm_private_proxy_tuning.conf contains the model run entries for each task. I can also create separate conf files based on amount of GPUs needed.
Each model for each task takes me ~7-22 hours each. I run the models with -n = 1 flag as my code doesn't support multi-threading.
I ended up using basic_summarization_metrics because I couldn't configure what was needed in my helm_env while maintaining compatibility with my code. If there are conda environment issues, I can share my env file and the modified run_specs.