Skip to content

Commit

Permalink
feat: add pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
FrankTianTT committed Oct 7, 2022
1 parent f5cfb00 commit 97e510b
Show file tree
Hide file tree
Showing 71 changed files with 1,969 additions and 1,277 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
/cmrl.egg-info/
/exp/
/stable-baselines3/
/stable-baselines3/
38 changes: 8 additions & 30 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,33 +1,11 @@
# See https://pre-commit.com for more information
# See https://pre-commit.com/hooks.html for more hooks
# pre-commit==2.13.0
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.0.1
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.3.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer

- repo: https://github.com/PyCQA/isort
rev: 5.8.0
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/psf/black
rev: 22.10.0
hooks:
- id: isort

- repo: https://github.com/PyCQA/flake8
rev: 3.9.2
hooks:
- id: flake8

- repo: https://github.com/psf/black
rev: 22.3.0
hooks:
- id: black


exclude: |
(?x)^(
etc|
.*?/migrations|
bmiss/settings/instance.*|
.*?proto.*
)
- id: black
28 changes: 28 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Contributing

In order to contibute to this repository you will need developer access to this repo.
To know more about the project go to the [README](README.md) first.

## Pre-commit hooks

Pre-commits hooks have been configured for this project using the
[pre-commit](https://pre-commit.com/) library:

- [black](https://github.com/psf/black) python formatter
- [flake8](https://flake8.pycqa.org/en/latest/) python linter

To get them going on your side, first install pre-commit:

```bash
pip install pre-commit
```

Then run the following commands from the root directory of this repository:

```bash
pre-commit install
pre-commit run --all-files
```

These pre-commits are applied to all the files, except the directory tmp/
(see .pre-commit-config.yaml)
2 changes: 1 addition & 1 deletion cmrl/agent/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .core import complete_agent_cfg, load_agent, Agent, RandomAgent
from .core import Agent, RandomAgent, complete_agent_cfg, load_agent
54 changes: 36 additions & 18 deletions cmrl/agent/core.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import abc
import pathlib
from typing import Any, Union, Optional
from typing import Any, Optional, Union

import gym
import hydra
from omegaconf import DictConfig, OmegaConf
import numpy as np
import omegaconf
from omegaconf import DictConfig, OmegaConf

import cmrl.models
import cmrl.types
Expand Down Expand Up @@ -37,9 +37,7 @@ def act(self, obs: np.ndarray, **kwargs) -> np.ndarray:
return self.env.action_space.sample()


def complete_agent_cfg(
env: gym.Env, agent_cfg: omegaconf.DictConfig
):
def complete_agent_cfg(env: gym.Env, agent_cfg: omegaconf.DictConfig):
obs_shape = env.observation_space.shape
act_shape = env.action_space.shape

Expand All @@ -52,14 +50,16 @@ def _create_numpy_config(array):
return {
"_target_": "numpy.array",
"object": array.tolist(),
"dtype": str(array.dtype)
"dtype": str(array.dtype),
}

_check_and_replace("num_inputs", obs_shape[0], agent_cfg)
if "action_space" in agent_cfg.keys() and isinstance(
agent_cfg.action_space, omegaconf.DictConfig
agent_cfg.action_space, omegaconf.DictConfig
):
_check_and_replace("low", _create_numpy_config(env.action_space.low), agent_cfg.action_space)
_check_and_replace(
"low", _create_numpy_config(env.action_space.low), agent_cfg.action_space
)
_check_and_replace(
"high", _create_numpy_config(env.action_space.high), agent_cfg.action_space
)
Expand All @@ -80,25 +80,41 @@ def _create_numpy_config(array):
agent_cfg.action_ub = _create_numpy_config(env.action_space.high)

if "env" in agent_cfg.keys():
_check_and_replace("low", _create_numpy_config(env.action_space.low), agent_cfg.env.action_space)
_check_and_replace(
"high", _create_numpy_config(env.action_space.high), agent_cfg.env.action_space
"low",
_create_numpy_config(env.action_space.low),
agent_cfg.env.action_space,
)
_check_and_replace(
"high",
_create_numpy_config(env.action_space.high),
agent_cfg.env.action_space,
)
_check_and_replace("shape", env.action_space.shape, agent_cfg.env.action_space)

_check_and_replace("low", _create_numpy_config(env.observation_space.low), agent_cfg.env.observation_space)
_check_and_replace(
"high", _create_numpy_config(env.observation_space.high), agent_cfg.env.observation_space
"low",
_create_numpy_config(env.observation_space.low),
agent_cfg.env.observation_space,
)
_check_and_replace(
"high",
_create_numpy_config(env.observation_space.high),
agent_cfg.env.observation_space,
)
_check_and_replace(
"shape", env.observation_space.shape, agent_cfg.env.observation_space
)
_check_and_replace("shape", env.observation_space.shape, agent_cfg.env.observation_space)

return agent_cfg


def load_agent(agent_path: Union[str, pathlib.Path],
env: gym.Env,
type: Optional[str] = "best",
device: Optional[str] = None) -> Agent:
def load_agent(
agent_path: Union[str, pathlib.Path],
env: gym.Env,
type: Optional[str] = "best",
device: Optional[str] = None,
) -> Agent:
"""Loads an agent from a Hydra config file at the given path.
For agent of type "pytorch_sac.agent.sac.SACAgent", the directory
Expand Down Expand Up @@ -127,7 +143,9 @@ def load_agent(agent_path: Union[str, pathlib.Path],

complete_agent_cfg(env, cfg.algorithm.agent)
agent: pytorch_sac.SAC = hydra.utils.instantiate(cfg.algorithm.agent)
agent.load_checkpoint(ckpt_path=agent_path / "sac_{}.pth".format(type), device=device)
agent.load_checkpoint(
ckpt_path=agent_path / "sac_{}.pth".format(type), device=device
)
return SACAgent(agent)
else:
raise ValueError("Invalid agent configuration.")
2 changes: 1 addition & 1 deletion cmrl/agent/sac_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def __init__(self, sac_agent: pytorch_sac.SAC):
self.sac_agent = sac_agent

def act(
self, obs: np.ndarray, sample: bool = False, batched: bool = False, **kwargs
self, obs: np.ndarray, sample: bool = False, batched: bool = False, **kwargs
) -> np.ndarray:
"""Issues an action given an observation.
Expand Down
103 changes: 58 additions & 45 deletions cmrl/algorithms/offline/mopo.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,30 @@
import os
from typing import Optional, Sequence, cast

import gym
import emei
import gym
import hydra.utils
import numpy as np
import omegaconf
import torch

import cmrl.constants
import cmrl.agent
import cmrl.constants
import cmrl.models
import cmrl.models.dynamics
import cmrl.third_party.pytorch_sac as pytorch_sac
import cmrl.types
import cmrl.util
import cmrl.util.creator as creator
from cmrl.agent.sac_wrapper import SACAgent
from cmrl.algorithms.util import (
evaluate,
maybe_load_trained_offline_model,
maybe_replace_sac_buffer,
rollout_model_and_populate_sac_buffer,
truncated_linear,
)
from cmrl.util.video import VideoRecorder
from cmrl.algorithms.util import evaluate, rollout_model_and_populate_sac_buffer, maybe_replace_sac_buffer, \
truncated_linear, maybe_load_trained_offline_model

MBPO_LOG_FORMAT = cmrl.constants.RESULTS_LOG_FORMAT + [
("epoch", "E", "int"),
Expand All @@ -32,13 +37,13 @@


def train(
env: emei.EmeiEnv,
test_env: emei.EmeiEnv,
termination_fn: Optional[cmrl.types.TermFnType],
reward_fn: Optional[cmrl.types.RewardFnType],
cfg: omegaconf.DictConfig,
silent: bool = False,
work_dir: Optional[str] = None,
env: emei.EmeiEnv,
test_env: emei.EmeiEnv,
termination_fn: Optional[cmrl.types.TermFnType],
reward_fn: Optional[cmrl.types.RewardFnType],
cfg: omegaconf.DictConfig,
silent: bool = False,
work_dir: Optional[str] = None,
) -> np.float32:
"""Train agent by MOPO algorithm.
Expand Down Expand Up @@ -75,18 +80,21 @@ def train(
)
logger.register_group(
"model_eval",
[("obs{}".format(o), "O{}".format(o), "float") for o in range(obs_shape[0])] + [
("reward", "R", "float")] + MODEL_EVAL_LOG_FORMAT,
[("obs{}".format(o), "O{}".format(o), "float") for o in range(obs_shape[0])]
+ [("reward", "R", "float")]
+ MODEL_EVAL_LOG_FORMAT,
color="green",
dump_frequency=1,
disable_console_dump=True
disable_console_dump=True,
)
save_video = cfg.get("save_video", False)
video_recorder = VideoRecorder(work_dir if save_video else None)
numpy_generator = np.random.default_rng(seed=cfg.seed)

# -------------- Create initial dataset --------------
dynamics = creator.create_dynamics(cfg.dynamics, obs_shape, act_shape, logger=logger)
dynamics = creator.create_dynamics(
cfg.dynamics, obs_shape, act_shape, logger=logger
)
replay_buffer = creator.create_replay_buffer(
cfg,
obs_shape,
Expand All @@ -97,33 +105,37 @@ def train(
if hasattr(env, "get_dataset"):
params, dataset_type = cfg.task.env.split("___")[-2:]
data_dict = env.get_dataset("{}-{}".format(params, dataset_type))
replay_buffer.add_batch(data_dict["observations"],
data_dict["actions"],
data_dict["next_observations"],
data_dict["rewards"],
data_dict["terminals"].astype(bool) | data_dict["timeouts"].astype(bool))
replay_buffer.add_batch(
data_dict["observations"],
data_dict["actions"],
data_dict["next_observations"],
data_dict["rewards"],
data_dict["terminals"].astype(bool) | data_dict["timeouts"].astype(bool),
)
else:
raise NotImplementedError

# ---------------------------------------------------------
# --------------------- Training Loop ---------------------
rollout_batch_size = (
cfg.task.effective_model_rollouts_per_step * cfg.algorithm.freq_train_model
)
trains_per_epoch = int(
np.ceil(cfg.task.epoch_length / cfg.task.freq_train_model)
cfg.task.effective_model_rollouts_per_step * cfg.algorithm.freq_train_model
)
trains_per_epoch = int(np.ceil(cfg.task.epoch_length / cfg.task.freq_train_model))
updates_made = 0
env_steps = 0
fake_env = cmrl.models.FakeEnv(env,
dynamics,
reward_fn,
termination_fn,
generator=numpy_generator,
penalty_coeff=cfg.algorithm.penalty_coeff)
model_env = gym.wrappers.TimeLimit(cmrl.models.GymBehaviouralFakeEnv(fake_env=fake_env, real_env=test_env),
max_episode_steps=test_env.spec.max_episode_steps,
new_step_api=True)
fake_env = cmrl.models.FakeEnv(
env,
dynamics,
reward_fn,
termination_fn,
generator=numpy_generator,
penalty_coeff=cfg.algorithm.penalty_coeff,
)
model_env = gym.wrappers.TimeLimit(
cmrl.models.GymBehaviouralFakeEnv(fake_env=fake_env, real_env=test_env),
max_episode_steps=test_env.spec.max_episode_steps,
new_step_api=True,
)

if hasattr(env, "causal_graph"):
oracle_causal_graph = env.causal_graph
Expand All @@ -133,21 +145,18 @@ def train(
if isinstance(dynamics, cmrl.models.dynamics.ConstraintBasedDynamics):
dynamics.set_oracle_mask("transition", oracle_causal_graph[:-1])

existed_trained_model = maybe_load_trained_offline_model(dynamics, cfg, obs_shape, act_shape,
work_dir=work_dir)
existed_trained_model = maybe_load_trained_offline_model(
dynamics, cfg, obs_shape, act_shape, work_dir=work_dir
)
if not existed_trained_model:
dynamics.learn(replay_buffer,
**cfg.dynamics,
work_dir=work_dir)
dynamics.learn(replay_buffer, **cfg.dynamics, work_dir=work_dir)

best_eval_reward = -np.inf
sac_buffer = None

for epoch in range(cfg.task.num_steps // cfg.task.epoch_length):
rollout_length = int(
truncated_linear(
*(cfg.task.rollout_schedule + [epoch + 1])
)
truncated_linear(*(cfg.task.rollout_schedule + [epoch + 1]))
)
sac_buffer_capacity = rollout_length * rollout_batch_size * trains_per_epoch
sac_buffer_capacity *= cfg.task.num_epochs_to_retain_sac_buffer
Expand All @@ -169,7 +178,7 @@ def train(
rollout_length,
rollout_batch_size,
logger,
epoch
epoch,
)

if debug_mode:
Expand All @@ -185,7 +194,7 @@ def train(
use_real_data = numpy_generator.random() < cfg.algorithm.real_data_ratio
which_buffer = replay_buffer if use_real_data else sac_buffer
if (env_steps + 1) % cfg.task.sac_updates_every_steps != 0 or len(
which_buffer
which_buffer
) < cfg.task.sac_batch_size:
break # only update every once in a while

Expand Down Expand Up @@ -224,11 +233,15 @@ def train(
"rollout_length": rollout_length,
},
)
agent.sac_agent.save_checkpoint(ckpt_path=os.path.join(work_dir, "sac_final.pth"), silence=True)
agent.sac_agent.save_checkpoint(
ckpt_path=os.path.join(work_dir, "sac_final.pth"), silence=True
)
if real_rewards.mean() > best_eval_reward:
video_recorder.save(f"{epoch}.mp4")
best_eval_reward = real_rewards.mean()
agent.sac_agent.save_checkpoint(ckpt_path=os.path.join(work_dir, "sac_best.pth"))
agent.sac_agent.save_checkpoint(
ckpt_path=os.path.join(work_dir, "sac_best.pth")
)

env_steps += 1
return np.float32(best_eval_reward)
Loading

0 comments on commit 97e510b

Please sign in to comment.