Skip to content

Commit

Permalink
Extract key checking procedure to _get_key
Browse files Browse the repository at this point in the history
  • Loading branch information
himkt committed Nov 24, 2021
1 parent 69cc663 commit 51212d9
Showing 1 changed file with 7 additions and 12 deletions.
19 changes: 7 additions & 12 deletions optuna/integration/allennlp/_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class _VariableManager:
"study_name": "{}_STUDY_NAME",
"trial_id": "{}_TRIAL_ID",
}
NAME_OF_PATH = "optuna.integration.allennlp._variables._VariableManager.NAME_OF_KEY"

def __init__(self, target_pid: int) -> None:
self.target_pid = target_pid
Expand All @@ -44,20 +45,18 @@ def __init__(self, target_pid: int) -> None:
def prefix(self) -> str:
return "{}_OPTUNA_ALLENNLP".format(self.target_pid)

def _get_key(self, name: str) -> Optional[str]:
return self.NAME_OF_KEY.get(name)
def _get_key(self, name: str) -> str:
key = self.NAME_OF_KEY.get(name)
assert key is not None, f"{name} is not found in `{self.NAME_OF_PATH}`."
return key

def set_value(self, name: str, value: str) -> None:
"""Set values to environment variables.
`set_value` is only invoked in `optuna.integration.allennlp.AllenNLPExecutor`.
"""
key = self._get_key(name)
name_of_path = "optuna.integration.allennlp._variables._VariableManager.NAME_OF_KEY"
assert key is not None, f"{name} is not found in `{name_of_path}`."

key = key.format(self.target_pid)
key = self._get_key(name).format(self.target_pid)
os.environ[key] = value

def get_value(self, name: str) -> Optional[str]:
Expand All @@ -66,11 +65,7 @@ def get_value(self, name: str) -> Optional[str]:
`get_value` is only called in `optuna.integration.allennlp.AllenNLPPruningCallback`.
"""
key = self._get_key(name)
name_of_path = "optuna.integration.allennlp._variables._VariableManager.NAME_OF_KEY"
assert key is not None, f"{name} is not found in `{name_of_path}`."

key = key.format(self.target_pid)
key = self._get_key(name).format(self.target_pid)
value = os.environ.get(key)
assert value is not None, f"{key} is not found in environment variables."

Expand Down

0 comments on commit 51212d9

Please sign in to comment.