Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add better GPU support #6

Merged
merged 3 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions PredictionServer/README
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
Solubility / Usability prediction server

For installing required packages use:
pip install -r requirements.txt
For installing required packages use (for inference on CPU or GPU, respectively):
pip install -r requirements_cpu.txt
pip install -r requirements_gpu.txt

All models go into models/
Model name format: {PREDICTION_TYPE}_{MODEL_TYPE}_{Fold number 0-4}_quantized.onnx
Expand Down
5 changes: 3 additions & 2 deletions PredictionServer/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,12 @@ def run_model_distilled(embed_dataloader, args, prediction_type, test_df):
opts.intra_op_num_threads = args.NUM_THREADS
opts.inter_op_num_threads = args.NUM_THREADS
opts.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL

providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] if torch.cuda.is_available() else ["CPUExecutionProvider"]

# Adjust session options
model_paths = [os.path.join(args.MODELS_PATH,
f"{prediction_type}_ESM1b_distilled_quantized.onnx")]
ort_sessions = [onnxruntime.InferenceSession(mp, sess_options=opts) for mp in model_paths]
ort_sessions = [onnxruntime.InferenceSession(mp, sess_options=opts, providers=providers) for mp in model_paths]

embed_dict = {}
inputs_names = ort_sessions[0].get_inputs()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
torch>=1.6
onnxruntime>=1.7.0
onnxruntime-gpu
numpy
pandas
fair-esm
5 changes: 5 additions & 0 deletions PredictionServer/requirements_gpu.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
torch>=1.6
onnxruntime-gpu>=1.7.0
numpy
pandas
fair-esm