Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use user HOME dir for cache, config, models #253

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions rfdiffusion/inference/model_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from rfdiffusion.model_input_logger import pickle_function_call
import sys

SCRIPT_DIR=os.path.dirname(os.path.realpath(__file__))
from rfdiffusion.util import USER_DIR


TOR_INDICES = util.torsion_indices
TOR_CAN_FLIP = util.torsion_can_flip
Expand Down Expand Up @@ -63,7 +64,8 @@ def initialize(self, conf: DictConfig) -> None:
if conf.inference.model_directory_path is not None:
model_directory = conf.inference.model_directory_path
else:
model_directory = f"{SCRIPT_DIR}/../../models"
# set default model weigths directory under env var control. fallback to user $HOME/rfdiffusion/models
model_directory = os.environ.get('RFD_MODELS',f"{USER_DIR}/models")

print(f"Reading models from {model_directory}")

Expand Down Expand Up @@ -122,7 +124,7 @@ def initialize(self, conf: DictConfig) -> None:
if conf.inference.schedule_directory_path is not None:
schedule_directory = conf.inference.schedule_directory_path
else:
schedule_directory = f"{SCRIPT_DIR}/../../schedules"
schedule_directory = f"{USER_DIR}/schedules"

# Check for cache schedule
if not os.path.exists(schedule_directory):
Expand Down
11 changes: 11 additions & 0 deletions rfdiffusion/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,17 @@
from rfdiffusion.chemical import *
from rfdiffusion.scoring import *

import sys
from pathlib import Path
# define RFdiffusion directory located in user HOME directory.
USER_HOME=str(Path.home())
USER_DIR=f"{USER_HOME}/rfdiffusion"

try:
Path(USER_DIR).mkdir(parents=True, exist_ok=True)
except FileExistsError as msg:
print(f'{USER_DIR} already exist and is a file.')
sys.exit(1)

def generate_Cbeta(N, Ca, C):
# recreate Cb given N,Ca,C
Expand Down
7 changes: 5 additions & 2 deletions scripts/run_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,24 @@
from omegaconf import OmegaConf
import hydra
import logging
from rfdiffusion.util import writepdb_multi, writepdb
from rfdiffusion.util import writepdb_multi, writepdb, USER_DIR
from rfdiffusion.inference import utils as iu
from hydra.core.hydra_config import HydraConfig
import numpy as np
import random
import glob

# set hyfra inference config dir under environment variable control
hydra_cfg_dir=os.environ.get('RFD_HYDRA_CFG', f"{USER_DIR}/config/inference")


def make_deterministic(seed=0):
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)


@hydra.main(version_base=None, config_path="../config/inference", config_name="base")
@hydra.main(version_base=None, config_path=hydra_cfg_dir, config_name="base")
def main(conf: HydraConfig) -> None:
log = logging.getLogger(__name__)
if conf.inference.deterministic:
Expand Down