Skip to content

Commit

Permalink
Minor fixes to experiment runner script
Browse files Browse the repository at this point in the history
  • Loading branch information
djfoote committed Oct 22, 2024
1 parent c19688a commit 27fce0a
Showing 1 changed file with 21 additions and 12 deletions.
33 changes: 21 additions & 12 deletions train_reward_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,10 @@


GPU_NUMBER = None
N_ITER = 5
N_COMPARISONS = 1_000
N_ITER = 50
N_COMPARISONS = 20_000

EXP_NAME = "debug_po_refactor_small"

#######################################################################################################################
##################################################### Expt params #####################################################
Expand All @@ -44,7 +45,7 @@
config = {
"environment": {
"name": "StealingGridworld",
"grid_size": 3,
"grid_size": 5,
"horizon": 30,
"reward_for_depositing": 100,
"reward_for_picking_up": 1,
Expand All @@ -62,17 +63,18 @@
"transition_oversampling": 10,
"initial_epoch_multiplier": 1,
"feedback": {
"type": "preference",
"type": "scalar",
},
"trajectory_generator": {
"epsilon": 0.1,
},
"visibility": {
"visibility": "full",
"visibility": "partial",
# Available visibility mask keys:
# "full": All of the grid is visible. Not actually used, but should be set for easier comparison.
# "(n-1)x(n-1)": All but the outermost ring of the grid is visible.
"visibility_mask_key": "full",
# "cross": Only the row and column containing the home location are visible.
"visibility_mask_key": "(n-1)x(n-1)",
},
"reward_trainer": {
"num_epochs": 3,
Expand All @@ -88,21 +90,21 @@
if config["visibility"]["visibility"] == "full" and config["visibility"]["visibility_mask_key"] != "full":
raise ValueError(
f'If visibility is "full", then visibility mask key must be "full".'
f'Instead, it is {wandb.config["visibility"]["visibility_mask_key"]}.'
f'Instead, it is {config["visibility"]["visibility_mask_key"]}.'
)

if config["visibility"]["visibility"] not in ["full", "partial"]:
raise ValueError(
f'Unknown visibility {wandb.config["visibility"]["visibility"]}.' f'Visibility must be "full" or "partial".'
f'Unknown visibility {config["visibility"]["visibility"]}.' f'Visibility must be "full" or "partial".'
)

if config["reward_model"]["type"] != "NonImageCnnRewardNet":
raise ValueError(f'Unknown reward model type {wandb.config["reward_model"]["type"]}.')
raise ValueError(f'Unknown reward model type {config["reward_model"]["type"]}.')

available_visibility_mask_keys = ["full", "(n-1)x(n-1)"]
available_visibility_mask_keys = ["full", "(n-1)x(n-1)", "cross"]
if config["visibility"]["visibility_mask_key"] not in available_visibility_mask_keys:
raise ValueError(
f'Unknown visibility mask key {wandb.config["visibility"]["visibility_mask_key"]}.'
f'Unknown visibility mask key {config["visibility"]["visibility_mask_key"]}.'
f"Available visibility mask keys are {available_visibility_mask_keys}."
)

Expand All @@ -113,7 +115,7 @@
run = wandb.init(
project="assisting-bounded-humans",
notes="finalizing logging pipeline for now",
name="setup_debug_4",
name=EXP_NAME,
tags=[
"debug",
],
Expand All @@ -127,6 +129,11 @@ def construct_visibility_mask(grid_size, visibility_mask_key):
visibility_mask = np.zeros((grid_size, grid_size), dtype=np.bool_)
visibility_mask[1:-1, 1:-1] = True
return visibility_mask
elif visibility_mask_key == "cross":
visibility_mask = np.zeros((grid_size, grid_size), dtype=np.bool_)
visibility_mask[grid_size // 2, :] = True
visibility_mask[:, grid_size // 2] = True
return visibility_mask
else:
raise ValueError(f"Unknown visibility mask key {visibility_mask_key}.")

Expand Down Expand Up @@ -190,6 +197,7 @@ def construct_visibility_mask(grid_size, visibility_mask_key):
loss=MSERewardLoss(),
rng=rng,
epochs=config["reward_trainer"]["num_epochs"],
lr=1e-2,
)
else:
feedback_model = PreferenceModel(model=reward_net)
Expand All @@ -198,6 +206,7 @@ def construct_visibility_mask(grid_size, visibility_mask_key):
loss=CrossEntropyRewardLoss(),
rng=rng,
epochs=config["reward_trainer"]["num_epochs"],
lr=1e-2,
)


Expand Down

0 comments on commit 27fce0a

Please sign in to comment.