xagents — реализации алгоритмов обучения с подкреплением

Описание

Можно сказать, что эта работа началась и превратилась из автономной реализации DQN, которую я включил в старый вопрос, в мини-библиотеку. агенты, вмещающий 7 многоразовых реализаций алгоритмов DRL на основе тензорного потока. Из-за кодовой базы среднего размера ~ = 6 тыс. Строк, его нельзя включить в один пост из-за ограничений на количество символов. Поэтому я разобью код на более мелкие логические компоненты, которые можно назвать, содержать и просматривать отдельно, и я постараюсь сделать его менее скучным и более интерактивным, насколько это возможно.

Демо

Позвольте мне познакомить вас с функциями в интерактивном режиме …

1. Установка

Я предполагаю, что у вас установлены python3.8 + и virtualenv, и вы используете какую-то оболочку unix. Пожалуйста, запустите в терминале следующее:

mkdir xagent-demo
cd xagent-demo
virtualenv demo-env
source demo-env/bin/activate
git clone https://github.com/schissmantics/xagents
pip install xagents/

Это займет несколько минут, после чего вы можете проверить установку, запустив:

>>> xagents

xagents 1.0

Usage:
    xagents <command> <agent> [options] [args]

Available commands:
    train      Train given an agent and environment
    play       Play a game given a trained agent and environment
    tune       Tune hyperparameters given an agent, hyperparameter specs, and environment

Use xagents <command> to see more info about a command
Use xagents <command> <agent> to see more info about command + agent

Используйте продемонстрированный выше синтаксис, если хотите узнать больше о параметрах команд и агентов. Например, чтобы узнать варианты обучения A2C, запустите xagents train a2c, который должен отображать соответствующие параметры.

2. Обучение

Давайте обучим агента PPO в спортзале OpenAI Корзина среды, чтобы получить целевую награду в 250. Тренировка на процессоре не должна занимать более 2-3 минут. Пожалуйста, запустите следующую команду в терминале:

xagents train ppo --env CartPole-v1 --target-reward 250 --n-envs 16 --n-steps 1024 --checkpoints ppo-cartpole.tf --history-checkpoint ppo-cartpole.parquet --seed 123

После завершения обучения вы должны увидеть Reward achieved in 376832 steps. В результате будут файлы контрольных точек, которые мы будем использовать позже в следующих шагах.

3. Визуализируйте историю тренировок
Визуализация обучения в настоящее время недоступна из командной строки, поэтому вам нужно запустить python в интерактивном режиме, а затем выполнить следующее:

from xagents.utils.common import plot_history
import matplotlib.pyplot as plt

plot_history(['ppo-cartpole.parquet'], ['PPO'], 'CartPole-v1', history_interval=100)
plt.show()

Что должно отображаться:

кривая ppo-cartpole

4. Играть в игру

Теперь мы воспользуемся контрольными точками веса, полученными в результате предыдущей тренировки, чтобы сыграть один эпизод «Картпул». Пожалуйста, запустите в терминале следующее:

xagents play ppo --env CartPole-v1 --render --weights ppo-cartpole.tf --video-dir video

который отобразит эпизод, который воспроизводится обученным агентом PPO, и приведет к следующему видео (я вручную преобразовал его в gif, чтобы загрузить сюда):

геймплей

Это еще не все, есть другие функции, включая настройку гиперпараметров, которые я буду обсуждать в других сообщениях с соответствующим кодом. Для получения дополнительной информации, не стесняйтесь проверить проект ПРОЧТИ МЕНЯ всякий раз, когда вам нужно.

Архитектура проекта

xagents
├── LICENSE
├── MANIFEST.in
├── README.md
├── img
│   ├── bipedal-walker.gif
│   ├── breakout.gif
│   ├── carnival.gif
│   ├── gopher.gif
│   ├── lunar-lander.gif
│   ├── pacman.gif
│   ├── param-importances.png
│   ├── pong.gif
│   ├── step-benchmark.jpg
│   ├── time-benchmark.jpg
│   └── wandb-agents.png
├── requirements.txt
├── scratch_commands.sh
├── setup.py
└── xagents
    ├── __init__.py
    ├── a2c
    │   ├── __init__.py
    │   ├── agent.py
    │   ├── cli.py
    │   └── models
    │       ├── ann-actor-critic.cfg
    │       └── cnn-actor-critic.cfg
    ├── acer
    │   ├── __init__.py
    │   ├── agent.py
    │   ├── cli.py
    │   └── models
    │       └── cnn-actor-critic.cfg
    ├── base.py
    ├── cli.py
    ├── ddpg
    │   ├── __init__.py
    │   ├── agent.py
    │   ├── cli.py
    │   └── models
    │       ├── ann-actor.cfg
    │       └── ann-critic.cfg
    ├── dqn
    │   ├── __init__.py
    │   ├── agent.py
    │   ├── cli.py
    │   └── models
    │       └── cnn.cfg
    ├── ppo
    │   ├── __init__.py
    │   ├── agent.py
    │   ├── cli.py
    │   └── models
    │       ├── ann-actor-critic.cfg
    │       └── cnn-actor-critic.cfg
    ├── td3
    │   ├── __init__.py
    │   ├── agent.py
    │   ├── cli.py
    │   └── models
    │       ├── ann-actor.cfg
    │       └── ann-critic.cfg
    ├── tests
    │   ├── __init__.py
    │   ├── conftest.py
    │   ├── test_base.py
    │   ├── test_buffers.py
    │   ├── test_cli.py
    │   ├── test_common_utils.py
    │   ├── test_tuning.py
    │   └── utils.py
    ├── trpo
    │   ├── __init__.py
    │   ├── agent.py
    │   ├── cli.py
    │   └── models
    │       ├── ann-actor.cfg
    │       ├── ann-critic.cfg
    │       ├── cnn-actor.cfg
    │       └── cnn-critic.cfg
    └── utils
        ├── __init__.py
        ├── buffers.py
        ├── cli.py
        ├── common.py
        └── tuning.py

Для тех, у кого нет опыта в RL, есть 2 типа обучения: политическое и внеполитическое. Для получения дополнительной информации об условиях вы можете проверить это вопрос который должен прояснить, что это такое. Все реализованные агенты наследуются от 2-х базовых агентов: OnPolicy и OffPolicy. Оба наследуются от BaseAgent и все они включены в base.py, о котором идет речь в этой публикации. Все агенты, унаследованные от этих классов, должны реализовывать train_step() абстрактный метод, который определяет логику одного шага поезда. Все агенты разделяют play() , который вызывается для воспроизведения одного эпизода, как показано выше. Кроме того, они разделяют fit() метод, который запускает цикл обучения и вызывается для обучения агента. Остальные реализованные методы являются вспомогательными, и вы можете проверить строки документации, чтобы понять, что они делают.

Вот определение функции, которая вам понадобится:

import pyarrow as pa
import pyarrow.parquet as pq


def write_from_dict(_dict, path):
    """
    Write to .parquet given a dict
    Args:
        _dict: Dictionary of label: [scalar]
        path: Path to .parquet fiile.

    Returns:
        None
    """
    table = pa.Table.from_pydict(_dict)
    pq.write_to_dataset(table, root_path=path, compression='gzip')

base.py

import os
import random
from abc import ABC
from collections import deque
from datetime import timedelta
from pathlib import Path
from time import perf_counter, sleep

import cv2
import gym
import numpy as np
import optuna
import pandas as pd
import tensorflow as tf
import wandb
from gym.spaces.box import Box
from gym.spaces.discrete import Discrete

from xagents.utils.common import write_from_dict


class BaseAgent(ABC):
    """
    Base class for various types of agents.
    """

    def __init__(
        self,
        envs,
        model,
        checkpoints=None,
        reward_buffer_size=100,
        n_steps=1,
        gamma=0.99,
        display_precision=2,
        seed=None,
        log_frequency=None,
        history_checkpoint=None,
        plateau_reduce_factor=0.9,
        plateau_reduce_patience=10,
        early_stop_patience=3,
        divergence_monitoring_steps=None,
        quiet=False,
        trial=None,
    ):
        """
        Initialize base settings.
        Args:
            envs: A list of gym environments.
            model: tf.keras.models.Model that is expected to be compiled
                with an optimizer before training starts.
            checkpoints: A list of paths to .tf filenames under which the trained model(s)
                will be saved.
            reward_buffer_size: Size of the reward buffer that will hold the last n total
                rewards which will be used for calculating the mean reward.
            n_steps: n-step transition for example given s1, s2, s3, s4 and n_step = 4,
                transition will be s1 -> s4 (defaults to 1, s1 -> s2)
            gamma: Discount factor used for gradient updates.
            display_precision: Decimal precision for display purposes.
            seed: Random seed passed to random.seed(), np.random.seed(), tf.random.seed(),
                env.seed()
            log_frequency: Interval of done games to display progress after each,
                defaults to the number of environments given if not specified.
            history_checkpoint: Path to .parquet file to which episode history will be saved.
            plateau_reduce_patience: int, Maximum times of non-improving consecutive model checkpoints.
            plateau_reduce_factor: Factor by which the learning rates of all models in
                self.output_models are multiplied when plateau_reduce_patience is consecutively
                reached / exceeded.
            early_stop_patience: Number of times plateau_reduce_patience is consecutively
                reached / exceeded.
            divergence_monitoring_steps: Number of steps at which reduce on plateau,
                and early stopping start monitoring.
            quiet: If True, all agent messages will be silenced.
            trial: optuna.trial.Trial
        """
        assert envs, 'No environments given'
        self.n_envs = len(envs)
        self.envs = envs
        self.model = model
        self.checkpoints = checkpoints
        self.total_rewards = deque(maxlen=reward_buffer_size)
        self.n_steps = n_steps
        self.gamma = gamma
        self.display_precision = display_precision
        self.seed = seed
        self.output_models = [self.model]
        self.log_frequency = log_frequency or self.n_envs
        self.id = self.__module__.split('.')[1]
        self.history_checkpoint = history_checkpoint
        self.plateau_reduce_factor = plateau_reduce_factor
        self.plateau_reduce_patience = plateau_reduce_patience
        self.early_stop_patience = early_stop_patience
        self.divergence_monitoring_steps = divergence_monitoring_steps
        self.quiet = quiet
        self.trial = trial
        self.reported_rewards = 0
        self.plateau_count = 0
        self.early_stop_count = 0
        self.target_reward = None
        self.max_steps = None
        self.input_shape = self.envs[0].observation_space.shape
        self.n_actions = None
        self.best_reward = -float('inf')
        self.mean_reward = -float('inf')
        self.states = [np.array(0)] * self.n_envs
        self.dones = [False] * self.n_envs
        self.steps = 0
        self.frame_speed = 0
        self.last_reset_step = 0
        self.training_start_time = None
        self.last_reset_time = None
        self.games = 0
        self.episode_rewards = np.zeros(self.n_envs)
        self.done_envs = 0
        self.supported_action_spaces = Box, Discrete
        if seed:
            self.set_seeds(seed)
        self.reset_envs()
        self.set_action_count()
        self.img_inputs = len(self.states[0].shape) >= 2
        self.display_titles = (
            'time',
            'steps',
            'games',
            'speed',
            'mean reward',
            'best reward',
        )

    def assert_valid_env(self, env, valid_type):
        """
        Assert the right type of environment is passed to an agent.
        Args:
            env: gym environment.
            valid_type: gym.spaces class.

        Returns:
            None
        """
        assert isinstance(env.action_space, valid_type), (
            f'Invalid environment: {env.spec.id}. {self.__class__.__name__} supports '
            f'environments with a {valid_type} action space only, got {env.action_space}'
        )

    def display_message(self, *args, **kwargs):
        """
        Display messages to the console.
        Args:
            *args: args passed to print()
            **kwargs: kwargs passed to print()

        Returns:
            None
        """
        if not self.quiet:
            print(*args, **kwargs)

    def set_seeds(self, seed):
        """
        Set random seeds for numpy, tensorflow, random, gym
        Args:
            seed: int, random seed.

        Returns:
            None
        """
        tf.random.set_seed(seed)
        np.random.seed(seed)
        for env in self.envs:
            env.seed(seed)
            env.action_space.seed(seed)
        os.environ['PYTHONHASHSEED'] = f'{seed}'
        random.seed(seed)

    def reset_envs(self):
        """
        Reset all environments in self.envs and update self.states

        Returns:
            None
        """
        for i, env in enumerate(self.envs):
            self.states[i] = env.reset()

    def set_action_count(self):
        """
        Set `self.n_actions` to the number of actions for discrete
        environments or to the shape of the action space for continuous.
        """
        assert (
            type(action_space := self.envs[0].action_space)
            in self.supported_action_spaces
        ), f'Expected one of {self.supported_action_spaces}, got {action_space}'
        if isinstance(action_space, Discrete):
            self.n_actions = action_space.n
        if isinstance(action_space, Box):
            self.n_actions = action_space.shape[0]

    def check_checkpoints(self):
        """
        Ensure the number of given checkpoints matches the number of output models.

        Returns:
            None
        """
        assert (n_models := len(self.output_models)) == (
            n_checkpoints := len(self.checkpoints)
        ), (
            f'Expected {n_models} checkpoints for {n_models} '
            f'given output models, got {n_checkpoints}'
        )

    def checkpoint(self):
        """
        Save model weights if current reward > best reward.

        Returns:
            None
        """
        if self.mean_reward > self.best_reward:
            self.plateau_count = 0
            self.early_stop_count = 0
            self.display_message(
                f'Best reward updated: {self.best_reward} -> {self.mean_reward}'
            )
            if self.checkpoints:
                for model, checkpoint in zip(self.output_models, self.checkpoints):
                    model.save_weights(checkpoint)
        self.best_reward = max(self.mean_reward, self.best_reward)

    def display_metrics(self):
        """
        Display progress metrics to the console when environments complete a full episode each.
        Metrics consist of:
            - time: Time since training started.
            - steps: Time steps so far.
            - games: Finished games / episodes that resulted in a terminal state.
            - speed: Frame speed/s
            - mean reward: Mean game total rewards.
            - best reward: Highest total episode score obtained.

        Returns:
            None
        """
        display_values = (
            timedelta(seconds=perf_counter() - self.training_start_time),
            self.steps,
            self.games,
            f'{round(self.frame_speed)} steps/s',
            self.mean_reward,
            self.best_reward,
        )
        display = (
            f'{title}: {value}'
            for title, value in zip(self.display_titles, display_values)
        )
        self.display_message(', '.join(display))

    def update_metrics(self):
        """
        Update progress metrics which consist of last reset step and time used
        for calculation of fps, and update mean and best rewards. The model is
        saved if there is a checkpoint path specified.

        Returns:
            None
        """
        self.checkpoint()
        if (
            self.divergence_monitoring_steps
            and self.steps >= self.divergence_monitoring_steps
            and self.mean_reward <= self.best_reward
        ):
            self.plateau_count += 1
        if self.plateau_count >= self.plateau_reduce_patience:
            current_lr, new_lr = None, None
            for model in self.output_models:
                current_lr = model.optimizer.learning_rate
                new_lr = current_lr * self.plateau_reduce_factor
            self.display_message(
                f'Learning rate reduced {current_lr.numpy()} ' f'-> {new_lr.numpy()}'
            )
            current_lr.assign(new_lr)
            self.plateau_count = 0
            self.early_stop_count += 1
        self.frame_speed = (self.steps - self.last_reset_step) / (
            perf_counter() - self.last_reset_time
        )
        self.last_reset_step = self.steps
        self.mean_reward = np.around(
            np.mean(self.total_rewards), self.display_precision
        )

    def report_rewards(self):
        """
        Report intermediate rewards or raise an exception to
        prune current trial.

        Returns:
            None

        Raises:
            optuna.exceptions.TrialPruned
        """
        self.trial.report(np.mean(self.total_rewards), self.reported_rewards)
        self.reported_rewards += 1
        if self.trial.should_prune():
            raise optuna.exceptions.TrialPruned()

    def check_episodes(self):
        """
        Check environment done counts to display progress and update metrics.

        Returns:
            None
        """
        if self.done_envs >= self.log_frequency:
            self.update_metrics()
            if self.trial:
                self.report_rewards()
            self.last_reset_time = perf_counter()
            self.display_metrics()
            self.done_envs = 0

    def training_done(self):
        """
        Check whether a target reward or maximum number of steps is reached.

        Returns:
            bool
        """
        if self.early_stop_count >= self.early_stop_patience:
            self.display_message(f'Early stopping')
            return True
        if self.target_reward and self.mean_reward >= self.target_reward:
            self.display_message(f'Reward achieved in {self.steps} steps')
            return True
        if self.max_steps and self.steps >= self.max_steps:
            self.display_message(f'Maximum steps exceeded')
            return True
        return False

    def concat_buffer_samples(self):
        """
        Concatenate samples drawn from each environment respective buffer.
        Args:

        Returns:
            A list of concatenated samples.
        """
        if hasattr(self, 'buffers'):
            batches = []
            for i, env in enumerate(self.envs):
                buffer = self.buffers[i]
                batch = buffer.get_sample()
                batches.append(batch)
            dtypes = (
                self.batch_dtypes
                if hasattr(self, 'batch_dtypes')
                else [np.float32 for _ in range(len(batches[0]))]
            )
            if len(batches) > 1:
                return [
                    np.concatenate(item).astype(dtype)
                    for (item, dtype) in zip(zip(*batches), dtypes)
                ]
            return [item.astype(dtype) for (item, dtype) in zip(batches[0], dtypes)]

    def update_history(self, episode_reward):
        """
        Write 1 episode stats to .parquet history checkpoint.
        Args:
            episode_reward: int, a finished episode reward

        Returns:
            None
        """
        data = {
            'mean_reward': [self.mean_reward],
            'best_reward': [self.best_reward],
            'episode_reward': [episode_reward],
            'step': [self.steps],
            'time': [perf_counter() - self.training_start_time],
        }
        write_from_dict(data, self.history_checkpoint)

    def step_envs(self, actions, get_observation=False, store_in_buffers=False):
        """
        Step environments in self.envs, update metrics (if any done games)
            and return / store results.
        Args:
            actions: An iterable of actions to execute by environments.
            get_observation: If True, a list of [states, actions, rewards, dones, new_states]
                of length self.n_envs each will be returned.
            store_in_buffers: If True, each observation is saved separately in respective buffer.

        Returns:
            A list of observations as numpy arrays or an empty list.
        """
        observations = []
        for (
            (i, env),
            action,
            *items,
        ) in zip(enumerate(self.envs), actions):
            state = self.states[i]
            new_state, reward, done, _ = env.step(action)
            self.states[i] = new_state
            self.dones[i] = done
            self.episode_rewards[i] += reward
            observation = state, action, reward, done, new_state
            if store_in_buffers and hasattr(self, 'buffers'):
                self.buffers[i].append(*observation)
            if get_observation:
                observations.append(observation)
            if done:
                if self.history_checkpoint:
                    self.update_history(self.episode_rewards[i])
                self.done_envs += 1
                self.total_rewards.append(self.episode_rewards[i])
                self.games += 1
                self.episode_rewards[i] = 0
                self.states[i] = env.reset()
            self.steps += 1
        return [np.array(item, np.float32) for item in zip(*observations)]

    def init_from_checkpoint(self):
        """
        Load previous training session metadata and update agent metrics
            to go from there.

        Returns:
            None
        """
        previous_history = pd.read_parquet(self.history_checkpoint)
        expected_columns = {
            'time',
            'mean_reward',
            'best_reward',
            'step',
            'episode_reward',
        }
        assert (
            set(previous_history.columns) == expected_columns
        ), f'Expected the following columns: {expected_columns}, got {set(previous_history.columns)}'
        last_row = previous_history.loc[previous_history['time'].idxmax()]
        self.mean_reward = last_row['mean_reward']
        self.best_reward = previous_history['best_reward'].max()
        history_start_steps = last_row['step']
        history_start_time = last_row['time']
        self.training_start_time = perf_counter() - history_start_time
        self.last_reset_step = self.steps = int(history_start_steps)
        self.total_rewards.append(last_row['episode_reward'])
        self.games = previous_history.shape[0]

    def init_training(self, target_reward, max_steps, monitor_session):
        """
        Initialize training start time, wandb session & models (self.model / self.target_model)
        Args:
            target_reward: Total reward per game value that whenever achieved,
                the training will stop.
            max_steps: Maximum time steps, if exceeded, the training will stop.
            monitor_session: Wandb session name.

        Returns:
            None
        """
        self.target_reward = target_reward
        self.max_steps = max_steps
        if monitor_session:
            wandb.init(name=monitor_session)
        if self.checkpoints:
            self.check_checkpoints()
        self.training_start_time = perf_counter()
        self.last_reset_time = perf_counter()
        if self.history_checkpoint and Path(self.history_checkpoint).exists():
            self.init_from_checkpoint()

    def train_step(self):
        """
        Perform 1 step which controls action_selection, interaction with environments
        in self.envs, batching and gradient updates.

        Returns:
            None
        """
        raise NotImplementedError(
            f'train_step() should be implemented by {self.__class__.__name__} subclasses'
        )

    def get_model_outputs(self, inputs, models, training=True):
        """
        Get single or multiple model outputs.
        Args:
            inputs: Inputs as tensors / numpy arrays that are expected
                by the given model(s).
            models: A tf.keras.Model or a list of tf.keras.Model(s)
            training: `training` parameter passed to model call.

        Returns:
            Outputs as a list in case of multiple models or any other shape
            that is expected from the given model(s).
        """
        if self.img_inputs:
            inputs = tf.cast(inputs, tf.float32) / 255.0
        if isinstance(models, tf.keras.models.Model):
            return models(inputs, training=training)
        elif len(models) == 1:
            return models[0](inputs, training=training)
        return [sub_model(inputs, training=training) for sub_model in models]

    def at_step_start(self):
        """
        Execute steps that will run before self.train_step().

        Returns:
            None
        """
        pass

    def at_step_end(self):
        """
        Execute steps that will run after self.train_step().

        Returns:
            None
        """
        pass

    def get_states(self):
        """
        Get most recent states.

        Returns:
            self.states as numpy array.
        """
        return np.array(self.states)

    def get_dones(self):
        """
        Get most recent game statuses.

        Returns:
            self.dones as numpy array.
        """
        return np.array(self.dones, np.float32)

    @staticmethod
    def concat_step_batches(*args):
        """
        Concatenate n-step batches.
        Args:
            *args: A list of numpy arrays which will be concatenated separately.

        Returns:
            A list of concatenated numpy arrays.
        """
        concatenated = []
        for arg in args:
            if len(arg.shape) == 1:
                arg = np.expand_dims(arg, -1)
            concatenated.append(arg.swapaxes(0, 1).reshape(-1, *arg.shape[2:]))
        return concatenated

    def fit(
        self,
        target_reward=None,
        max_steps=None,
        monitor_session=None,
    ):
        """
        Common training loop shared by subclasses, monitors training status
        and progress, performs all training steps, updates metrics, and logs progress.
        Args:
            target_reward: Target reward, if achieved, the training will stop
            max_steps: Maximum number of steps, if reached the training will stop.
            monitor_session: Session name to use for monitoring the training with wandb.

        Returns:
            None
        """
        assert (
            target_reward or max_steps
        ), '`target_reward` or `max_steps` should be specified when fit() is called'
        self.init_training(target_reward, max_steps, monitor_session)
        while True:
            self.check_episodes()
            if self.training_done():
                break
            self.at_step_start()
            self.train_step()
            self.at_step_end()

    def play(
        self,
        video_dir=None,
        render=False,
        frame_dir=None,
        frame_delay=0.0,
        max_steps=None,
        action_idx=0,
        frame_frequency=1,
    ):
        """
        Play and display a game.
        Args:
            video_dir: Path to directory to save the resulting game video.
            render: If True, the game will be displayed.
            frame_dir: Path to directory to save game frames.
            frame_delay: Delay between rendered frames.
            max_steps: Maximum environment steps.
            action_idx: Index of action output by self.model
            frame_frequency: If frame_dir is specified, save frames every n frames.

        Returns:
            None
        """
        self.reset_envs()
        env_idx = 0
        total_reward = 0
        env_in_use = self.envs[env_idx]
        if video_dir:
            env_in_use = gym.wrappers.Monitor(env_in_use, video_dir)
            env_in_use.reset()
        steps = 0
        agent_id = self.__module__.split('.')[1]
        for dir_name in (video_dir, frame_dir):
            os.makedirs(dir_name or '.', exist_ok=True)
        while True:
            if max_steps and steps >= max_steps:
                self.display_message(f'Maximum steps {max_steps} exceeded')
                break
            if render:
                env_in_use.render()
                sleep(frame_delay)
            if frame_dir and steps % frame_frequency == 0:
                frame = cv2.cvtColor(
                    env_in_use.render(mode="rgb_array"), cv2.COLOR_BGR2RGB
                )
                cv2.imwrite(os.path.join(frame_dir, f'{steps:05d}.jpg'), frame)
            if hasattr(self, 'actor') and agent_id in ['td3', 'ddpg']:
                action = self.actor(self.get_states())[env_idx]
            else:
                action = self.get_model_outputs(
                    self.get_states(), self.output_models, False
                )[action_idx][env_idx].numpy()
            self.states[env_idx], reward, done, _ = env_in_use.step(action)
            total_reward += reward
            if done:
                self.display_message(f'Total reward: {total_reward}')
                break
            steps += 1


class OnPolicy(BaseAgent, ABC):
    """
    Base class for on-policy agents.
    """

    def __init__(self, envs, model, **kwargs):
        """
        Initialize on-policy agent.
        Args:
            envs: A list of gym environments.
            model: tf.keras.models.Model that is expected to be compiled
                with an optimizer before training starts.
            **kwargs: kwargs passed to BaseAgent.
        """
        super(OnPolicy, self).__init__(envs, model, **kwargs)


class OffPolicy(BaseAgent, ABC):
    """
    Base class for off-policy agents.
    """

    def __init__(
        self,
        envs,
        model,
        buffers,
        **kwargs,
    ):
        """
        Initialize off-policy agent.
        Args:
            envs: A list of gym environments.
            model: tf.keras.models.Model that is expected to be compiled
                with an optimizer before training starts.
            buffers: A list of replay buffer objects whose length should match
                `envs`s'.
            **kwargs: kwargs passed to BaseAgent.
        """
        super(OffPolicy, self).__init__(envs, model, **kwargs)
        assert len(envs) == len(buffers), (
            f'Expected equal env and replay buffer sizes, got {self.n_envs} '
            f'and {len(buffers)}'
        )
        self.buffers = buffers

    def fill_buffers(self):
        """
        Fill each buffer in self.buffers up to its initial size.

        Returns:
            None
        """
        total_size = sum(buffer.initial_size for buffer in self.buffers)
        sizes = {}
        for i, env in enumerate(self.envs):
            buffer = self.buffers[i]
            state = self.states[i]
            while buffer.current_size < buffer.initial_size:
                action = env.action_space.sample()
                new_state, reward, done, _ = env.step(action)
                buffer.append(state, action, reward, done, new_state)
                state = new_state
                if done:
                    state = env.reset()
                sizes[i] = buffer.current_size
                filled = sum(sizes.values())
                complete = round((filled / total_size) * 100, self.display_precision)
                self.display_message(
                    f'rFilling replay buffer {i + 1}/{self.n_envs} ==> {complete}% | '
                    f'{filled}/{total_size}',
                    end='',
                )
        self.display_message('')
        self.reset_envs()

    def fit(
        self,
        target_reward=None,
        max_steps=None,
        monitor_session=None,
    ):
        """
        Common training loop shared by subclasses, monitors training status
        and progress, performs all training steps, updates metrics, and logs progress.
        ** Additionally, replay buffers are pre-filled before training starts **
        Args:
            target_reward: Target reward, if achieved, the training will stop
            max_steps: Maximum number of steps, if reached the training will stop.
            monitor_session: Session name to use for monitoring the training with wandb.

        Returns:
            None
        """
        self.fill_buffers()
        super(OffPolicy, self).fit(target_reward, max_steps, monitor_session)

Относительно соответствующего тестового модуля test_base.py Я включу его в отдельный пост, потому что, помимо тестирования базовых агентов, он проверяет дополнительные функции, выходящие за рамки их возможностей, и, конечно же, ограничение на количество символов. Если кто-то заинтересован в этом, дайте мне знать, если у вас возникнут какие-либо вопросы.

0

Добавить комментарий

Ваш адрес email не будет опубликован. Обязательные поля помечены *