forked from mees/calvin
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
8339e51
commit c7051d2
Showing
7 changed files
with
134 additions
and
187 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
256 changes: 87 additions & 169 deletions
256
calvin_models/calvin_agent/evaluation/evaluate_policy_singlestep.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,186 +1,104 @@ | ||
import argparse | ||
from collections import Counter | ||
import logging | ||
from pathlib import Path | ||
import typing | ||
|
||
from calvin_agent.evaluation.utils import format_sftp_path, get_checkpoint, imshow_tensor, print_task_log | ||
from calvin_agent.models.play_lmp import PlayLMP | ||
import time | ||
|
||
from calvin_agent.evaluation.multistep_sequences import get_sequences | ||
from calvin_agent.evaluation.utils import ( | ||
DefaultLangEmbeddings, | ||
get_default_model_and_env, | ||
get_eval_env_state, | ||
imshow_tensor, | ||
) | ||
from calvin_agent.utils.utils import get_last_checkpoint | ||
import hydra | ||
import numpy as np | ||
from omegaconf import DictConfig, OmegaConf | ||
from pytorch_lightning import seed_everything | ||
from termcolor import colored | ||
import torch | ||
from tqdm.auto import tqdm | ||
|
||
from calvin_env.envs.play_table_env import get_env | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@hydra.main(config_path="../../conf/inference", config_name="config_inference") | ||
def test_policy(input_cfg: DictConfig) -> None: | ||
""" | ||
Run inference on trained policy. | ||
Arguments: | ||
train_folder (str): path of trained model. | ||
load_checkpoint (str): optional model checkpoint. If not specified, the last checkpoint is taken by default. | ||
+datamodule.root_data_dir (str): /path/dataset when running inference on another machine than were it was trained | ||
visualize (bool): wether to visualize the policy rollouts (default True). | ||
""" | ||
# when mounting remote folder with sftp, format path | ||
format_sftp_path(input_cfg) | ||
# load config used during training | ||
train_cfg_path = Path(input_cfg.train_folder) / ".hydra/config.yaml" | ||
train_cfg = OmegaConf.load(train_cfg_path) | ||
|
||
# merge configs to keep current cmd line overrides | ||
cfg = OmegaConf.merge(train_cfg, input_cfg) | ||
seed_everything(cfg.seed) | ||
|
||
device = torch.device("cuda:0") | ||
|
||
checkpoint = get_checkpoint(cfg) | ||
task_to_id_dict = torch.load(checkpoint)["task_to_id_dict"] | ||
id_to_task_dict = torch.load(checkpoint)["id_to_task_dict"] | ||
|
||
# since we don't use the trainer during inference, manually set up data_module | ||
data_module = hydra.utils.instantiate(cfg.datamodule, num_workers=4) | ||
data_module.prepare_data() | ||
data_module.setup() | ||
dataloader = data_module.val_dataloader() | ||
dataset = dataloader.dataset.datasets["vis"] | ||
lang_dataset = dataloader.dataset.datasets["lang"] | ||
env = hydra.utils.instantiate(cfg.callbacks.rollout.env_cfg, dataset, device, show_gui=False) | ||
|
||
embeddings = np.load( | ||
lang_dataset.abs_datasets_dir / lang_dataset.lang_folder / "embeddings.npy", allow_pickle=True | ||
).item() | ||
|
||
task_checker = hydra.utils.instantiate(cfg.callbacks.rollout.tasks) | ||
|
||
logger.info("Loading model from checkpoint.") | ||
model = PlayLMP.load_from_checkpoint(checkpoint) | ||
model.freeze() | ||
if train_cfg.model.decoder.get("load_action_bounds", False): | ||
model.action_decoder._setup_action_bounds(cfg.datamodule.root_data_dir, None, None, True) | ||
model = model.cuda(device) | ||
|
||
logger.info("Successfully loaded model.") | ||
demo_task_counter = Counter() # type: typing.Counter[str] | ||
live_task_counter = Counter() # type: typing.Counter[str] | ||
for task_name, ids in task_to_id_dict.items(): | ||
print() | ||
print(f"Evaluate {task_name}: {embeddings[task_name]['ann']}") | ||
print() | ||
def evaluate_policy(model, env, datamodule, lang_embeddings, args): | ||
conf_dir = Path(__file__).absolute().parents[2] / "conf" | ||
task_cfg = OmegaConf.load(conf_dir / "callbacks/rollout/tasks/new_playtable_tasks.yaml") | ||
task_oracle = hydra.utils.instantiate(task_cfg) | ||
val_annotations = OmegaConf.load(conf_dir / "annotations/new_playtable_validation.yaml") | ||
|
||
task_to_id_dict = torch.load(args.checkpoint)["task_to_id_dict"] | ||
dataset = datamodule.val_dataloader().dataset.datasets["vis"] | ||
|
||
results = Counter() | ||
|
||
for task, ids in task_to_id_dict.items(): | ||
for i in ids: | ||
episode = dataset[int(i)] | ||
rollout( | ||
model=model, | ||
episode=episode, | ||
env=env, | ||
tasks=task_checker, | ||
demo_task_counter=demo_task_counter, | ||
live_task_counter=live_task_counter, | ||
modalities=["lang"], | ||
cfg=cfg, | ||
device=device, | ||
id_to_task_dict=id_to_task_dict, | ||
embeddings=embeddings, | ||
) | ||
print_task_log(demo_task_counter, live_task_counter, "lang") | ||
|
||
|
||
def rollout( | ||
model, | ||
episode, | ||
env, | ||
tasks, | ||
demo_task_counter, | ||
live_task_counter, | ||
modalities, | ||
cfg, | ||
device, | ||
id_to_task_dict=None, | ||
embeddings=None, | ||
): | ||
""" | ||
Args: | ||
model: PlayLMP model | ||
episode: Batch from dataloader | ||
state_obs: Tensor, | ||
rgb_obs: tuple(Tensor, ), | ||
depth_obs: tuple(Tensor, ), | ||
actions: Tensor, | ||
lang: Tensor, | ||
reset_info: Dict | ||
idx: int | ||
env: play_lmp_wrapper(play_table_env) | ||
tasks: Tasks | ||
demo_task_counter: Counter[str] | ||
live_task_counter: Counter[str] | ||
visualize: visualize images | ||
""" | ||
state_obs, rgb_obs, depth_obs, actions, _, reset_info, idx = episode | ||
seq_len_max = state_obs.shape[0] - 1 | ||
for mod in modalities: | ||
groundtruth_task = id_to_task_dict[int(idx)] | ||
# reset env to state of first step in the episode | ||
obs = env.reset(robot_obs=reset_info["robot_obs"][0], scene_obs=reset_info["scene_obs"][0]) | ||
start_info = env.get_info() | ||
demo_task_counter += Counter(groundtruth_task) | ||
current_img_obs = obs["rgb_obs"] | ||
current_depth_obs = obs["depth_obs"] | ||
current_state_obs = obs["state_obs"] | ||
|
||
start_img_obs = [img.clone() for img in current_img_obs] | ||
|
||
# goal image is last step of the episode | ||
|
||
_task = np.random.choice(list(groundtruth_task)) | ||
task_embeddings = embeddings[_task]["emb"] | ||
goal_lang = torch.from_numpy(embeddings[_task]["emb"]).to(device).squeeze(0) | ||
|
||
# goal image is last step of the episode | ||
goal_imgs = [rgb_ob[-1].unsqueeze(0).to(device) for rgb_ob in rgb_obs] | ||
goal_depths = [depth_ob[-1].unsqueeze(0).to(device) for depth_ob in depth_obs] | ||
goal_state = state_obs[-1].unsqueeze(0).to(device) | ||
|
||
for step in range(cfg.ep_len): | ||
# replan every replan_freq steps (default 30 i.e every second) | ||
if step % cfg.replan_freq == 0: | ||
if mod == "lang": | ||
plan, latent_goal = model.get_pp_plan_lang( | ||
current_img_obs, current_depth_obs, current_state_obs, goal_lang | ||
) # type: ignore | ||
else: | ||
plan, latent_goal = model.get_pp_plan_vision( | ||
current_img_obs, | ||
current_depth_obs, | ||
goal_imgs, | ||
goal_depths, | ||
current_state_obs, | ||
goal_state, | ||
) # type: ignore | ||
if cfg.visualize: | ||
imshow_tensor("start_img", start_img_obs[0], wait=1) | ||
imshow_tensor("goal_img", goal_imgs[0], wait=1) | ||
imshow_tensor("current_img", current_img_obs[0], wait=1) | ||
imshow_tensor("dataset_img", rgb_obs[0][np.clip(step, 0, seq_len_max)], wait=1) | ||
|
||
# use plan to predict actions with current observations | ||
action = model.predict_with_plan(current_img_obs, current_depth_obs, current_state_obs, latent_goal, plan) | ||
obs, _, _, current_info = env.step(action) | ||
# check if current step solves a task | ||
current_task_info = tasks.get_task_info_for_set(start_info, current_info, groundtruth_task) | ||
# check if a task was achieved and if that task is a subset of the original tasks | ||
# we do not just want to solve any task, we want to solve the task that was proposed | ||
if len(current_task_info) > 0: | ||
live_task_counter += Counter(current_task_info) | ||
# skip current sequence if task was achieved | ||
break | ||
# update current observation | ||
current_img_obs = obs["rgb_obs"] | ||
current_depth_obs = obs["depth_obs"] | ||
current_state_obs = obs["state_obs"] | ||
results[task] += rollout(env, model, episode, task_oracle, args, task, lang_embeddings, val_annotations) | ||
print(f"{task}: {results[task]} / {len(ids)}") | ||
|
||
print(f"SR: {sum(results.values()) / sum(len(x) for x in task_to_id_dict.values()) * 100:.1f}%") | ||
|
||
|
||
def rollout(env, model, episode, task_oracle, args, task, lang_embeddings, val_annotations): | ||
state_obs, rgb_obs, depth_obs = episode["robot_obs"], episode["rgb_obs"], episode["depth_obs"] | ||
reset_info = episode["state_info"] | ||
idx = episode["idx"] | ||
obs = env.reset(robot_obs=reset_info["robot_obs"][0], scene_obs=reset_info["scene_obs"][0]) | ||
# get lang annotation for subtask | ||
lang_annotation = val_annotations[task][0] | ||
# get language goal embedding | ||
goal = lang_embeddings.get_lang_goal(lang_annotation) | ||
model.reset() | ||
start_info = env.get_info() | ||
|
||
for step in range(args.ep_len): | ||
action = model.step(obs, goal) | ||
obs, _, _, current_info = env.step(action) | ||
if args.debug: | ||
env.render() | ||
# time.sleep(0.1) | ||
# check if current step solves a task | ||
current_task_info = task_oracle.get_task_info_for_set(start_info, current_info, {task}) | ||
if len(current_task_info) > 0: | ||
if args.debug: | ||
print(colored("S", "green"), end=" ") | ||
return True | ||
if args.debug: | ||
print(colored("F", "red"), end=" ") | ||
return False | ||
|
||
|
||
if __name__ == "__main__": | ||
test_policy() | ||
parser = argparse.ArgumentParser(description="Evaluate a trained model on multistep sequences with language goals.") | ||
parser.add_argument("--dataset_path", type=str, help="Path to the dataset root directory.") | ||
|
||
# arguments for loading default model | ||
parser.add_argument( | ||
"--train_folder", type=str, help="If calvin_agent was used to train, specify path to the log dir." | ||
) | ||
parser.add_argument( | ||
"--checkpoint", | ||
type=str, | ||
default=None, | ||
help="Manually specify checkpoint path (default is latest). Only used for calvin_agent.", | ||
) | ||
|
||
parser.add_argument("--debug", action="store_true", help="Print debug info and visualize environment.") | ||
|
||
args = parser.parse_args() | ||
|
||
# Do not change | ||
args.ep_len = 240 | ||
model, env, datamodule = get_default_model_and_env(args.train_folder, args.dataset_path, args.checkpoint) | ||
|
||
if args.checkpoint is None: | ||
args.checkpoint = get_last_checkpoint(Path(args.train_folder)) | ||
|
||
lang_embeddings = DefaultLangEmbeddings(args.dataset_path) # type: ignore | ||
evaluate_policy(model, env, datamodule, lang_embeddings, args) |
Oops, something went wrong.