diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index 24bc29895..53bae72eb 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -1397,6 +1397,7 @@ def __init__( # for keeping track of the global iteration, in case train() is called # multiple times self._iteration = 0 + self.trajectory_generator_num_steps = 0 self.model = reward_model @@ -1442,7 +1443,7 @@ def train( self, total_timesteps: int, total_comparisons: int, - callback: Optional[Callable[[int], None]] = None, + callback: Optional[Callable[[int, int], None]] = None, ) -> Mapping[str, Any]: """Train the reward model and the policy if applicable. @@ -1526,14 +1527,15 @@ def train( with self.logger.accumulate_means("agent"): self.logger.log(f"Training agent for {num_steps} timesteps") self.trajectory_generator.train(steps=num_steps) + self.trajectory_generator_num_steps += num_steps self.logger.dump(self._iteration) ######################## # Additional Callbacks # ######################## - if callback: - callback(self._iteration) self._iteration += 1 + if callback: + callback(self._iteration, self.trajectory_generator_num_steps) return {"reward_loss": reward_loss, "reward_accuracy": reward_accuracy} diff --git a/src/imitation/scripts/common/common.py b/src/imitation/scripts/common/common.py index c215460e2..4352aa0a5 100644 --- a/src/imitation/scripts/common/common.py +++ b/src/imitation/scripts/common/common.py @@ -3,7 +3,7 @@ import contextlib import logging import os -from typing import Any, Mapping, Sequence, Tuple, Union +from typing import Any, Mapping, Optional, Sequence, Tuple, Union import sacred from stable_baselines3.common import vec_env @@ -134,7 +134,7 @@ def make_venv( env_name: str, num_vec: int, parallel: bool, - log_dir: str, + log_dir: Optional[str], max_episode_steps: int, env_make_kwargs: Mapping[str, Any], **kwargs, diff --git a/src/imitation/scripts/common/train.py b/src/imitation/scripts/common/train.py index f5aa3c1bb..621af4de4 100644 --- a/src/imitation/scripts/common/train.py +++ b/src/imitation/scripts/common/train.py @@ -1,14 +1,16 @@ """Common configuration elements for training imitation algorithms.""" import logging -from typing import Any, Mapping, Union +from typing import Any, Mapping, Optional, Union import sacred +import stable_baselines3.common.logger as sb_logger from stable_baselines3.common import base_class, policies, torch_layers, vec_env import imitation.util.networks from imitation.data import rollout from imitation.policies import base +from imitation.util import video_wrapper train_ingredient = sacred.Ingredient("train") logger = logging.getLogger(__name__) @@ -23,6 +25,10 @@ def config(): # Evaluation n_episodes_eval = 50 # Num of episodes for final mean ground truth return + # Visualization + videos = False # save video files + video_kwargs = {} # arguments to VideoWrapper + locals() # quieten flake8 @@ -111,6 +117,26 @@ def eval_policy( return rollout.rollout_stats(trajs) +@train_ingredient.capture +def save_video( + videos: bool, + video_kwargs: Mapping[str, Any], + output_dir: str, + policy: policies.BasePolicy, + eval_venv: vec_env.VecEnv, + logger: Optional[sb_logger.Logger] = None, +) -> None: + """Save video of imitation policy evaluation.""" + if videos: + video_wrapper.record_and_save_video( + output_dir=output_dir, + policy=policy, + eval_venv=eval_venv, + video_kwargs=video_kwargs, + logger=logger, + ) + + @train_ingredient.capture def suppress_sacred_error(policy_kwargs: Mapping[str, Any]): """No-op so Sacred recognizes `policy_kwargs` is used (in `rl` and elsewhere).""" diff --git a/src/imitation/scripts/config/train_preference_comparisons.py b/src/imitation/scripts/config/train_preference_comparisons.py index ba4e9483c..f891490ac 100644 --- a/src/imitation/scripts/config/train_preference_comparisons.py +++ b/src/imitation/scripts/config/train_preference_comparisons.py @@ -28,7 +28,7 @@ def train_defaults(): fragment_length = 100 # timesteps per fragment used for comparisons total_timesteps = int(1e6) # total number of environment timesteps total_comparisons = 5000 # total number of comparisons to elicit - num_iterations = 5 # Arbitrary, should be tuned for the task + num_iterations = 50 # Arbitrary, should be tuned for the task comparison_queue_size = None # factor by which to oversample transitions before creating fragments transition_oversampling = 1 @@ -39,6 +39,7 @@ def train_defaults(): preference_model_kwargs = {} reward_trainer_kwargs = { "epochs": 3, + "weight_decay": 0.0, } save_preferences = False # save preference dataset at the end? agent_path = None # path to a (partially) trained agent to load at the beginning diff --git a/src/imitation/scripts/train_adversarial.py b/src/imitation/scripts/train_adversarial.py index 2910babdf..b85b310c4 100644 --- a/src/imitation/scripts/train_adversarial.py +++ b/src/imitation/scripts/train_adversarial.py @@ -9,6 +9,7 @@ import sacred.commands import torch as th from sacred.observers import FileStorageObserver +from stable_baselines3.common import vec_env from imitation.algorithms.adversarial import airl as airl_algo from imitation.algorithms.adversarial import common @@ -22,16 +23,26 @@ logger = logging.getLogger("imitation.scripts.train_adversarial") -def save(trainer, save_path): +def save_checkpoint( + trainer: common.AdversarialTrainer, + log_dir: str, + eval_venv: vec_env.VecEnv, + round_str: str, +) -> None: """Save discriminator and generator.""" + save_path = os.path.join(log_dir, "checkpoints", round_str) # We implement this here and not in Trainer since we do not want to actually # serialize the whole Trainer (including e.g. expert demonstrations). os.makedirs(save_path, exist_ok=True) th.save(trainer.reward_train, os.path.join(save_path, "reward_train.pt")) th.save(trainer.reward_test, os.path.join(save_path, "reward_test.pt")) - serialize.save_stable_model( - os.path.join(save_path, "gen_policy"), - trainer.gen_algo, + policy_path = os.path.join(save_path, "gen_policy") + serialize.save_stable_model(policy_path, trainer.gen_algo) + train.save_video( + output_dir=policy_path, + policy=trainer.gen_algo.policy, + eval_venv=eval_venv, + logger=trainer.logger, ) @@ -67,7 +78,6 @@ def dummy_config(): @train_adversarial_ex.capture def train_adversarial( _run, - _seed: int, show_config: bool, algo_cls: Type[common.AdversarialTrainer], algorithm_kwargs: Mapping[str, Any], @@ -84,7 +94,6 @@ def train_adversarial( - Generator policies are saved to `f"{log_dir}/checkpoints/{step}/gen_policy/"`. Args: - _seed: Random seed. show_config: Print the merged config before starting training. This is analogous to the print_config command, but will show config after rather than before merging `algorithm_specific` arguments. @@ -117,6 +126,7 @@ def train_adversarial( expert_trajs = demonstrations.get_expert_trajectories() with common_config.make_venv() as venv: + reward_net = reward.make_reward_net(venv) relabel_reward_fn = functools.partial( reward_net.predict_processed, @@ -150,16 +160,20 @@ def train_adversarial( **algorithm_kwargs, ) - def callback(round_num): - if checkpoint_interval > 0 and round_num % checkpoint_interval == 0: - save(trainer, os.path.join(log_dir, "checkpoints", f"{round_num:05d}")) + with common_config.make_venv(num_vec=1, log_dir=None) as eval_venv: - trainer.train(total_timesteps, callback) - imit_stats = train.eval_policy(trainer.policy, trainer.venv_train) + def callback(round_num): + if checkpoint_interval > 0 and round_num % checkpoint_interval == 0: + round_str = f"{round_num:05d}" + save_checkpoint(trainer, log_dir, eval_venv, round_str=round_str) + + trainer.train(total_timesteps, callback) - # Save final artifacts. - if checkpoint_interval >= 0: - save(trainer, os.path.join(log_dir, "checkpoints", "final")) + # Save final artifacts. + if checkpoint_interval >= 0: + save_checkpoint(trainer, log_dir, eval_venv, round_str="final") + + imit_stats = train.eval_policy(trainer.policy, trainer.venv_train) return { "imit_stats": imit_stats, diff --git a/src/imitation/scripts/train_imitation.py b/src/imitation/scripts/train_imitation.py index 8d7085577..07397682f 100644 --- a/src/imitation/scripts/train_imitation.py +++ b/src/imitation/scripts/train_imitation.py @@ -63,7 +63,6 @@ def make_policy( @train_imitation_ex.capture def train_imitation( - _run, bc_kwargs: Mapping[str, Any], bc_train_kwargs: Mapping[str, Any], dagger: Mapping[str, Any], @@ -132,6 +131,14 @@ def train_imitation( imit_stats = train.eval_policy(imit_policy, venv) + with common.make_venv(num_vec=1, log_dir=None) as eval_venv: + train.save_video( + output_dir=log_dir, + policy=imit_policy, + eval_venv=eval_venv, + logger=custom_logger, + ) + return { "imit_stats": imit_stats, "expert_stats": rollout.rollout_stats( diff --git a/src/imitation/scripts/train_preference_comparisons.py b/src/imitation/scripts/train_preference_comparisons.py index 704bed568..0280b3532 100644 --- a/src/imitation/scripts/train_preference_comparisons.py +++ b/src/imitation/scripts/train_preference_comparisons.py @@ -6,11 +6,12 @@ import functools import os +import os.path as osp from typing import Any, Mapping, Optional, Type, Union import torch as th from sacred.observers import FileStorageObserver -from stable_baselines3.common import type_aliases +from stable_baselines3.common import type_aliases, vec_env from imitation.algorithms import preference_comparisons from imitation.data import types @@ -23,31 +24,33 @@ ) -def save_model( - agent_trainer: preference_comparisons.AgentTrainer, - save_path: str, -): - """Save the model as model.pkl.""" - serialize.save_stable_model( - output_dir=os.path.join(save_path, "policy"), - model=agent_trainer.algorithm, - ) - - def save_checkpoint( trainer: preference_comparisons.PreferenceComparisons, - save_path: str, + log_dir: str, allow_save_policy: Optional[bool], -): + eval_venv: vec_env.VecEnv, + round_str: str, +) -> None: """Save reward model and optionally policy.""" + save_path = osp.join(log_dir, "checkpoints", round_str) os.makedirs(save_path, exist_ok=True) - th.save(trainer.model, os.path.join(save_path, "reward_net.pt")) + th.save(trainer.model, osp.join(save_path, "reward_net.pt")) if allow_save_policy: # Note: We should only save the model as model.pkl if `trajectory_generator` # contains one. Specifically we check if the `trajectory_generator` contains an # `algorithm` attribute. assert hasattr(trainer.trajectory_generator, "algorithm") - save_model(trainer.trajectory_generator, save_path) + policy_dir = osp.join(save_path, "policy") + serialize.save_stable_model( + output_dir=policy_dir, + model=trainer.trajectory_generator.algorithm, + ) + train.save_video( + output_dir=policy_dir, + policy=trainer.trajectory_generator.algorithm.policy, + eval_venv=eval_venv, + logger=trainer.logger, + ) else: trainer.logger.warn( "trainer.trajectory_generator doesn't contain a policy to save.", @@ -242,46 +245,50 @@ def train_preference_comparisons( query_schedule=query_schedule, ) - def save_callback(iteration_num): - if checkpoint_interval > 0 and iteration_num % checkpoint_interval == 0: + # Create an eval_venv for policy evaluation and maybe visualization. + with common.make_venv(num_vec=1, log_dir=None) as eval_venv: + + def save_callback(iter_num, traj_gen_num_steps): + if checkpoint_interval > 0 and iter_num % checkpoint_interval == 0: + round_str = f"iter_{iter_num:04d}_step_{traj_gen_num_steps:08d}" + save_checkpoint( + trainer=main_trainer, + log_dir=log_dir, + allow_save_policy=bool(trajectory_path is None), + eval_venv=eval_venv, + round_str=round_str, + ) + + results = main_trainer.train( + total_timesteps, + total_comparisons, + callback=save_callback, + ) + + # Save final artifacts. + if checkpoint_interval >= 0: save_checkpoint( trainer=main_trainer, - save_path=os.path.join( - log_dir, - "checkpoints", - f"{iteration_num:04d}", - ), + log_dir=log_dir, allow_save_policy=bool(trajectory_path is None), + eval_venv=eval_venv, + round_str="final", ) - results = main_trainer.train( - total_timesteps, - total_comparisons, - callback=save_callback, - ) - # Storing and evaluating policy only useful if we generated trajectory data if bool(trajectory_path is None): results = dict(results) results["rollout"] = train.eval_policy(agent, venv) if save_preferences: - main_trainer.dataset.save(os.path.join(log_dir, "preferences.pkl")) - - # Save final artifacts. - if checkpoint_interval >= 0: - save_checkpoint( - trainer=main_trainer, - save_path=os.path.join(log_dir, "checkpoints", "final"), - allow_save_policy=bool(trajectory_path is None), - ) + main_trainer.dataset.save(osp.join(log_dir, "preferences.pkl")) return results def main_console(): observer = FileStorageObserver( - os.path.join("output", "sacred", "train_preference_comparisons"), + osp.join("output", "sacred", "train_preference_comparisons"), ) train_preference_comparisons_ex.observers.append(observer) train_preference_comparisons_ex.run_commandline() diff --git a/src/imitation/scripts/train_rl.py b/src/imitation/scripts/train_rl.py index 7122cd701..eb34df1e0 100644 --- a/src/imitation/scripts/train_rl.py +++ b/src/imitation/scripts/train_rl.py @@ -24,10 +24,12 @@ from imitation.rewards.serialize import load_reward from imitation.scripts.common import common, rl, train from imitation.scripts.config.train_rl import train_rl_ex +from imitation.util import video_wrapper @train_rl_ex.main def train_rl( + _config: Mapping[str, Any], *, total_timesteps: int, normalize_reward: bool, @@ -52,6 +54,7 @@ def train_rl( - Rollouts are saved to `{log_dir}/rollouts/{step}.pkl`. Args: + _config: Configuration dictionary of the current run. total_timesteps: Number of training timesteps in `model.learn()`. normalize_reward: Applies normalization and clipping to the reward function by keeping a running average of training rewards. Note: this is may be @@ -119,22 +122,44 @@ def train_rl( RuntimeWarning, ) - if policy_save_interval > 0: - save_policy_callback = serialize.SavePolicyCallback(policy_dir) - save_policy_callback = callbacks.EveryNTimesteps( - policy_save_interval, - save_policy_callback, - ) - callback_objs.append(save_policy_callback) - callback = callbacks.CallbackList(callback_objs) - - if agent_path is None: - rl_algo = rl.make_rl_algo(venv) - else: - rl_algo = rl.load_rl_algo_from_path(agent_path=agent_path, venv=venv) - rl_algo.set_logger(custom_logger) - rl_algo.learn(total_timesteps, callback=callback) - + with common.make_venv(num_vec=1, log_dir=None) as eval_venv: + if policy_save_interval > 0: + save_policy_callback = serialize.SavePolicyCallback(policy_dir) + save_policy_callback = callbacks.EveryNTimesteps( + policy_save_interval, + save_policy_callback, + ) + callback_objs.append(save_policy_callback) + + if _config["train"]["videos"]: + save_video_callback = video_wrapper.SaveVideoCallback( + policy_dir, + eval_venv, + video_kwargs=_config["train"]["video_kwargs"], + ) + save_video_callback = callbacks.EveryNTimesteps( + policy_save_interval, + save_video_callback, + ) + callback_objs.append(save_video_callback) + callback = callbacks.CallbackList(callback_objs) + + if agent_path is None: + rl_algo = rl.make_rl_algo(venv) + else: + rl_algo = rl.load_rl_algo_from_path(agent_path=agent_path, venv=venv) + rl_algo.set_logger(custom_logger) + rl_algo.learn(total_timesteps, callback=callback) + + if policy_save_final: + output_dir = os.path.join(policy_dir, "final") + serialize.save_stable_model(output_dir, rl_algo) + train.save_video( + output_dir=output_dir, + policy=rl_algo.policy, + eval_venv=eval_venv, + logger=rl_algo.logger, + ) # Save final artifacts after training is complete. if rollout_save_final: save_path = osp.join(rollout_dir, "final.pkl") @@ -146,9 +171,6 @@ def train_rl( save_path, rollout.rollout(rl_algo, rl_algo.get_env(), sample_until), ) - if policy_save_final: - output_dir = os.path.join(policy_dir, "final") - serialize.save_stable_model(output_dir, rl_algo) # Final evaluation of expert policy. return train.eval_policy(rl_algo, venv) diff --git a/src/imitation/util/logger.py b/src/imitation/util/logger.py index 8875cb211..efa07740b 100644 --- a/src/imitation/util/logger.py +++ b/src/imitation/util/logger.py @@ -202,7 +202,10 @@ def write( if excluded is not None and "wandb" in excluded: continue - self.wandb_module.log({key: value}, step=step) + if key != "video": + self.wandb_module.log({key: value}, step=step) + else: + self.wandb_module.log({"video": self.wandb_module.Video(value)}) self.wandb_module.log({}, commit=True) def close(self) -> None: diff --git a/src/imitation/util/video_wrapper.py b/src/imitation/util/video_wrapper.py index 5fc2ae4f5..fe7b7f77e 100644 --- a/src/imitation/util/video_wrapper.py +++ b/src/imitation/util/video_wrapper.py @@ -1,11 +1,15 @@ """Wrapper to record rendered video frames from an environment.""" import os +import os.path as osp +from typing import Any, Mapping, Optional import gym +import stable_baselines3.common.logger as sb_logger from gym.wrappers.monitoring import video_recorder +from stable_baselines3.common import callbacks, policies, vec_env -from imitation.data import types +from imitation.data import rollout, types class VideoWrapper(gym.Wrapper): @@ -29,12 +33,16 @@ def __init__( metadata), then saving to different files can be useful. """ super().__init__(env) - self.episode_id = 0 + self._episode_id = 0 self.video_recorder = None self.single_video = single_video self.directory = os.path.abspath(directory) - os.makedirs(self.directory) + os.makedirs(self.directory, exist_ok=True) + + @property + def episode_id(self) -> int: + return self._episode_id def _reset_video_recorder(self) -> None: """Creates a video recorder if one does not already exist. @@ -55,14 +63,14 @@ def _reset_video_recorder(self) -> None: env=self.env, base_path=os.path.join( self.directory, - "video.{:06}".format(self.episode_id), + "video.{:06}".format(self._episode_id), ), - metadata={"episode_id": self.episode_id}, + metadata={"episode_id": self._episode_id}, ) def reset(self): self._reset_video_recorder() - self.episode_id += 1 + self._episode_id += 1 return self.env.reset() def step(self, action): @@ -75,3 +83,70 @@ def close(self) -> None: self.video_recorder.close() self.video_recorder = None super().close() + + +def record_and_save_video( + output_dir: str, + policy: policies.BasePolicy, + eval_venv: vec_env.VecEnv, + video_kwargs: Optional[Mapping[str, Any]] = None, + logger: Optional[sb_logger.Logger] = None, +) -> None: + video_dir = osp.join(output_dir, "videos") + video_venv = VideoWrapper( + eval_venv, + directory=video_dir, + **(video_kwargs or dict()), + ) + sample_until = rollout.make_sample_until(min_timesteps=None, min_episodes=1) + # video.{:06}.mp4".format(VideoWrapper.episode_id) will be saved within + # rollout.generate_trajectories() + rollout.generate_trajectories(policy, video_venv, sample_until) + video_name = "video.{:06}.mp4".format(video_venv.episode_id - 1) + assert video_name in os.listdir(video_dir) + video_path = osp.join(video_dir, video_name) + if logger: + logger.record("video", video_path) + logger.log(f"Recording and saving video to {video_path} ...") + + +class SaveVideoCallback(callbacks.EventCallback): + """Saves the policy using `record_and_save_video` each time when it is called. + + Should be used in conjunction with `callbacks.EveryNTimesteps` + or another event-based trigger. + """ + + def __init__( + self, + policy_dir: str, + eval_venv: vec_env.VecEnv, + *args, + video_kwargs: Optional[Mapping[str, Any]] = None, + **kwargs, + ): + """Builds SavePolicyCallback. + + Args: + policy_dir: Directory to save checkpoints. + eval_venv: Environment to evaluate the policy on. + *args: Passed through to `callbacks.EventCallback`. + video_kwargs: Keyword arguments to pass to `VideoWrapper`. + **kwargs: Passed through to `callbacks.EventCallback`. + """ + super().__init__(*args, **kwargs) + self.policy_dir = policy_dir + self.eval_venv = eval_venv + self.video_kwargs = video_kwargs or dict() + + def _on_step(self) -> bool: + output_dir = os.path.join(self.policy_dir, f"{self.num_timesteps:012d}") + record_and_save_video( + output_dir=output_dir, + policy=self.model.policy, + eval_venv=self.eval_venv, + video_kwargs=self.video_kwargs, + logger=self.model.logger, + ) + + return True diff --git a/tests/scripts/test_scripts.py b/tests/scripts/test_scripts.py index f461dc889..c29b20ea5 100644 --- a/tests/scripts/test_scripts.py +++ b/tests/scripts/test_scripts.py @@ -8,6 +8,7 @@ import collections import filecmp import os +import os.path as osp import pathlib import pickle import platform @@ -508,6 +509,7 @@ def test_train_adversarial_warmstart(tmpdir, command): config_updates = { "common": dict(log_root=tmpdir), "demonstrations": dict(rollout_path=CARTPOLE_TEST_ROLLOUT_PATH), + "checkpoint_interval": 1, } run = train_adversarial.train_adversarial_ex.run( command_name=command, @@ -937,3 +939,79 @@ def test_convert_trajs(tmpdir: str): assert len(from_pkl) == len(from_npz) for t_pkl, t_npz in zip(from_pkl, from_npz): assert t_pkl == t_npz + + +_TRAIN_VIDEO_CONFIGS = {"train": {"videos": True}} +VIDEO_NAME = "video.{:06}.mp4".format(0) +# Change the following if the file structure of checkpoints changed. +VIDEO_PATH_DICT = dict( + rl=lambda d: osp.join(d, "policies", "final", "videos"), + adversarial=lambda d: osp.join(d, "checkpoints", "final", "gen_policy", "videos"), + pc=lambda d: osp.join(d, "checkpoints", "final", "policy", "videos"), + bc=lambda d: osp.join(d, "videos"), +) + + +def _check_video_exists(log_dir, algo): + video_dir = VIDEO_PATH_DICT[algo](log_dir) + assert os.path.exists(video_dir) + assert VIDEO_NAME in os.listdir(video_dir) + + +def test_train_rl_video_saving(tmpdir): + """Smoke test for imitation.scripts.train_rl.""" + config_updates = dict( + common=dict(log_root=tmpdir), + **_TRAIN_VIDEO_CONFIGS, + ) + run = train_rl.train_rl_ex.run( + named_configs=["cartpole"] + ALGO_FAST_CONFIGS["rl"], + config_updates=config_updates, + ) + assert run.status == "COMPLETED" + _check_video_exists(run.config["common"]["log_dir"], "rl") + + +def test_train_adversarial_video_saving(tmpdir): + """Smoke test for imitation.scripts.train_adversarial.""" + named_configs = ["pendulum"] + ALGO_FAST_CONFIGS["adversarial"] + config_updates = { + "common": dict(log_root=tmpdir), + "demonstrations": dict(rollout_path=PENDULUM_TEST_ROLLOUT_PATH), + **_TRAIN_VIDEO_CONFIGS, + } + run = train_adversarial.train_adversarial_ex.run( + command_name="gail", + named_configs=named_configs, + config_updates=config_updates, + ) + assert run.status == "COMPLETED" + _check_video_exists(run.config["common"]["log_dir"], "adversarial") + + +def test_train_preference_comparisons_video_saving(tmpdir): + config_updates = dict( + common=dict(log_root=tmpdir), + **_TRAIN_VIDEO_CONFIGS, + ) + run = train_preference_comparisons.train_preference_comparisons_ex.run( + named_configs=["cartpole"] + ALGO_FAST_CONFIGS["preference_comparison"], + config_updates=config_updates, + ) + assert run.status == "COMPLETED" + _check_video_exists(run.config["common"]["log_dir"], "pc") + + +def test_train_bc_video_saving(tmpdir): + config_updates = dict( + common=dict(log_root=tmpdir), + demonstrations=dict(rollout_path=CARTPOLE_TEST_ROLLOUT_PATH), + **_TRAIN_VIDEO_CONFIGS, + ) + run = train_imitation.train_imitation_ex.run( + command_name="bc", + named_configs=["cartpole"] + ALGO_FAST_CONFIGS["imitation"], + config_updates=config_updates, + ) + assert run.status == "COMPLETED" + _check_video_exists(run.config["common"]["log_dir"], "bc") diff --git a/tests/util/test_wb_logger.py b/tests/util/test_wb_logger.py index fe755ad31..f95bdd82f 100644 --- a/tests/util/test_wb_logger.py +++ b/tests/util/test_wb_logger.py @@ -111,6 +111,10 @@ def test_wandb_output_format(): {"_step": 0, "foo": 42, "fizz": 12}, {"_step": 3, "fizz": 21}, ] + with pytest.raises(ValueError, match=r"wandb.Video accepts a file path.*"): + log_obj.record("video", 42) + log_obj.dump(step=4) + log_obj.close()