Skip to content

Commit

Permalink
Fixing CEM tests (facebookresearch#508)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: facebookresearch#508

Reviewed By: czxttkl

Differential Revision: D29805519

fbshipit-source-id: dbcde11f8292eb167a0b7a66384e0d1d723b38e4
  • Loading branch information
kittipatv authored and facebook-github-bot committed Jul 21, 2021
1 parent 35da394 commit ba06d68
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 12 deletions.
56 changes: 56 additions & 0 deletions reagent/gym/datasets/replay_buffer_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,59 @@ def __iter__(self):
)

logger.info(f"Episode rewards during training: {rewards}")


class OfflineReplayBufferDataset(torch.utils.data.IterableDataset):
"""
Simply sampling from the replay buffer
"""

def __init__(
self,
env: EnvWrapper,
replay_buffer: ReplayBuffer,
batch_size: int,
num_batches: int,
trainer_preprocessor=None,
):
super().__init__()
self._env = env
self._replay_buffer = replay_buffer
self._batch_size = batch_size
self._num_batches = num_batches
self._trainer_preprocessor = trainer_preprocessor

# TODO: Just use kwargs here?
@classmethod
def create_for_trainer(
cls,
trainer,
env: EnvWrapper,
replay_buffer: ReplayBuffer,
batch_size: int,
num_batches: int,
trainer_preprocessor=None,
device=None,
):
device = device or torch.device("cpu")
if trainer_preprocessor is None:
trainer_preprocessor = make_replay_buffer_trainer_preprocessor(
trainer, device, env
)

return cls(
env=env,
replay_buffer=replay_buffer,
batch_size=batch_size,
num_batches=num_batches,
trainer_preprocessor=trainer_preprocessor,
)

def __iter__(self):
for _ in range(self._num_batches):
train_batch = self._replay_buffer.sample_transition_batch(
batch_size=self._batch_size
)
if self._trainer_preprocessor:
train_batch = self._trainer_preprocessor(train_batch)
yield train_batch
36 changes: 24 additions & 12 deletions reagent/gym/tests/test_gym_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@
import os
import pprint
import unittest
import uuid

import numpy as np
import pytest
import pytorch_lightning as pl
import torch
from parameterized import parameterized
from reagent.core.tensorboardX import summary_writer_context
from reagent.gym.agents.agent import Agent
from reagent.gym.datasets.replay_buffer_dataset import OfflineReplayBufferDataset
from reagent.gym.envs import Gym
from reagent.gym.policies.random_policies import make_random_policy_for_env
from reagent.gym.preprocessors import make_replay_buffer_trainer_preprocessor
Expand Down Expand Up @@ -82,6 +85,11 @@ def evaluate_cem(env, manager, trainer_module, num_eval_episodes: int):
)


def identity_collate(batch):
assert isinstance(batch, list) and len(batch) == 1, f"Got {batch}"
return batch[0]


def run_test_offline(
env_name: str,
model: ModelManager__Union,
Expand Down Expand Up @@ -121,18 +129,22 @@ def run_test_offline(
)

device = torch.device("cuda") if use_gpu else None
# pyre-fixme[6]: Expected `device` for 2nd param but got `Optional[torch.device]`.
trainer_preprocessor = make_replay_buffer_trainer_preprocessor(trainer, device, env)

writer = SummaryWriter()
with summary_writer_context(writer):
for epoch in range(num_train_epochs):
logger.info(f"Evaluating before epoch {epoch}: ")
eval_rewards = evaluate_cem(env, manager, trainer, 1)
for _ in tqdm(range(num_batches_per_epoch)):
train_batch = replay_buffer.sample_transition_batch()
preprocessed_batch = trainer_preprocessor(train_batch)
trainer.train(preprocessed_batch)
dataset = OfflineReplayBufferDataset.create_for_trainer(
trainer,
env,
replay_buffer,
batch_size=minibatch_size,
num_batches=num_batches_per_epoch,
device=device,
)
data_loader = torch.utils.data.DataLoader(dataset, collate_fn=identity_collate)
pl_trainer = pl.Trainer(
max_epochs=num_train_epochs,
gpus=int(use_gpu),
deterministic=True,
default_root_dir=f"lightning_log_{str(uuid.uuid4())}",
)
pl_trainer.fit(trainer, data_loader)

logger.info(f"Evaluating after training for {num_train_epochs} epochs: ")
eval_rewards = evaluate_cem(env, manager, trainer, num_eval_episodes)
Expand Down

0 comments on commit ba06d68

Please sign in to comment.