Skip to content

Commit

Permalink
fixed tune/protox/env/mqo/mqo_wrapper.py
Browse files Browse the repository at this point in the history
  • Loading branch information
wangpatrick57 committed Sep 2, 2024
1 parent 6228088 commit 96338bc
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
6 changes: 5 additions & 1 deletion tune/protox/env/mqo/mqo_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def __init__(
self.logger = logger

def _update_best_observed(
self, query_metric_data: dict[str, BestQueryRun], force_overwrite=False
self, query_metric_data: dict[str, BestQueryRun], force_overwrite: bool=False
) -> None:
if query_metric_data is not None:
for qid, best_run in query_metric_data.items():
Expand All @@ -176,6 +176,7 @@ def _update_best_observed(
None,
)
if self.logger:
assert best_run.runtime is not None
self.logger.get_logger(__name__).debug(
f"[best_observe] {qid}: {best_run.runtime/1e6} (force: {force_overwrite})"
)
Expand Down Expand Up @@ -307,6 +308,7 @@ def transmute(
)

# Execute.
assert self.logger is not None
self.logger.get_logger(__name__).info("MQOWrapper called step_execute()")
success, info = self.unwrapped.step_execute(success, runs, info)
if info["query_metric_data"]:
Expand All @@ -319,6 +321,7 @@ def transmute(
with torch.no_grad():
# Pass the mutilated action back through.
assert isinstance(self.action_space, HolonSpace)
assert info["actions_info"] is not None
info["actions_info"][
"best_observed_holon_action"
] = best_observed_holon_action
Expand Down Expand Up @@ -412,6 +415,7 @@ def reset(self, *args: Any, **kwargs: Any) -> Tuple[Any, EnvInfoDict]: # type:

# Update the reward baseline.
if self.unwrapped.reward_utility:
assert self.unwrapped.baseline_metric
self.unwrapped.reward_utility.set_relative_baseline(
self.unwrapped.baseline_metric,
prev_result=metric,
Expand Down
4 changes: 2 additions & 2 deletions tune/protox/env/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,9 @@ class EnvInfoDict(TypedDict, total=False):
attempted_changes: Tuple[list[str], list[str]]

# Metric of this step.
metric: float
metric: Optional[float]
# Reward of this step.
reward: float
reward: Optional[float]
# Whether any queries timed out or the workload as a whole timed out.
did_anything_time_out: bool
# Query metric data.
Expand Down

0 comments on commit 96338bc

Please sign in to comment.