Skip to content

Commit

Permalink
add data quntity, use unique exp name, add config load, add data filt…
Browse files Browse the repository at this point in the history
…er and saute
  • Loading branch information
ZhengYinan-AIR committed Dec 9, 2023
1 parent d6a0bf0 commit 6164990
Show file tree
Hide file tree
Showing 10 changed files with 177 additions and 32 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,6 @@ results/

*.png

.vscode
.vscode

*.hdf5
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,5 @@ Please also cite the JAXRL repo as well if you use this repo
year = {2021}
}
```

python launcher/viz/viz_map.py --model_location 'results/PointRobot/ddpm_feasibility_hj_N16_minqc_2023-12-09_s54_486'
6 changes: 3 additions & 3 deletions configs/train_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

def get_config(config_string):
base_real_config = dict(
project='FISOR-mainresult',
seed=1,
project='FISOR',
seed=-1,
max_steps=1000001,
eval_episodes=20,
batch_size=2048, #Actor batch size x 2 (so really 1024), critic is fixed to 256
Expand All @@ -18,7 +18,7 @@ def get_config(config_string):

base_data_config = dict(
cost_scale=25,
pr_data='env/point_robot-expert-random-100k.h5py', # The location of point_robot data
pr_data='data/point_robot-expert-random-100k.hdf5', # The location of point_robot data
)

possible_structures = {
Expand Down
Binary file removed env/point_robot-expert-random-100k.h5py
Binary file not shown.
115 changes: 115 additions & 0 deletions filter_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import gym
import numpy as np
import dsrl
from collections import defaultdict
import h5py
from tqdm.auto import trange # noqa

# Use for Saute RL dataset
def state_augmentation(dataset_dict, cost_limit):
data_num = dataset_dict['observations'].shape[0]

observations = []
next_observations = []

is_start = True
for i in trange(data_num, desc='data_processing'):
if is_start:
safe_state = 1.0
is_start = False

observations.append(np.hstack([dataset_dict['observations'][i], safe_state]))
safe_state -= dataset_dict['costs'][i] / cost_limit
safe_state /= 0.99
next_observations.append(np.hstack([dataset_dict['next_observations'][i], safe_state]))

if safe_state <= 0:
dataset_dict['rewards'][i] = -10

if dataset_dict['terminals'][i] or dataset_dict['timeouts'][i]:
is_start = True

print(len(observations))
print(len(next_observations))

observations = np.array(observations)
next_observations = np.array(next_observations)

print(observations.shape)
print(next_observations.shape)

keys = [
'actions', 'rewards', 'costs', 'terminals',
'timeouts'
]

output_path = 'OfflineCarPush1Gymnasium-v0-10.hdf5'
outf = h5py.File(output_path, 'w')
for k in keys:
outf.create_dataset(k, data=dataset_dict[k], compression='gzip')
outf.create_dataset('observations', data = observations, compression='gzip')
outf.create_dataset('next_observations', data=next_observations, compression='gzip')
outf.close()


return dataset_dict

# Use for data quntity exp
def filter_dataset(data_dict, ratio):
done_idx = np.where(
(data_dict["terminals"] == 1) | (data_dict["timeouts"] == 1)
)[0]

trajs= []
for i in range(done_idx.shape[0]):
start = 0 if i == 0 else done_idx[i - 1] + 1
end = done_idx[i] + 1
traj = {k: data_dict[k][start:end] for k in data_dict.keys()}
trajs.append(traj)

print(
f"before filter: traj num = {len(trajs)}, transitions num = {data_dict['observations'].shape[0]}"
)

traj_idx = np.random.randint(0, len(trajs), size=int(len(trajs) * ratio))

processed_data_dict = defaultdict(list)
for k in data_dict.keys():
for i in traj_idx:
processed_data_dict[k].append(trajs[i][k])
processed_data_dict = {
k: np.concatenate(v)
for k, v in processed_data_dict.items()
}


print(
f"before filter: traj num = {traj_idx.shape[0]}, transitions num = {processed_data_dict['observations'].shape[0]}"
)

keys = [
'observations', 'next_observations', 'actions', 'rewards', 'costs', 'terminals',
'timeouts'
]

output_path = 'data/SafeMetaDrive-hardsparse-v0-85-' + str(traj_idx.shape[0]) + '-' + str(ratio) + '.hdf5'
outf = h5py.File(output_path, 'w')
for k in keys:
outf.create_dataset(k, data=processed_data_dict[k], compression='gzip')
outf.close()


env = gym.make("OfflineMetadrive-hardsparse-v0")
dataset_dict = env.get_dataset()

for k, v in dataset_dict.items():
dataset_dict[k] = v.astype(np.float32)


dataset_dict = filter_dataset(dataset_dict, 0.01)






2 changes: 1 addition & 1 deletion jaxrl5/agents/fisor/fisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ def update_actor(agent, batch: DatasetDict) -> Tuple[Agent, Dict[str, float]]:
cost_exp_adv = jnp.exp((vc-qc) * agent.cost_temperature)
reward_exp_adv = jnp.exp((q - v) * agent.reward_temperature)

unsafe_weights = unsafe_condition * jnp.clip(cost_exp_adv, 1, agent.cost_ub) ## ignore vc >0, qc>vc
unsafe_weights = unsafe_condition * jnp.clip(cost_exp_adv, 0, agent.cost_ub) ## ignore vc >0, qc>vc
safe_weights = safe_condition * jnp.clip(reward_exp_adv, 0, 100)

weights = unsafe_weights + safe_weights
Expand Down
17 changes: 13 additions & 4 deletions jaxrl5/data/dsrl_datasets.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import gymnasium as gym
import dsrl
import numpy as np
Expand All @@ -6,12 +7,13 @@


class DSRLDataset(Dataset):
def __init__(self, env: gym.Env, clip_to_eps: bool = True, eps: float = 1e-5, critic_type="qc", data_location=None, cost_scale=1.):
def __init__(self, env: gym.Env, clip_to_eps: bool = True, eps: float = 1e-5, critic_type="qc", data_location=None, cost_scale=1., ratio = 1.0):

if data_location is not None:
# Point Robot
dataset_dict = {}
print('=========Data loading=========')
print('Load data from:', data_location)
print('Load point robot data from:', data_location)
f = h5py.File(data_location, 'r')
dataset_dict["observations"] = np.array(f['state'])
dataset_dict["actions"] = np.array(f['action'])
Expand All @@ -27,19 +29,26 @@ def __init__(self, env: gym.Env, clip_to_eps: bool = True, eps: float = 1e-5, cr

else:
# DSRL
dataset_dict = env.get_dataset()
if ratio == 1.0:
dataset_dict = env.get_dataset()
else:
_, dataset_name = os.path.split(env.dataset_url)
file_list = dataset_name.split('-')
ratio_num = int(float(file_list[-1].split('.')[0]) * ratio)
dataset_ratio = '-'.join(file_list[:-1]) + '-' + str(ratio_num) + '-' + str(ratio) + '.hdf5'
dataset_dict = env.get_dataset(dataset_ratio)
print('max_episode_reward', env.max_episode_reward,
'min_episode_reward', env.min_episode_reward,
'mean_episode_reward', env._max_episode_steps * np.mean(dataset_dict['rewards']))
print('max_episode_cost', env.max_episode_cost,
'min_episode_cost', env.min_episode_cost,
'mean_episode_cost', env._max_episode_steps * np.mean(dataset_dict['costs']))
print('data_num', dataset_dict['actions'].shape[0])
dataset_dict['dones'] = np.logical_or(dataset_dict["terminals"],
dataset_dict["timeouts"]).astype(np.float32)
del dataset_dict["terminals"]
del dataset_dict['timeouts']


if critic_type == "hj":
dataset_dict['costs'] = np.where(dataset_dict['costs']>0, 1*cost_scale, -1)

Expand Down
29 changes: 20 additions & 9 deletions launcher/examples/train_offline.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import os
import sys
sys.path.append('.')
import random
import numpy as np
from absl import app, flags
import datetime
import yaml
from ml_collections import config_flags
from ml_collections import config_flags, ConfigDict
import wandb
from tqdm.auto import trange # noqa
import gymnasium as gym
Expand All @@ -15,10 +16,13 @@
from jaxrl5.agents import FISOR
from jaxrl5.data.dsrl_datasets import DSRLDataset
from jaxrl5.evaluation import evaluate, evaluate_pr
import json


FLAGS = flags.FLAGS
flags.DEFINE_integer('env_id', 30, 'Choose env')
flags.DEFINE_float('ratio', 1.0, 'dataset ratio')
flags.DEFINE_string('project', '', 'project name for wandb')
flags.DEFINE_string('experiment_name', '', 'experiment name for wandb')
config_flags.DEFINE_config_file(
"config",
Expand All @@ -27,10 +31,15 @@
lock_config=False,
)

def to_dict(config):
if isinstance(config, ConfigDict):
return {k: to_dict(v) for k, v in config.items()}
return config


def call_main(details):
wandb.init(project=details['project'], name=details['experiment_name'], group=details['group'])
wandb.config.update(details)
details['agent_kwargs']['cost_scale'] = details['dataset_kwargs']['cost_scale']
wandb.init(project=details['project'], name=details['experiment_name'], group=details['group'], config=details['agent_kwargs'])

if details['env_name'] == 'PointRobot':
assert details['dataset_kwargs']['pr_data'] is not None, "No data for Point Robot"
Expand All @@ -39,7 +48,7 @@ def call_main(details):
ds = DSRLDataset(env, critic_type=details['agent_kwargs']['critic_type'], data_location=details['dataset_kwargs']['pr_data'])
else:
env = gym.make(details['env_name'])
ds = DSRLDataset(env, critic_type=details['agent_kwargs']['critic_type'], cost_scale=details['dataset_kwargs']['cost_scale'])
ds = DSRLDataset(env, critic_type=details['agent_kwargs']['critic_type'], cost_scale=details['dataset_kwargs']['cost_scale'], ratio=details['ratio'])
env_max_steps = env._max_episode_steps
env = wrap_gym(env, cost_limit=details['agent_kwargs']['cost_limit'])
ds.normalize_returns(env.max_episode_reward, env.min_episode_reward, env_max_steps)
Expand All @@ -49,6 +58,7 @@ def call_main(details):
config_dict['env_max_steps'] = env_max_steps

model_cls = config_dict.pop("model_cls")
config_dict.pop("cost_scale")
agent = globals()[model_cls].create(
details['seed'], env.observation_space, env.action_space, **config_dict
)
Expand Down Expand Up @@ -77,17 +87,18 @@ def call_main(details):

def main(_):
parameters = FLAGS.config
np.random.seed(parameters['seed'])

if FLAGS.project != '':
parameters['project'] = FLAGS.project
parameters['env_name'] = env_list[FLAGS.env_id]
parameters['ratio'] = FLAGS.ratio
parameters['group'] = parameters['env_name']

parameters['experiment_name'] = parameters['agent_kwargs']['sampling_method'] + '_' \
+ parameters['agent_kwargs']['actor_objective'] + '_' \
+ parameters['agent_kwargs']['critic_type'] + '_N' \
+ str(parameters['agent_kwargs']['N']) + '_' \
+ parameters['agent_kwargs']['extract_method'] if FLAGS.experiment_name == '' else FLAGS.experiment_name
parameters['experiment_name'] += '_' + str(datetime.date.today()) + '_' + str(parameters['seed'])
parameters['experiment_name'] += '_' + str(datetime.date.today()) + '_s' + str(parameters['seed']) + '_' + str(random.randint(0,1000))

if parameters['env_name'] == 'PointRobot':
parameters['max_steps'] = 100001
Expand All @@ -102,8 +113,8 @@ def main(_):

if not os.path.exists(f"./results/{parameters['group']}/{parameters['experiment_name']}"):
os.makedirs(f"./results/{parameters['group']}/{parameters['experiment_name']}")
with open(f"./results/{parameters['group']}/{parameters['experiment_name']}/config.yaml", "w") as f:
yaml.dump(dict(parameters), f, default_flow_style=False, allow_unicode=True)
with open(f"./results/{parameters['group']}/{parameters['experiment_name']}/config.json", "w") as f:
json.dump(to_dict(parameters), f, indent=4)

call_main(parameters)

Expand Down
27 changes: 13 additions & 14 deletions launcher/viz/viz_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
sys.path.append('.')
from absl import app, flags
import re
import json
import numpy as np
from ml_collections import config_flags
from ml_collections import config_flags, ConfigDict
import matplotlib.pyplot as plt
from matplotlib import colors
import jax
Expand All @@ -13,15 +14,13 @@


FLAGS = flags.FLAGS
flags.DEFINE_integer('env_id', 30, 'Choose env')
flags.DEFINE_integer('seed', -1, '')
flags.DEFINE_string('experiment_name', '', 'experiment name for wandb')
config_flags.DEFINE_config_file(
"config",
"configs/train_config.py:fisor",
"File path to the training hyperparameter configuration.",
lock_config=False,
)
flags.DEFINE_string('model_location', '', 'model location for point robot model')


def to_config_dict(d):
if isinstance(d, dict):
return ConfigDict({k: to_config_dict(v) for k, v in d.items()})
return d

hazard_position_list = [np.array([0.4, -1.2]), np.array([-0.4, 1.2])]

Expand Down Expand Up @@ -190,7 +189,8 @@ def plot_pic(env, agent, model_location):

def load_diffusion_model(model_location):

cfg = FLAGS.config
with open(os.path.join(model_location, 'config.json'), 'r') as file:
cfg = to_config_dict(json.load(file))

env = eval('PointRobot')(id=0, seed=0)

Expand Down Expand Up @@ -227,10 +227,9 @@ def get_model_file():

def main(_):

diffusion_model_location = 'results/PointRobot/ddpm_feasibility_hj_N16_minqc_2023-10-29_208' # expert_random
env, diffusion_agent = load_diffusion_model(diffusion_model_location)
env, diffusion_agent = load_diffusion_model(FLAGS.model_location)

plot_pic(env, diffusion_agent, diffusion_model_location)
plot_pic(env, diffusion_agent, FLAGS.model_location)


if __name__ == '__main__':
Expand Down
7 changes: 7 additions & 0 deletions run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,11 @@ export XLA_PYTHON_CLIENT_PREALLOCATE=False
export CUDA_VISIBLE_DEVICES=0
# export JAX_PLATFORM_NAME=cpu

# python launcher/examples/train_offline.py --env_id 0 --config configs/train_config.py:fisor_cb1

# python launcher/examples/train_offline.py --env_id 0 --config configs/train_config.py:fisor_cb1
# python launcher/examples/train_offline.py --env_id 21 --config configs/train_config.py:fisor --ratio 0.1
# python launcher/examples/train_offline.py --env_id 21 --config configs/train_config.py:fisor --ratio 0.1
#

python launcher/examples/train_offline.py --env_id 29 --config configs/train_config.py:fisor

0 comments on commit 6164990

Please sign in to comment.