Skip to content

Commit

Permalink
add training with remove_recommended_ids(obs_mask & obs_next_mask)
Browse files Browse the repository at this point in the history
  • Loading branch information
yuyq18 committed Nov 16, 2023
1 parent 47d1c4d commit 3dc75df
Show file tree
Hide file tree
Showing 24 changed files with 434 additions and 53 deletions.
1 change: 1 addition & 0 deletions examples/advance/run_A2C_IPS.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def setup_policy_model(args, state_tracker, train_envs, test_envs_dict):
VectorReplayBuffer(args.buffer_size, len(train_envs)),
# preprocess_fn=state_tracker.build_state,
exploration_noise=args.exploration_noise,
remove_recommended_ids = args.remove_recommended_ids
)

test_collector_set = CollectorSet(rec_policy, test_envs_dict, args.buffer_size, args.test_num,
Expand Down
1 change: 1 addition & 0 deletions examples/advance/run_DORL.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def setup_policy_model(args, state_tracker, train_envs, test_envs_dict):
VectorReplayBuffer(args.buffer_size, len(train_envs)),
# preprocess_fn=state_tracker.build_state,
exploration_noise=args.exploration_noise,
remove_recommended_ids = args.remove_recommended_ids
)

test_collector_set = CollectorSet(rec_policy, test_envs_dict, args.buffer_size, args.test_num,
Expand Down
1 change: 1 addition & 0 deletions examples/advance/run_Intrinsic.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def setup_policy_model(args, state_tracker, train_envs, test_envs_dict):
VectorReplayBuffer(args.buffer_size, len(train_envs)),
# preprocess_fn=state_tracker.build_state,
exploration_noise=args.exploration_noise,
remove_recommended_ids = args.remove_recommended_ids
)

test_collector_set = CollectorSet(rec_policy, test_envs_dict, args.buffer_size, args.test_num,
Expand Down
1 change: 1 addition & 0 deletions examples/advance/run_MOPO.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def setup_policy_model(args, state_tracker, train_envs, test_envs_dict):
VectorReplayBuffer(args.buffer_size, len(train_envs)),
# preprocess_fn=state_tracker.build_state,
exploration_noise=args.exploration_noise,
remove_recommended_ids = args.remove_recommended_ids
)

test_collector_set = CollectorSet(rec_policy, test_envs_dict, args.buffer_size, args.test_num,
Expand Down
3 changes: 3 additions & 0 deletions examples/policy/policy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ def get_args_all():
parser.add_argument("--user_model_name", type=str, default="DeepFM")
parser.add_argument('--seed', default=2022, type=int)
parser.add_argument('--cuda', default=0, type=int)

# training
parser.add_argument('--remove_recommended_ids', action="store_true", default=False)

parser.add_argument('--is_draw_bar', dest='draw_bar', action='store_true')
parser.add_argument('--no_draw_bar', dest='draw_bar', action='store_false')
Expand Down
1 change: 1 addition & 0 deletions examples/policy/run_A2C.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def setup_policy_model(args, state_tracker, train_envs, test_envs_dict):
VectorReplayBuffer(args.buffer_size, len(train_envs)),
# preprocess_fn=state_tracker.build_state,
exploration_noise=args.exploration_noise,
remove_recommended_ids = args.remove_recommended_ids
)

test_collector_set = CollectorSet(rec_policy, test_envs_dict, args.buffer_size, args.test_num,
Expand Down
1 change: 1 addition & 0 deletions examples/policy/run_C51.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def setup_policy_model(args, state_tracker, train_envs, test_envs_dict):
buffer=buf,
# preprocess_fn=state_tracker.build_state,
exploration_noise=args.exploration_noise,
remove_recommended_ids = args.remove_recommended_ids
)
# train_collector.collect(n_step=args.batch_size * args.training_num) ## TODO

Expand Down
144 changes: 144 additions & 0 deletions examples/policy/run_ContinuousA2C.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import argparse
import os
import sys
import traceback
from gymnasium.spaces import Box
from torch.distributions import Independent, Normal

import torch

sys.path.extend([".", "./src", "./src/DeepCTR-Torch", "./src/tianshou"])

from policy_utils import get_args_all, learn_policy, prepare_dir_log, prepare_user_model, prepare_train_envs, prepare_test_envs, setup_state_tracker

# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

from core.collector.collector_set import CollectorSet
from core.util.data import get_env_args
from core.collector.collector import Collector
from core.policy.RecPolicy import RecPolicy


from tianshou.data import VectorReplayBuffer

from tianshou.utils.net.common import ActorCritic, Net
from tianshou.utils.net.continuous import ActorProb, Critic
from tianshou.policy import A2CPolicy

# from util.upload import my_upload
import logzero

try:
import envpool
except ImportError:
envpool = None


def get_args_A2C():
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", type=str, default="ContinuousA2C")
parser.add_argument('--vf-coef', type=float, default=0.5)
parser.add_argument('--ent-coef', type=float, default=0.0)
parser.add_argument('--max-grad-norm', type=float, default=None)
parser.add_argument('--gae-lambda', type=float, default=1.)
parser.add_argument('--rew-norm', action="store_true", default=False)

parser.add_argument("--read_message", type=str, default="UM")
parser.add_argument("--message", type=str, default="ContinuousA2C")

args = parser.parse_known_args()[0]
return args


def setup_policy_model(args, state_tracker, train_envs, test_envs_dict):
if args.cpu:
args.device = "cpu"
else:
args.device = torch.device("cuda:{}".format(args.cuda) if torch.cuda.is_available() else "cpu")

# model
net = Net(args.state_dim, hidden_sizes=args.hidden_sizes, device=args.device)

actor = ActorProb(net, state_tracker.emb_dim, unbounded=True,
device=args.device).to(args.device)
critic = Critic(
Net(args.state_dim, hidden_sizes=args.hidden_sizes, device=args.device),
device=args.device
).to(args.device)

optim_RL = torch.optim.Adam(ActorCritic(actor, critic).parameters(), lr=args.lr)
optim_state = torch.optim.Adam(state_tracker.parameters(), lr=args.lr)
optim = [optim_RL, optim_state]

def dist(*logits):
return Independent(Normal(*logits), 1)

policy = A2CPolicy(
actor,
critic,
optim,
dist,
state_tracker=state_tracker,
discount_factor=args.gamma,
gae_lambda=args.gae_lambda,
vf_coef=args.vf_coef,
ent_coef=args.ent_coef,
max_grad_norm=args.max_grad_norm,
reward_normalization=args.rew_norm,
action_space=Box(shape=(state_tracker.emb_dim,), low=0, high=1),
action_bound_method="", # not clip
action_scaling=False
)

rec_policy = RecPolicy(args, policy, state_tracker)

# Prepare the collectors and logs
train_collector = Collector(
rec_policy, train_envs,
VectorReplayBuffer(args.buffer_size, len(train_envs)),
# preprocess_fn=state_tracker.build_state,
exploration_noise=args.exploration_noise,
remove_recommended_ids = args.remove_recommended_ids
)

test_collector_set = CollectorSet(rec_policy, test_envs_dict, args.buffer_size, args.test_num,
# preprocess_fn=state_tracker.build_state,
exploration_noise=args.exploration_noise,
force_length=args.force_length)

return rec_policy, train_collector, test_collector_set, optim





def main(args):
# %% 1. Prepare the saved path.
MODEL_SAVE_PATH, logger_path = prepare_dir_log(args)

# %% 2. Prepare user model and environment
ensemble_models = prepare_user_model(args)
env, dataset, train_envs = prepare_train_envs(args, ensemble_models)
test_envs_dict = prepare_test_envs(args)

# %% 3. Setup policy
state_tracker = setup_state_tracker(args, ensemble_models, env, train_envs, test_envs_dict)
policy, train_collector, test_collector_set, optim = setup_policy_model(args, state_tracker, train_envs, test_envs_dict)

# %% 4. Learn policy
learn_policy(args, env, dataset, policy, train_collector, test_collector_set, state_tracker, optim, MODEL_SAVE_PATH,
logger_path)


if __name__ == '__main__':
args_all = get_args_all()
args = get_env_args(args_all)
args_A2C = get_args_A2C()
args_all.__dict__.update(args.__dict__)
args_all.__dict__.update(args_A2C.__dict__)
try:
main(args_all)
except Exception as e:
var = traceback.format_exc()
print(var)
logzero.logger.error(var)
196 changes: 196 additions & 0 deletions examples/policy/run_ContinuousBCQ.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
import argparse
import functools
import os
import pprint
import sys
import traceback
from gymnasium.spaces import Box

import torch

sys.path.extend([".", "./src", "./src/DeepCTR-Torch", "./src/tianshou"])

from policy_utils import get_args_all, learn_policy, prepare_dir_log, prepare_user_model, prepare_buffer_via_offline_data, prepare_test_envs, setup_state_tracker

# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

from core.collector.collector_set import CollectorSet
from core.evaluation.evaluator import Evaluator_Feat, Evaluator_Coverage_Count, Evaluator_User_Experience, save_model_fn
from core.evaluation.loggers import LoggerEval_Policy
from core.util.data import get_env_args
from core.policy.RecPolicy import RecPolicy

from tianshou.utils.net.common import MLP, Net
from tianshou.utils.net.continuous import VAE, Critic, Perturbation
from tianshou.policy import BCQPolicy
from tianshou.trainer import offline_trainer

# from util.upload import my_upload
import logzero
from logzero import logger

try:
import envpool
except ImportError:
envpool = None


def get_args_BCQ():
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", type=str, default="ContinuousBCQ")
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64])
parser.add_argument('--actor-lr', type=float, default=1e-3)
parser.add_argument('--critic-lr', type=float, default=1e-3)

parser.add_argument("--n-step", type=int, default=3)
parser.add_argument('--step-per-epoch', type=int, default=1000)
# parser.add_argument("--update-per-epoch", type=int, default=5000)

parser.add_argument("--vae-hidden-sizes", type=int, nargs='*', default=[32, 32])
# default to 2 * action_dim
parser.add_argument('--latent_dim', type=int, default=None)
parser.add_argument("--gamma", default=0.99)
parser.add_argument("--tau", default=0.005)
# Weighting for Clipped Double Q-learning in BCQ
parser.add_argument("--lmbda", default=0.75)
# Max perturbation hyper-parameter for BCQ
parser.add_argument("--phi", default=0.05)

parser.add_argument("--read_message", type=str, default="UM")
parser.add_argument("--message", type=str, default="ContinuousBCQ")

args = parser.parse_known_args()[0]
return args


def setup_policy_model(args, state_tracker, buffer, test_envs_dict):
if args.cpu:
args.device = "cpu"
else:
args.device = torch.device("cuda:{}".format(args.cuda) if torch.cuda.is_available() else "cpu")

args.action_dim = state_tracker.emb_dim
args.max_action = 1.
print("args.action_dim", args.action_dim)
# model
# perturbation network
net_a = MLP(
input_dim=args.state_dim + args.action_dim,
output_dim=args.action_dim,
hidden_sizes=args.hidden_sizes,
device=args.device,
)
actor = Perturbation(
net_a, max_action=args.max_action, device=args.device, phi=args.phi
).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)

net_c1 = Net(
args.state_dim,
args.action_dim,
hidden_sizes=args.hidden_sizes,
concat=True,
device=args.device,
)
net_c2 = Net(
args.state_dim,
args.action_dim,
hidden_sizes=args.hidden_sizes,
concat=True,
device=args.device,
)
critic1 = Critic(net_c1, device=args.device).to(args.device)
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
critic2 = Critic(net_c2, device=args.device).to(args.device)
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)

# vae
# output_dim = 0, so the last Module in the encoder is ReLU
vae_encoder = MLP(
input_dim=args.state_dim + args.action_dim,
hidden_sizes=args.vae_hidden_sizes,
device=args.device,
)
if not args.latent_dim:
args.latent_dim = args.action_dim * 2
vae_decoder = MLP(
input_dim=args.state_dim + args.latent_dim,
output_dim=args.action_dim,
hidden_sizes=args.vae_hidden_sizes,
device=args.device,
)
vae = VAE(
vae_encoder,
vae_decoder,
hidden_dim=args.vae_hidden_sizes[-1],
latent_dim=args.latent_dim,
max_action=args.max_action,
device=args.device,
).to(args.device)
vae_optim = torch.optim.Adam(vae.parameters())

optim_state = torch.optim.Adam(state_tracker.parameters(), lr=args.lr)
optim = [actor_optim, optim_state]

policy = BCQPolicy(
actor,
actor_optim,
critic1,
critic1_optim,
critic2,
critic2_optim,
vae,
vae_optim,
optim_state,
device=args.device,
gamma=args.gamma,
tau=args.tau,
lmbda=args.lmbda,
state_tracker=state_tracker,
buffer=buffer,
action_space=Box(shape=(state_tracker.emb_dim,), low=0, high=1),
)

rec_policy = RecPolicy(args, policy, state_tracker)

# collector
# buffer has been gathered

test_collector_set = CollectorSet(rec_policy, test_envs_dict, args.buffer_size, args.test_num,
# preprocess_fn=state_tracker.build_state,
exploration_noise=args.exploration_noise,
force_length=args.force_length)

return rec_policy, test_collector_set, optim



def main(args):
# %% 1. Prepare the saved path.
MODEL_SAVE_PATH, logger_path = prepare_dir_log(args)

# %% 2. Prepare user model and environment
ensemble_models = prepare_user_model(args)
env, dataset, buffer = prepare_buffer_via_offline_data(args)
test_envs_dict = prepare_test_envs(args)

# %% 3. Setup policy
state_tracker = setup_state_tracker(args, ensemble_models, env, buffer, test_envs_dict, use_buffer_in_train=True)
policy, test_collector_set, optim = setup_policy_model(args, state_tracker, buffer, test_envs_dict)

# %% 4. Learn policy
learn_policy(args, env, dataset, policy, buffer, test_collector_set, state_tracker, optim, MODEL_SAVE_PATH, logger_path, is_offline=True)


if __name__ == '__main__':
args_all = get_args_all()
args = get_env_args(args_all)
args_BCQ = get_args_BCQ()
args_all.__dict__.update(args.__dict__)
args_all.__dict__.update(args_BCQ.__dict__)
try:
main(args_all)
except Exception as e:
var = traceback.format_exc()
print(var)
logzero.logger.error(var)
Loading

0 comments on commit 3dc75df

Please sign in to comment.