From 00592b210314141ae001fb6a50d2ff756bade4d6 Mon Sep 17 00:00:00 2001 From: YawenDuan <35738259+kmdanielduan@users.noreply.github.com> Date: Wed, 10 Aug 2022 18:15:30 -0700 Subject: [PATCH 1/6] add video saving and uploading support to add train_* scripts --- .../algorithms/preference_comparisons.py | 8 +- src/imitation/scripts/common/common.py | 4 +- src/imitation/scripts/common/train.py | 4 + .../config/train_preference_comparisons.py | 3 +- src/imitation/scripts/train_adversarial.py | 32 ++++++-- src/imitation/scripts/train_imitation.py | 12 +++ .../scripts/train_preference_comparisons.py | 57 +++++++++----- src/imitation/scripts/train_rl.py | 23 ++++++ src/imitation/util/logger.py | 5 +- src/imitation/util/video_wrapper.py | 74 ++++++++++++++++++- tests/scripts/test_scripts.py | 27 ++++++- 11 files changed, 211 insertions(+), 38 deletions(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index ee94b77b4..a2c8aef74 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -1158,6 +1158,7 @@ def __init__( # for keeping track of the global iteration, in case train() is called # multiple times self._iteration = 0 + self.trajectory_generator_num_steps = 0 self.model = reward_model @@ -1202,7 +1203,7 @@ def train( self, total_timesteps: int, total_comparisons: int, - callback: Optional[Callable[[int], None]] = None, + callback: Optional[Callable[[int, int], None]] = None, ) -> Mapping[str, Any]: """Train the reward model and the policy if applicable. @@ -1286,14 +1287,15 @@ def train( with self.logger.accumulate_means("agent"): self.logger.log(f"Training agent for {num_steps} timesteps") self.trajectory_generator.train(steps=num_steps) + self.trajectory_generator_num_steps += num_steps self.logger.dump(self._iteration) ######################## # Additional Callbacks # ######################## - if callback: - callback(self._iteration) self._iteration += 1 + if callback: + callback(self._iteration, self.trajectory_generator_num_steps) return {"reward_loss": reward_loss, "reward_accuracy": reward_accuracy} diff --git a/src/imitation/scripts/common/common.py b/src/imitation/scripts/common/common.py index 3a20c59e0..f3cec501d 100644 --- a/src/imitation/scripts/common/common.py +++ b/src/imitation/scripts/common/common.py @@ -2,7 +2,7 @@ import logging import os -from typing import Any, Mapping, Sequence, Tuple, Union +from typing import Any, Mapping, Optional, Sequence, Tuple, Union import sacred from stable_baselines3.common import vec_env @@ -131,7 +131,7 @@ def make_venv( env_name: str, num_vec: int, parallel: bool, - log_dir: str, + log_dir: Optional[str], max_episode_steps: int, env_make_kwargs: Mapping[str, Any], **kwargs, diff --git a/src/imitation/scripts/common/train.py b/src/imitation/scripts/common/train.py index 88413fdf5..09ce07c34 100644 --- a/src/imitation/scripts/common/train.py +++ b/src/imitation/scripts/common/train.py @@ -23,6 +23,10 @@ def config(): # Evaluation n_episodes_eval = 50 # Num of episodes for final mean ground truth return + # Visualization + videos = False # save video files + video_kwargs = {} # arguments to VideoWrapper + locals() # quieten flake8 diff --git a/src/imitation/scripts/config/train_preference_comparisons.py b/src/imitation/scripts/config/train_preference_comparisons.py index 469abd983..482444ba3 100644 --- a/src/imitation/scripts/config/train_preference_comparisons.py +++ b/src/imitation/scripts/config/train_preference_comparisons.py @@ -28,7 +28,7 @@ def train_defaults(): fragment_length = 100 # timesteps per fragment used for comparisons total_timesteps = int(1e6) # total number of environment timesteps total_comparisons = 5000 # total number of comparisons to elicit - num_iterations = 5 # Arbitrary, should be tuned for the task + num_iterations = 50 # Arbitrary, should be tuned for the task comparison_queue_size = None # factor by which to oversample transitions before creating fragments transition_oversampling = 1 @@ -39,6 +39,7 @@ def train_defaults(): cross_entropy_loss_kwargs = {} reward_trainer_kwargs = { "epochs": 3, + "weight_decay": 0.0, } save_preferences = False # save preference dataset at the end? agent_path = None # path to a (partially) trained agent to load at the beginning diff --git a/src/imitation/scripts/train_adversarial.py b/src/imitation/scripts/train_adversarial.py index a646391f0..b74009cf2 100644 --- a/src/imitation/scripts/train_adversarial.py +++ b/src/imitation/scripts/train_adversarial.py @@ -9,6 +9,7 @@ import sacred.commands import torch as th from sacred.observers import FileStorageObserver +from stable_baselines3.common import vec_env from imitation.algorithms.adversarial import airl as airl_algo from imitation.algorithms.adversarial import common @@ -18,21 +19,33 @@ from imitation.scripts.common import common as common_config from imitation.scripts.common import demonstrations, reward, rl, train from imitation.scripts.config.train_adversarial import train_adversarial_ex +from imitation.util import video_wrapper logger = logging.getLogger("imitation.scripts.train_adversarial") -def save(trainer, save_path): +def save( + _config: Mapping[str, Any], + trainer: common.AdversarialTrainer, + save_path: str, + eval_venv: vec_env.VecEnv, +) -> None: """Save discriminator and generator.""" # We implement this here and not in Trainer since we do not want to actually # serialize the whole Trainer (including e.g. expert demonstrations). os.makedirs(save_path, exist_ok=True) th.save(trainer.reward_train, os.path.join(save_path, "reward_train.pt")) th.save(trainer.reward_test, os.path.join(save_path, "reward_test.pt")) - serialize.save_stable_model( - os.path.join(save_path, "gen_policy"), - trainer.gen_algo, - ) + policy_path = os.path.join(save_path, "gen_policy") + serialize.save_stable_model(policy_path, trainer.gen_algo) + if _config["train"]["videos"]: + video_wrapper.record_and_save_video( + output_dir=policy_path, + policy=trainer.gen_algo.policy, + eval_venv=eval_venv, + video_kwargs=_config["train"]["video_kwargs"], + logger=trainer.logger, + ) def _add_hook(ingredient: sacred.Ingredient) -> None: @@ -68,6 +81,7 @@ def dummy_config(): def train_adversarial( _run, _seed: int, + _config: Mapping[str, Any], show_config: bool, algo_cls: Type[common.AdversarialTrainer], algorithm_kwargs: Mapping[str, Any], @@ -85,6 +99,7 @@ def train_adversarial( Args: _seed: Random seed. + _config: Sacred configuration dict. show_config: Print the merged config before starting training. This is analogous to the print_config command, but will show config after rather than before merging `algorithm_specific` arguments. @@ -117,6 +132,7 @@ def train_adversarial( expert_trajs = demonstrations.load_expert_trajs() venv = common_config.make_venv() + eval_venv = common_config.make_venv(log_dir=None) reward_net = reward.make_reward_net(venv) relabel_reward_fn = functools.partial( @@ -153,13 +169,15 @@ def train_adversarial( def callback(round_num): if checkpoint_interval > 0 and round_num % checkpoint_interval == 0: - save(trainer, os.path.join(log_dir, "checkpoints", f"{round_num:05d}")) + save_path = os.path.join(log_dir, "checkpoints", f"{round_num:05d}") + save(_config, trainer, save_path, eval_venv) trainer.train(total_timesteps, callback) # Save final artifacts. if checkpoint_interval >= 0: - save(trainer, os.path.join(log_dir, "checkpoints", "final")) + save_path = os.path.join(log_dir, "checkpoints", "final") + save(_config, trainer, save_path, eval_venv) return { "imit_stats": train.eval_policy(trainer.policy, trainer.venv_train), diff --git a/src/imitation/scripts/train_imitation.py b/src/imitation/scripts/train_imitation.py index ca444bbbf..46e120b5d 100644 --- a/src/imitation/scripts/train_imitation.py +++ b/src/imitation/scripts/train_imitation.py @@ -14,6 +14,7 @@ from imitation.policies import serialize from imitation.scripts.common import common, demonstrations, train from imitation.scripts.config.train_imitation import train_imitation_ex +from imitation.util import video_wrapper logger = logging.getLogger(__name__) @@ -98,6 +99,7 @@ def load_expert_policy( @train_imitation_ex.capture def train_imitation( _run, + _config: Mapping[str, Any], bc_kwargs: Mapping[str, Any], bc_train_kwargs: Mapping[str, Any], dagger: Mapping[str, Any], @@ -120,6 +122,7 @@ def train_imitation( """ custom_logger, log_dir = common.setup_logging() venv = common.make_venv() + eval_venv = common.make_venv(log_dir=None) imit_policy = make_policy(venv, agent_path=agent_path) expert_trajs = None @@ -163,6 +166,15 @@ def train_imitation( # TODO(adam): add checkpointing to BC? bc_trainer.save_policy(policy_path=osp.join(log_dir, "final.th")) + if _config["train"]["videos"]: + video_wrapper.record_and_save_video( + output_dir=log_dir, + policy=imit_policy, + eval_venv=eval_venv, + video_kwargs=_config["train"]["video_kwargs"], + logger=custom_logger, + ) + return { "imit_stats": train.eval_policy(imit_policy, venv), "expert_stats": rollout.rollout_stats( diff --git a/src/imitation/scripts/train_preference_comparisons.py b/src/imitation/scripts/train_preference_comparisons.py index 9d2691113..38ad2e7c0 100644 --- a/src/imitation/scripts/train_preference_comparisons.py +++ b/src/imitation/scripts/train_preference_comparisons.py @@ -6,11 +6,12 @@ import functools import os +import os.path as osp from typing import Any, Mapping, Optional, Type, Union import torch as th from sacred.observers import FileStorageObserver -from stable_baselines3.common import type_aliases +from stable_baselines3.common import type_aliases, vec_env from imitation.algorithms import preference_comparisons from imitation.data import types @@ -21,33 +22,37 @@ from imitation.scripts.config.train_preference_comparisons import ( train_preference_comparisons_ex, ) - - -def save_model( - agent_trainer: preference_comparisons.AgentTrainer, - save_path: str, -): - """Save the model as model.pkl.""" - serialize.save_stable_model( - output_dir=os.path.join(save_path, "policy"), - model=agent_trainer.algorithm, - ) +from imitation.util import video_wrapper def save_checkpoint( + _config: Mapping[str, Any], trainer: preference_comparisons.PreferenceComparisons, save_path: str, allow_save_policy: Optional[bool], -): + eval_venv: vec_env.VecEnv, +) -> None: """Save reward model and optionally policy.""" os.makedirs(save_path, exist_ok=True) - th.save(trainer.model, os.path.join(save_path, "reward_net.pt")) + th.save(trainer.model, osp.join(save_path, "reward_net.pt")) if allow_save_policy: # Note: We should only save the model as model.pkl if `trajectory_generator` # contains one. Specifically we check if the `trajectory_generator` contains an # `algorithm` attribute. assert hasattr(trainer.trajectory_generator, "algorithm") - save_model(trainer.trajectory_generator, save_path) + policy_dir = osp.join(save_path, "policy") + serialize.save_stable_model( + output_dir=policy_dir, + model=trainer.trajectory_generator.algorithm, + ) + if _config["train"]["videos"]: + video_wrapper.record_and_save_video( + output_dir=policy_dir, + policy=trainer.trajectory_generator.algorithm.policy, + eval_venv=eval_venv, + video_kwargs=_config["train"]["video_kwargs"], + logger=trainer.logger, + ) else: trainer.logger.warn( "trainer.trajectory_generator doesn't contain a policy to save.", @@ -57,6 +62,7 @@ def save_checkpoint( @train_preference_comparisons_ex.main def train_preference_comparisons( _seed: int, + _config: Mapping[str, Any], total_timesteps: int, total_comparisons: int, num_iterations: int, @@ -82,6 +88,7 @@ def train_preference_comparisons( Args: _seed: Random seed. + _config: Sacred configuration dict. total_timesteps: number of environment interaction steps total_comparisons: number of preferences to gather in total num_iterations: number of times to train the agent against the reward model @@ -140,6 +147,7 @@ def train_preference_comparisons( """ custom_logger, log_dir = common.setup_logging() venv = common.make_venv() + eval_venv = common.make_venv(log_dir=None) reward_net = reward.make_reward_net(venv) relabel_reward_fn = functools.partial( @@ -220,12 +228,19 @@ def train_preference_comparisons( query_schedule=query_schedule, ) - def save_callback(iteration_num): + def save_callback(iteration_num, traj_generator_num_steps): if checkpoint_interval > 0 and iteration_num % checkpoint_interval == 0: + save_path = osp.join( + log_dir, + "checkpoints", + f"iter_{iteration_num:04d}_step_{traj_generator_num_steps:08d}", + ) save_checkpoint( + _config, trainer=main_trainer, - save_path=os.path.join(log_dir, "checkpoints", f"{iteration_num:04d}"), + save_path=save_path, allow_save_policy=bool(trajectory_path is None), + eval_venv=eval_venv, ) results = main_trainer.train( @@ -235,14 +250,16 @@ def save_callback(iteration_num): ) if save_preferences: - main_trainer.dataset.save(os.path.join(log_dir, "preferences.pkl")) + main_trainer.dataset.save(osp.join(log_dir, "preferences.pkl")) # Save final artifacts. if checkpoint_interval >= 0: save_checkpoint( + _config, trainer=main_trainer, - save_path=os.path.join(log_dir, "checkpoints", "final"), + save_path=osp.join(log_dir, "checkpoints", "final"), allow_save_policy=bool(trajectory_path is None), + eval_venv=eval_venv, ) # Storing and evaluating policy only useful if we actually generate trajectory data @@ -255,7 +272,7 @@ def save_callback(iteration_num): def main_console(): observer = FileStorageObserver( - os.path.join("output", "sacred", "train_preference_comparisons"), + osp.join("output", "sacred", "train_preference_comparisons"), ) train_preference_comparisons_ex.observers.append(observer) train_preference_comparisons_ex.run_commandline() diff --git a/src/imitation/scripts/train_rl.py b/src/imitation/scripts/train_rl.py index 389d7e1a0..63f77feda 100644 --- a/src/imitation/scripts/train_rl.py +++ b/src/imitation/scripts/train_rl.py @@ -24,10 +24,12 @@ from imitation.rewards.serialize import load_reward from imitation.scripts.common import common, rl, train from imitation.scripts.config.train_rl import train_rl_ex +from imitation.util import video_wrapper @train_rl_ex.main def train_rl( + _config: Mapping[str, Any], *, total_timesteps: int, normalize_reward: bool, @@ -96,6 +98,8 @@ def train_rl( venv = common.make_venv( post_wrappers=[lambda env, idx: wrappers.RolloutInfoWrapper(env)], ) + eval_venv = common.make_venv(log_dir=None) + callback_objs = [] if reward_type is not None: reward_fn = load_reward(reward_type, reward_path, venv, **load_reward_kwargs) @@ -122,6 +126,18 @@ def train_rl( save_policy_callback, ) callback_objs.append(save_policy_callback) + + if _config["train"]["videos"]: + save_video_callback = video_wrapper.SaveVideoCallback( + policy_dir, + eval_venv, + video_kwargs=_config["train"]["video_kwargs"], + ) + save_video_callback = callbacks.EveryNTimesteps( + policy_save_interval, + save_video_callback, + ) + callback_objs.append(save_video_callback) callback = callbacks.CallbackList(callback_objs) if agent_path is None: @@ -142,6 +158,13 @@ def train_rl( if policy_save_final: output_dir = os.path.join(policy_dir, "final") serialize.save_stable_model(output_dir, rl_algo) + video_wrapper.record_and_save_video( + output_dir=output_dir, + policy=rl_algo.policy, + eval_venv=eval_venv, + video_kwargs=_config["train"]["video_kwargs"], + logger=rl_algo.logger, + ) # Final evaluation of expert policy. return train.eval_policy(rl_algo, venv) diff --git a/src/imitation/util/logger.py b/src/imitation/util/logger.py index 8875cb211..efa07740b 100644 --- a/src/imitation/util/logger.py +++ b/src/imitation/util/logger.py @@ -202,7 +202,10 @@ def write( if excluded is not None and "wandb" in excluded: continue - self.wandb_module.log({key: value}, step=step) + if key != "video": + self.wandb_module.log({key: value}, step=step) + else: + self.wandb_module.log({"video": self.wandb_module.Video(value)}) self.wandb_module.log({}, commit=True) def close(self) -> None: diff --git a/src/imitation/util/video_wrapper.py b/src/imitation/util/video_wrapper.py index 5fc2ae4f5..9924b8d88 100644 --- a/src/imitation/util/video_wrapper.py +++ b/src/imitation/util/video_wrapper.py @@ -1,11 +1,15 @@ """Wrapper to record rendered video frames from an environment.""" import os +import os.path as osp +from typing import Any, Mapping, Optional import gym +import stable_baselines3.common.logger as sb_logger from gym.wrappers.monitoring import video_recorder +from stable_baselines3.common import callbacks, policies, vec_env -from imitation.data import types +from imitation.data import rollout, types class VideoWrapper(gym.Wrapper): @@ -34,7 +38,7 @@ def __init__( self.single_video = single_video self.directory = os.path.abspath(directory) - os.makedirs(self.directory) + os.makedirs(self.directory, exist_ok=True) def _reset_video_recorder(self) -> None: """Creates a video recorder if one does not already exist. @@ -75,3 +79,69 @@ def close(self) -> None: self.video_recorder.close() self.video_recorder = None super().close() + + +def record_and_save_video( + output_dir: str, + policy: policies.BasePolicy, + eval_venv: vec_env.VecEnv, + video_kwargs: Mapping[str, Any], + logger: Optional[sb_logger.Logger] = None, +) -> None: + video_dir = osp.join(output_dir, "videos") + video_venv = VideoWrapper( + eval_venv, + directory=video_dir, + **video_kwargs, + ) + sample_until = rollout.make_sample_until(min_timesteps=None, min_episodes=2) + # video.{:06}".format(VideoWrapper.episode_id) will be saved within + # rollout.generate_trajectories() + rollout.generate_trajectories(policy, video_venv, sample_until) + assert "video.000000.mp4" in os.listdir(video_dir) + video_path = osp.join(video_dir, "video.000000.mp4") + if logger: + logger.record("video", video_path) + logger.log(f"Recording and saving video to {video_path} ...") + + +class SaveVideoCallback(callbacks.EventCallback): + """Saves the policy using `save_n_record_video` each time it is called. + + Should be used in conjunction with `callbacks.EveryNTimesteps` + or another event-based trigger. + """ + + def __init__( + self, + policy_dir: str, + eval_venv: vec_env.VecEnv, + video_kwargs: Mapping[str, Any], + *args, + **kwargs, + ): + """Builds SavePolicyCallback. + + Args: + policy_dir: Directory to save checkpoints. + eval_venv: Environment to evaluate the policy on. + video_kwargs: Keyword arguments to pass to `VideoWrapper`. + *args: Passed through to `callbacks.EventCallback`. + **kwargs: Passed through to `callbacks.EventCallback`. + """ + super().__init__(*args, **kwargs) + self.policy_dir = policy_dir + self.eval_venv = eval_venv + self.video_kwargs = video_kwargs + + def _on_step(self) -> bool: + output_dir = os.path.join(self.policy_dir, f"{self.num_timesteps:012d}") + record_and_save_video( + output_dir=output_dir, + policy=self.model.policy, + eval_venv=self.eval_venv, + video_kwargs=self.video_kwargs, + logger=self.model.logger, + ) + + return True diff --git a/tests/scripts/test_scripts.py b/tests/scripts/test_scripts.py index f7402d87e..4856ce47c 100644 --- a/tests/scripts/test_scripts.py +++ b/tests/scripts/test_scripts.py @@ -95,6 +95,8 @@ def test_main_console(script_mod): "common": dict(num_vec=8), } +_TRAIN_VIDEO_CONFIGS = {"train": {"videos": True}} + PREFERENCE_COMPARISON_CONFIGS = [ {}, { @@ -109,8 +111,9 @@ def test_main_console(script_mod): **_RL_AGENT_LOADING_CONFIGS, }, { + # Test that we can save checkpoints and videos "checkpoint_interval": 1, - # Test that we can save checkpoints + **_TRAIN_VIDEO_CONFIGS, }, ] @@ -356,7 +359,11 @@ def test_train_bc_warmstart(tmpdir): assert isinstance(run_warmstart.result, dict) -TRAIN_RL_PPO_CONFIGS = [{}, _RL_AGENT_LOADING_CONFIGS] +TRAIN_RL_PPO_CONFIGS = [ + {}, + _RL_AGENT_LOADING_CONFIGS, + _TRAIN_VIDEO_CONFIGS, +] @pytest.mark.parametrize("config", TRAIN_RL_PPO_CONFIGS) @@ -522,6 +529,22 @@ def test_train_adversarial_sac(tmpdir, command): _check_train_ex_result(run.result) +def test_train_adversarial_video_saving(tmpdir): + """Smoke test for imitation.scripts.train_adversarial.""" + named_configs = ["pendulum"] + ALGO_FAST_CONFIGS["adversarial"] + config_updates = { + "common": dict(log_root=tmpdir), + "demonstrations": dict(rollout_path=PENDULUM_TEST_ROLLOUT_PATH), + **_TRAIN_VIDEO_CONFIGS, + } + run = train_adversarial.train_adversarial_ex.run( + command_name="gail", + named_configs=named_configs, + config_updates=config_updates, + ) + assert run.status == "COMPLETED" + + def test_train_adversarial_algorithm_value_error(tmpdir): """Error on bad algorithm arguments.""" base_named_configs = ["cartpole"] + ALGO_FAST_CONFIGS["adversarial"] From 0f8c7d72c74f45fafb12651b0bea580436985d87 Mon Sep 17 00:00:00 2001 From: kmdanielduan <35738259+kmdanielduan@users.noreply.github.com> Date: Mon, 15 Aug 2022 17:43:59 +0000 Subject: [PATCH 2/6] add tests for video saving to test_scripts.py --- tests/scripts/test_scripts.py | 103 +++++++++++++++++++++++++--------- 1 file changed, 78 insertions(+), 25 deletions(-) diff --git a/tests/scripts/test_scripts.py b/tests/scripts/test_scripts.py index 4856ce47c..8e6fed308 100644 --- a/tests/scripts/test_scripts.py +++ b/tests/scripts/test_scripts.py @@ -8,6 +8,7 @@ import collections import filecmp import os +import os.path as osp import pathlib import pickle import shutil @@ -95,8 +96,6 @@ def test_main_console(script_mod): "common": dict(num_vec=8), } -_TRAIN_VIDEO_CONFIGS = {"train": {"videos": True}} - PREFERENCE_COMPARISON_CONFIGS = [ {}, { @@ -111,9 +110,8 @@ def test_main_console(script_mod): **_RL_AGENT_LOADING_CONFIGS, }, { - # Test that we can save checkpoints and videos + # Test that we can save checkpoints "checkpoint_interval": 1, - **_TRAIN_VIDEO_CONFIGS, }, ] @@ -359,11 +357,7 @@ def test_train_bc_warmstart(tmpdir): assert isinstance(run_warmstart.result, dict) -TRAIN_RL_PPO_CONFIGS = [ - {}, - _RL_AGENT_LOADING_CONFIGS, - _TRAIN_VIDEO_CONFIGS, -] +TRAIN_RL_PPO_CONFIGS = [{}, _RL_AGENT_LOADING_CONFIGS] @pytest.mark.parametrize("config", TRAIN_RL_PPO_CONFIGS) @@ -529,22 +523,6 @@ def test_train_adversarial_sac(tmpdir, command): _check_train_ex_result(run.result) -def test_train_adversarial_video_saving(tmpdir): - """Smoke test for imitation.scripts.train_adversarial.""" - named_configs = ["pendulum"] + ALGO_FAST_CONFIGS["adversarial"] - config_updates = { - "common": dict(log_root=tmpdir), - "demonstrations": dict(rollout_path=PENDULUM_TEST_ROLLOUT_PATH), - **_TRAIN_VIDEO_CONFIGS, - } - run = train_adversarial.train_adversarial_ex.run( - command_name="gail", - named_configs=named_configs, - config_updates=config_updates, - ) - assert run.status == "COMPLETED" - - def test_train_adversarial_algorithm_value_error(tmpdir): """Error on bad algorithm arguments.""" base_named_configs = ["cartpole"] + ALGO_FAST_CONFIGS["adversarial"] @@ -929,3 +907,78 @@ def test_convert_trajs(tmpdir: str): assert len(from_pkl) == len(from_npz) for t_pkl, t_npz in zip(from_pkl, from_npz): assert t_pkl == t_npz + + +_TRAIN_VIDEO_CONFIGS = {"train": {"videos": True}} +VIDEO_NAME = "video.000000.mp4" +VIDEO_PATH_DICT = dict( + rl=lambda d: osp.join(d, "policies", "final", "videos"), + adversarial=lambda d: osp.join(d, "checkpoints", "final", "gen_policy", "videos"), + pc=lambda d: osp.join(d, "checkpoints", "final", "policy", "videos"), + bc=lambda d: osp.join(d, "videos"), +) + + +def _check_video_exists(log_dir, algo): + video_dir = VIDEO_PATH_DICT[algo](log_dir) + assert os.path.exists(video_dir) + assert VIDEO_NAME in os.listdir(video_dir) + + +def test_train_rl_video_saving(tmpdir): + """Smoke test for imitation.scripts.train_rl.""" + config_updates = dict( + common=dict(log_root=tmpdir), + **_TRAIN_VIDEO_CONFIGS, + ) + run = train_rl.train_rl_ex.run( + named_configs=["cartpole"] + ALGO_FAST_CONFIGS["rl"], + config_updates=config_updates, + ) + assert run.status == "COMPLETED" + _check_video_exists(run.config["common"]["log_dir"], "rl") + + +def test_train_adversarial_video_saving(tmpdir): + """Smoke test for imitation.scripts.train_adversarial.""" + named_configs = ["pendulum"] + ALGO_FAST_CONFIGS["adversarial"] + config_updates = { + "common": dict(log_root=tmpdir), + "demonstrations": dict(rollout_path=PENDULUM_TEST_ROLLOUT_PATH), + **_TRAIN_VIDEO_CONFIGS, + } + run = train_adversarial.train_adversarial_ex.run( + command_name="gail", + named_configs=named_configs, + config_updates=config_updates, + ) + assert run.status == "COMPLETED" + _check_video_exists(run.config["common"]["log_dir"], "adversarial") + + +def test_train_preference_comparisons_video_saving(tmpdir): + config_updates = dict( + common=dict(log_root=tmpdir), + **_TRAIN_VIDEO_CONFIGS, + ) + run = train_preference_comparisons.train_preference_comparisons_ex.run( + named_configs=["cartpole"] + ALGO_FAST_CONFIGS["preference_comparison"], + config_updates=config_updates, + ) + assert run.status == "COMPLETED" + _check_video_exists(run.config["common"]["log_dir"], "pc") + + +def test_train_bc_video_saving(tmpdir): + config_updates = dict( + common=dict(log_root=tmpdir), + demonstrations=dict(rollout_path=CARTPOLE_TEST_ROLLOUT_PATH), + **_TRAIN_VIDEO_CONFIGS, + ) + run = train_imitation.train_imitation_ex.run( + command_name="bc", + named_configs=["cartpole"] + ALGO_FAST_CONFIGS["imitation"], + config_updates=config_updates, + ) + assert run.status == "COMPLETED" + _check_video_exists(run.config["common"]["log_dir"], "bc") From 904430e5c8a1723b877256553be083dee7deaeac Mon Sep 17 00:00:00 2001 From: YawenDuan <35738259+kmdanielduan@users.noreply.github.com> Date: Mon, 22 Aug 2022 22:06:40 -0700 Subject: [PATCH 3/6] add train.save_video to simplify video saving activities --- src/imitation/scripts/common/train.py | 24 ++++++++++++- src/imitation/scripts/train_adversarial.py | 32 +++++++---------- src/imitation/scripts/train_imitation.py | 17 ++++----- .../scripts/train_preference_comparisons.py | 36 ++++++++----------- src/imitation/scripts/train_rl.py | 4 +-- src/imitation/util/video_wrapper.py | 33 +++++++++-------- tests/scripts/test_scripts.py | 2 +- 7 files changed, 77 insertions(+), 71 deletions(-) diff --git a/src/imitation/scripts/common/train.py b/src/imitation/scripts/common/train.py index 09ce07c34..3069bf2c3 100644 --- a/src/imitation/scripts/common/train.py +++ b/src/imitation/scripts/common/train.py @@ -1,14 +1,16 @@ """Common configuration elements for training imitation algorithms.""" import logging -from typing import Any, Mapping, Union +from typing import Any, Mapping, Optional, Union import sacred +import stable_baselines3.common.logger as sb_logger from stable_baselines3.common import base_class, policies, torch_layers, vec_env import imitation.util.networks from imitation.data import rollout from imitation.policies import base +from imitation.util import video_wrapper train_ingredient = sacred.Ingredient("train") logger = logging.getLogger(__name__) @@ -102,6 +104,26 @@ def eval_policy( return rollout.rollout_stats(trajs) +@train_ingredient.capture +def save_video( + videos: bool, + video_kwargs: Mapping[str, Any], + output_dir: str, + policy: policies.BasePolicy, + eval_venv: vec_env.VecEnv, + logger: Optional[sb_logger.Logger] = None, +) -> None: + """Save video of imitation policy evaluation.""" + if videos: + video_wrapper.record_and_save_video( + output_dir=output_dir, + policy=policy, + eval_venv=eval_venv, + video_kwargs=video_kwargs, + logger=logger, + ) + + @train_ingredient.capture def suppress_sacred_error(policy_kwargs: Mapping[str, Any]): """No-op so Sacred recognizes `policy_kwargs` is used (in `rl` and elsewhere).""" diff --git a/src/imitation/scripts/train_adversarial.py b/src/imitation/scripts/train_adversarial.py index b74009cf2..f3f31307c 100644 --- a/src/imitation/scripts/train_adversarial.py +++ b/src/imitation/scripts/train_adversarial.py @@ -19,18 +19,18 @@ from imitation.scripts.common import common as common_config from imitation.scripts.common import demonstrations, reward, rl, train from imitation.scripts.config.train_adversarial import train_adversarial_ex -from imitation.util import video_wrapper logger = logging.getLogger("imitation.scripts.train_adversarial") -def save( - _config: Mapping[str, Any], +def save_checkpoint( trainer: common.AdversarialTrainer, - save_path: str, + log_dir: str, eval_venv: vec_env.VecEnv, + round_str: str, ) -> None: """Save discriminator and generator.""" + save_path = os.path.join(log_dir, "checkpoints", round_str) # We implement this here and not in Trainer since we do not want to actually # serialize the whole Trainer (including e.g. expert demonstrations). os.makedirs(save_path, exist_ok=True) @@ -38,14 +38,12 @@ def save( th.save(trainer.reward_test, os.path.join(save_path, "reward_test.pt")) policy_path = os.path.join(save_path, "gen_policy") serialize.save_stable_model(policy_path, trainer.gen_algo) - if _config["train"]["videos"]: - video_wrapper.record_and_save_video( - output_dir=policy_path, - policy=trainer.gen_algo.policy, - eval_venv=eval_venv, - video_kwargs=_config["train"]["video_kwargs"], - logger=trainer.logger, - ) + train.save_video( + output_dir=policy_path, + policy=trainer.gen_algo.policy, + eval_venv=eval_venv, + logger=trainer.logger, + ) def _add_hook(ingredient: sacred.Ingredient) -> None: @@ -80,8 +78,6 @@ def dummy_config(): @train_adversarial_ex.capture def train_adversarial( _run, - _seed: int, - _config: Mapping[str, Any], show_config: bool, algo_cls: Type[common.AdversarialTrainer], algorithm_kwargs: Mapping[str, Any], @@ -98,8 +94,6 @@ def train_adversarial( - Generator policies are saved to `f"{log_dir}/checkpoints/{step}/gen_policy/"`. Args: - _seed: Random seed. - _config: Sacred configuration dict. show_config: Print the merged config before starting training. This is analogous to the print_config command, but will show config after rather than before merging `algorithm_specific` arguments. @@ -169,15 +163,13 @@ def train_adversarial( def callback(round_num): if checkpoint_interval > 0 and round_num % checkpoint_interval == 0: - save_path = os.path.join(log_dir, "checkpoints", f"{round_num:05d}") - save(_config, trainer, save_path, eval_venv) + save_checkpoint(trainer, log_dir, eval_venv, round_str=f"{round_num:05d}") trainer.train(total_timesteps, callback) # Save final artifacts. if checkpoint_interval >= 0: - save_path = os.path.join(log_dir, "checkpoints", "final") - save(_config, trainer, save_path, eval_venv) + save_checkpoint(trainer, log_dir, eval_venv, round_str="final") return { "imit_stats": train.eval_policy(trainer.policy, trainer.venv_train), diff --git a/src/imitation/scripts/train_imitation.py b/src/imitation/scripts/train_imitation.py index 46e120b5d..fb34d521d 100644 --- a/src/imitation/scripts/train_imitation.py +++ b/src/imitation/scripts/train_imitation.py @@ -14,7 +14,6 @@ from imitation.policies import serialize from imitation.scripts.common import common, demonstrations, train from imitation.scripts.config.train_imitation import train_imitation_ex -from imitation.util import video_wrapper logger = logging.getLogger(__name__) @@ -98,8 +97,6 @@ def load_expert_policy( @train_imitation_ex.capture def train_imitation( - _run, - _config: Mapping[str, Any], bc_kwargs: Mapping[str, Any], bc_train_kwargs: Mapping[str, Any], dagger: Mapping[str, Any], @@ -166,14 +163,12 @@ def train_imitation( # TODO(adam): add checkpointing to BC? bc_trainer.save_policy(policy_path=osp.join(log_dir, "final.th")) - if _config["train"]["videos"]: - video_wrapper.record_and_save_video( - output_dir=log_dir, - policy=imit_policy, - eval_venv=eval_venv, - video_kwargs=_config["train"]["video_kwargs"], - logger=custom_logger, - ) + train.save_video( + output_dir=log_dir, + policy=imit_policy, + eval_venv=eval_venv, + logger=custom_logger, + ) return { "imit_stats": train.eval_policy(imit_policy, venv), diff --git a/src/imitation/scripts/train_preference_comparisons.py b/src/imitation/scripts/train_preference_comparisons.py index 38ad2e7c0..f85594b50 100644 --- a/src/imitation/scripts/train_preference_comparisons.py +++ b/src/imitation/scripts/train_preference_comparisons.py @@ -22,17 +22,17 @@ from imitation.scripts.config.train_preference_comparisons import ( train_preference_comparisons_ex, ) -from imitation.util import video_wrapper def save_checkpoint( - _config: Mapping[str, Any], trainer: preference_comparisons.PreferenceComparisons, - save_path: str, + log_dir: str, allow_save_policy: Optional[bool], eval_venv: vec_env.VecEnv, + round_str: str, ) -> None: """Save reward model and optionally policy.""" + save_path = osp.join(log_dir, "checkpoints", round_str) os.makedirs(save_path, exist_ok=True) th.save(trainer.model, osp.join(save_path, "reward_net.pt")) if allow_save_policy: @@ -45,14 +45,12 @@ def save_checkpoint( output_dir=policy_dir, model=trainer.trajectory_generator.algorithm, ) - if _config["train"]["videos"]: - video_wrapper.record_and_save_video( - output_dir=policy_dir, - policy=trainer.trajectory_generator.algorithm.policy, - eval_venv=eval_venv, - video_kwargs=_config["train"]["video_kwargs"], - logger=trainer.logger, - ) + train.save_video( + output_dir=policy_dir, + policy=trainer.trajectory_generator.algorithm.policy, + eval_venv=eval_venv, + logger=trainer.logger, + ) else: trainer.logger.warn( "trainer.trajectory_generator doesn't contain a policy to save.", @@ -62,7 +60,6 @@ def save_checkpoint( @train_preference_comparisons_ex.main def train_preference_comparisons( _seed: int, - _config: Mapping[str, Any], total_timesteps: int, total_comparisons: int, num_iterations: int, @@ -88,7 +85,6 @@ def train_preference_comparisons( Args: _seed: Random seed. - _config: Sacred configuration dict. total_timesteps: number of environment interaction steps total_comparisons: number of preferences to gather in total num_iterations: number of times to train the agent against the reward model @@ -230,17 +226,13 @@ def train_preference_comparisons( def save_callback(iteration_num, traj_generator_num_steps): if checkpoint_interval > 0 and iteration_num % checkpoint_interval == 0: - save_path = osp.join( - log_dir, - "checkpoints", - f"iter_{iteration_num:04d}_step_{traj_generator_num_steps:08d}", - ) + round_str = f"iter_{iteration_num:04d}_step_{traj_generator_num_steps:08d}" save_checkpoint( - _config, trainer=main_trainer, - save_path=save_path, + log_dir=log_dir, allow_save_policy=bool(trajectory_path is None), eval_venv=eval_venv, + round_str=round_str, ) results = main_trainer.train( @@ -255,11 +247,11 @@ def save_callback(iteration_num, traj_generator_num_steps): # Save final artifacts. if checkpoint_interval >= 0: save_checkpoint( - _config, trainer=main_trainer, - save_path=osp.join(log_dir, "checkpoints", "final"), + log_dir=log_dir, allow_save_policy=bool(trajectory_path is None), eval_venv=eval_venv, + round_str="final", ) # Storing and evaluating policy only useful if we actually generate trajectory data diff --git a/src/imitation/scripts/train_rl.py b/src/imitation/scripts/train_rl.py index 63f77feda..106d7e673 100644 --- a/src/imitation/scripts/train_rl.py +++ b/src/imitation/scripts/train_rl.py @@ -54,6 +54,7 @@ def train_rl( - Rollouts are saved to `{log_dir}/rollouts/{step}.pkl`. Args: + _config: Configuration dictionary of the current run. total_timesteps: Number of training timesteps in `model.learn()`. normalize_reward: Applies normalization and clipping to the reward function by keeping a running average of training rewards. Note: this is may be @@ -158,11 +159,10 @@ def train_rl( if policy_save_final: output_dir = os.path.join(policy_dir, "final") serialize.save_stable_model(output_dir, rl_algo) - video_wrapper.record_and_save_video( + train.save_video( output_dir=output_dir, policy=rl_algo.policy, eval_venv=eval_venv, - video_kwargs=_config["train"]["video_kwargs"], logger=rl_algo.logger, ) diff --git a/src/imitation/util/video_wrapper.py b/src/imitation/util/video_wrapper.py index 9924b8d88..fe7b7f77e 100644 --- a/src/imitation/util/video_wrapper.py +++ b/src/imitation/util/video_wrapper.py @@ -33,13 +33,17 @@ def __init__( metadata), then saving to different files can be useful. """ super().__init__(env) - self.episode_id = 0 + self._episode_id = 0 self.video_recorder = None self.single_video = single_video self.directory = os.path.abspath(directory) os.makedirs(self.directory, exist_ok=True) + @property + def episode_id(self) -> int: + return self._episode_id + def _reset_video_recorder(self) -> None: """Creates a video recorder if one does not already exist. @@ -59,14 +63,14 @@ def _reset_video_recorder(self) -> None: env=self.env, base_path=os.path.join( self.directory, - "video.{:06}".format(self.episode_id), + "video.{:06}".format(self._episode_id), ), - metadata={"episode_id": self.episode_id}, + metadata={"episode_id": self._episode_id}, ) def reset(self): self._reset_video_recorder() - self.episode_id += 1 + self._episode_id += 1 return self.env.reset() def step(self, action): @@ -85,28 +89,29 @@ def record_and_save_video( output_dir: str, policy: policies.BasePolicy, eval_venv: vec_env.VecEnv, - video_kwargs: Mapping[str, Any], + video_kwargs: Optional[Mapping[str, Any]] = None, logger: Optional[sb_logger.Logger] = None, ) -> None: video_dir = osp.join(output_dir, "videos") video_venv = VideoWrapper( eval_venv, directory=video_dir, - **video_kwargs, + **(video_kwargs or dict()), ) - sample_until = rollout.make_sample_until(min_timesteps=None, min_episodes=2) - # video.{:06}".format(VideoWrapper.episode_id) will be saved within + sample_until = rollout.make_sample_until(min_timesteps=None, min_episodes=1) + # video.{:06}.mp4".format(VideoWrapper.episode_id) will be saved within # rollout.generate_trajectories() rollout.generate_trajectories(policy, video_venv, sample_until) - assert "video.000000.mp4" in os.listdir(video_dir) - video_path = osp.join(video_dir, "video.000000.mp4") + video_name = "video.{:06}.mp4".format(video_venv.episode_id - 1) + assert video_name in os.listdir(video_dir) + video_path = osp.join(video_dir, video_name) if logger: logger.record("video", video_path) logger.log(f"Recording and saving video to {video_path} ...") class SaveVideoCallback(callbacks.EventCallback): - """Saves the policy using `save_n_record_video` each time it is called. + """Saves the policy using `record_and_save_video` each time when it is called. Should be used in conjunction with `callbacks.EveryNTimesteps` or another event-based trigger. @@ -116,8 +121,8 @@ def __init__( self, policy_dir: str, eval_venv: vec_env.VecEnv, - video_kwargs: Mapping[str, Any], *args, + video_kwargs: Optional[Mapping[str, Any]] = None, **kwargs, ): """Builds SavePolicyCallback. @@ -125,14 +130,14 @@ def __init__( Args: policy_dir: Directory to save checkpoints. eval_venv: Environment to evaluate the policy on. - video_kwargs: Keyword arguments to pass to `VideoWrapper`. *args: Passed through to `callbacks.EventCallback`. + video_kwargs: Keyword arguments to pass to `VideoWrapper`. **kwargs: Passed through to `callbacks.EventCallback`. """ super().__init__(*args, **kwargs) self.policy_dir = policy_dir self.eval_venv = eval_venv - self.video_kwargs = video_kwargs + self.video_kwargs = video_kwargs or dict() def _on_step(self) -> bool: output_dir = os.path.join(self.policy_dir, f"{self.num_timesteps:012d}") diff --git a/tests/scripts/test_scripts.py b/tests/scripts/test_scripts.py index 8e6fed308..15a9c5769 100644 --- a/tests/scripts/test_scripts.py +++ b/tests/scripts/test_scripts.py @@ -910,7 +910,7 @@ def test_convert_trajs(tmpdir: str): _TRAIN_VIDEO_CONFIGS = {"train": {"videos": True}} -VIDEO_NAME = "video.000000.mp4" +VIDEO_NAME = "video.{:06}.mp4".format(0) VIDEO_PATH_DICT = dict( rl=lambda d: osp.join(d, "policies", "final", "videos"), adversarial=lambda d: osp.join(d, "checkpoints", "final", "gen_policy", "videos"), From e8ab769af30720c7b9aff99e7814a3109c4d3232 Mon Sep 17 00:00:00 2001 From: YawenDuan <35738259+kmdanielduan@users.noreply.github.com> Date: Mon, 22 Aug 2022 22:10:20 -0700 Subject: [PATCH 4/6] add minor comment --- tests/scripts/test_scripts.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/scripts/test_scripts.py b/tests/scripts/test_scripts.py index 15a9c5769..bce2ce384 100644 --- a/tests/scripts/test_scripts.py +++ b/tests/scripts/test_scripts.py @@ -911,6 +911,7 @@ def test_convert_trajs(tmpdir: str): _TRAIN_VIDEO_CONFIGS = {"train": {"videos": True}} VIDEO_NAME = "video.{:06}.mp4".format(0) +# Change the following if the file structure of checkpoints changed. VIDEO_PATH_DICT = dict( rl=lambda d: osp.join(d, "policies", "final", "videos"), adversarial=lambda d: osp.join(d, "checkpoints", "final", "gen_policy", "videos"), From b5daea6c310f938675c584258bb549e9ceb80e15 Mon Sep 17 00:00:00 2001 From: YawenDuan <35738259+kmdanielduan@users.noreply.github.com> Date: Mon, 22 Aug 2022 23:17:23 -0700 Subject: [PATCH 5/6] fix bugs and code format --- src/imitation/scripts/train_adversarial.py | 21 +++--- src/imitation/scripts/train_imitation.py | 16 ++--- .../scripts/train_preference_comparisons.py | 52 +++++++------- src/imitation/scripts/train_rl.py | 72 +++++++++---------- 4 files changed, 83 insertions(+), 78 deletions(-) diff --git a/src/imitation/scripts/train_adversarial.py b/src/imitation/scripts/train_adversarial.py index a7f0a3273..00cc551cf 100644 --- a/src/imitation/scripts/train_adversarial.py +++ b/src/imitation/scripts/train_adversarial.py @@ -126,7 +126,6 @@ def train_adversarial( expert_trajs = demonstrations.load_expert_trajs() with common_config.make_venv() as venv: - eval_venv = common_config.make_venv(log_dir=None) reward_net = reward.make_reward_net(venv) relabel_reward_fn = functools.partial( @@ -161,16 +160,20 @@ def train_adversarial( **algorithm_kwargs, ) - def callback(round_num): - if checkpoint_interval > 0 and round_num % checkpoint_interval == 0: - save_checkpoint(trainer, log_dir, eval_venv, round_str=f"{round_num:05d}") + with common_config.make_venv(num_vec=1, log_dir=None) as eval_venv: - trainer.train(total_timesteps, callback) - imit_stats = train.eval_policy(trainer.policy, trainer.venv_train) + def callback(round_num): + if checkpoint_interval > 0 and round_num % checkpoint_interval == 0: + round_str = f"{round_num:05d}" + save_checkpoint(trainer, log_dir, eval_venv, round_str=round_str) + + trainer.train(total_timesteps, callback) - # Save final artifacts. - if checkpoint_interval >= 0: - save_checkpoint(trainer, log_dir, eval_venv, round_str="final") + # Save final artifacts. + if checkpoint_interval >= 0: + save_checkpoint(trainer, log_dir, eval_venv, round_str="final") + + imit_stats = train.eval_policy(trainer.policy, trainer.venv_train) return { "imit_stats": imit_stats, diff --git a/src/imitation/scripts/train_imitation.py b/src/imitation/scripts/train_imitation.py index 920b3da6a..8f79494fb 100644 --- a/src/imitation/scripts/train_imitation.py +++ b/src/imitation/scripts/train_imitation.py @@ -120,7 +120,6 @@ def train_imitation( custom_logger, log_dir = common.setup_logging() with common.make_venv() as venv: - eval_venv = common.make_venv(log_dir=None) imit_policy = make_policy(venv, agent_path=agent_path) expert_trajs = None @@ -164,14 +163,15 @@ def train_imitation( # TODO(adam): add checkpointing to BC? bc_trainer.save_policy(policy_path=osp.join(log_dir, "final.th")) - imit_stats = train.eval_policy(imit_policy, eval_venv) + imit_stats = train.eval_policy(imit_policy, venv) - train.save_video( - output_dir=log_dir, - policy=imit_policy, - eval_venv=eval_venv, - logger=custom_logger, - ) + with common.make_venv(num_vec=1, log_dir=None) as eval_venv: + train.save_video( + output_dir=log_dir, + policy=imit_policy, + eval_venv=eval_venv, + logger=custom_logger, + ) return { "imit_stats": imit_stats, diff --git a/src/imitation/scripts/train_preference_comparisons.py b/src/imitation/scripts/train_preference_comparisons.py index 0653142be..56efba4c8 100644 --- a/src/imitation/scripts/train_preference_comparisons.py +++ b/src/imitation/scripts/train_preference_comparisons.py @@ -144,7 +144,6 @@ def train_preference_comparisons( custom_logger, log_dir = common.setup_logging() with common.make_venv() as venv: - eval_venv = common.make_venv(log_dir=None) reward_net = reward.make_reward_net(venv) relabel_reward_fn = functools.partial( reward_net.predict_processed, @@ -224,22 +223,35 @@ def train_preference_comparisons( query_schedule=query_schedule, ) - def save_callback(iteration_num, traj_gen_num_steps): - if checkpoint_interval > 0 and iteration_num % checkpoint_interval == 0: - round_str = f"iter_{iteration_num:04d}_step_{traj_gen_num_steps:08d}" - save_checkpoint( - trainer=main_trainer, - log_dir=log_dir, - allow_save_policy=bool(trajectory_path is None), - eval_venv=eval_venv, - round_str=round_str, + # Create an eval_venv for policy evaluation and maybe visualization. + with common.make_venv(num_vec=1, log_dir=None) as eval_venv: + + def save_callback(iter_num, traj_gen_num_steps): + if checkpoint_interval > 0 and iter_num % checkpoint_interval == 0: + round_str = f"iter_{iter_num:04d}_step_{traj_gen_num_steps:08d}" + save_checkpoint( + trainer=main_trainer, + log_dir=log_dir, + allow_save_policy=bool(trajectory_path is None), + eval_venv=eval_venv, + round_str=round_str, + ) + + results = main_trainer.train( + total_timesteps, + total_comparisons, + callback=save_callback, ) - results = main_trainer.train( - total_timesteps, - total_comparisons, - callback=save_callback, - ) + # Save final artifacts. + if checkpoint_interval >= 0: + save_checkpoint( + trainer=main_trainer, + log_dir=log_dir, + allow_save_policy=bool(trajectory_path is None), + eval_venv=eval_venv, + round_str="final", + ) # Storing and evaluating policy only useful if we generated trajectory data if bool(trajectory_path is None): @@ -249,16 +261,6 @@ def save_callback(iteration_num, traj_gen_num_steps): if save_preferences: main_trainer.dataset.save(osp.join(log_dir, "preferences.pkl")) - # Save final artifacts. - if checkpoint_interval >= 0: - save_checkpoint( - trainer=main_trainer, - log_dir=log_dir, - allow_save_policy=bool(trajectory_path is None), - eval_venv=eval_venv, - round_str="final", - ) - return results diff --git a/src/imitation/scripts/train_rl.py b/src/imitation/scripts/train_rl.py index ff490d829..fffbd77ca 100644 --- a/src/imitation/scripts/train_rl.py +++ b/src/imitation/scripts/train_rl.py @@ -98,7 +98,6 @@ def train_rl( post_wrappers = [lambda env, idx: wrappers.RolloutInfoWrapper(env)] with common.make_venv(post_wrappers=post_wrappers) as venv: - eval_venv = common.make_venv(log_dir=None) callback_objs = [] if reward_type is not None: reward_fn = load_reward( @@ -123,34 +122,44 @@ def train_rl( RuntimeWarning, ) - if policy_save_interval > 0: - save_policy_callback = serialize.SavePolicyCallback(policy_dir) - save_policy_callback = callbacks.EveryNTimesteps( - policy_save_interval, - save_policy_callback, - ) - callback_objs.append(save_policy_callback) - - if _config["train"]["videos"]: - save_video_callback = video_wrapper.SaveVideoCallback( - policy_dir, - eval_venv, - video_kwargs=_config["train"]["video_kwargs"], - ) - save_video_callback = callbacks.EveryNTimesteps( + with common.make_venv(num_vec=1, log_dir=None) as eval_venv: + if policy_save_interval > 0: + save_policy_callback = serialize.SavePolicyCallback(policy_dir) + save_policy_callback = callbacks.EveryNTimesteps( policy_save_interval, - save_video_callback, + save_policy_callback, + ) + callback_objs.append(save_policy_callback) + + if _config["train"]["videos"]: + save_video_callback = video_wrapper.SaveVideoCallback( + policy_dir, + eval_venv, + video_kwargs=_config["train"]["video_kwargs"], + ) + save_video_callback = callbacks.EveryNTimesteps( + policy_save_interval, + save_video_callback, + ) + callback_objs.append(save_video_callback) + callback = callbacks.CallbackList(callback_objs) + + if agent_path is None: + rl_algo = rl.make_rl_algo(venv) + else: + rl_algo = rl.load_rl_algo_from_path(agent_path=agent_path, venv=venv) + rl_algo.set_logger(custom_logger) + rl_algo.learn(total_timesteps, callback=callback) + + if policy_save_final: + output_dir = os.path.join(policy_dir, "final") + serialize.save_stable_model(output_dir, rl_algo) + train.save_video( + output_dir=output_dir, + policy=rl_algo.policy, + eval_venv=eval_venv, + logger=rl_algo.logger, ) - callback_objs.append(save_video_callback) - callback = callbacks.CallbackList(callback_objs) - - if agent_path is None: - rl_algo = rl.make_rl_algo(venv) - else: - rl_algo = rl.load_rl_algo_from_path(agent_path=agent_path, venv=venv) - rl_algo.set_logger(custom_logger) - rl_algo.learn(total_timesteps, callback=callback) - # Save final artifacts after training is complete. if rollout_save_final: save_path = osp.join(rollout_dir, "final.pkl") @@ -159,15 +168,6 @@ def train_rl( rollout_save_n_episodes, ) types.save(save_path, rollout.rollout(rl_algo, venv, sample_until)) - if policy_save_final: - output_dir = os.path.join(policy_dir, "final") - serialize.save_stable_model(output_dir, rl_algo) - train.save_video( - output_dir=output_dir, - policy=rl_algo.policy, - eval_venv=eval_venv, - logger=rl_algo.logger, - ) # Final evaluation of expert policy. return train.eval_policy(rl_algo, venv) From 0d107934e5cd1a6fa1cfefdec97f9568e262d8dc Mon Sep 17 00:00:00 2001 From: YawenDuan <35738259+kmdanielduan@users.noreply.github.com> Date: Mon, 22 Aug 2022 23:51:01 -0700 Subject: [PATCH 6/6] add test coverage --- tests/scripts/test_scripts.py | 1 + tests/util/test_wb_logger.py | 4 ++++ 2 files changed, 5 insertions(+) diff --git a/tests/scripts/test_scripts.py b/tests/scripts/test_scripts.py index bce2ce384..4d3633399 100644 --- a/tests/scripts/test_scripts.py +++ b/tests/scripts/test_scripts.py @@ -478,6 +478,7 @@ def test_train_adversarial_warmstart(tmpdir, command): config_updates = { "common": dict(log_root=tmpdir), "demonstrations": dict(rollout_path=CARTPOLE_TEST_ROLLOUT_PATH), + "checkpoint_interval": 1, } run = train_adversarial.train_adversarial_ex.run( command_name=command, diff --git a/tests/util/test_wb_logger.py b/tests/util/test_wb_logger.py index fe755ad31..f95bdd82f 100644 --- a/tests/util/test_wb_logger.py +++ b/tests/util/test_wb_logger.py @@ -111,6 +111,10 @@ def test_wandb_output_format(): {"_step": 0, "foo": 42, "fizz": 12}, {"_step": 3, "fizz": 21}, ] + with pytest.raises(ValueError, match=r"wandb.Video accepts a file path.*"): + log_obj.record("video", 42) + log_obj.dump(step=4) + log_obj.close()