Реализация DQN

Я только что написал свой pong DQN. Вроде работает. Я ищу анализ производительности всего, что может замедлить обучение на сложных моделях.

main.py:

import gym
import torch
import torch.optim as optim
import matplotlib.pyplot as plt

from tensorboardX import SummaryWriter

import wrappers
import model
import common


plt.ion()
plt.show()


ENV_NAME = "PongNoFrameskip-v4"
DEVICE = torch.device("cuda")

EPSILON_START = 1.
EPSILON_FRAMES = 100000
EPSILON_FINAL = .02

REPLAY_SIZE = 100000
REPLAY_INITIAL = 10000
TGT_NET_SYNC = 1024

BATCH_SIZE = 64
LEARNING_RATE = 1e-4
GAMMA = .99

EVAL_INTERVAL = 10000


if __name__ == "__main__":
    env = wrappers.wrap_env(gym.make(ENV_NAME))

    net = model.ConvModel(env.observation_space.shape, env.action_space.n).to(DEVICE)
    tgt_net = model.ConvModel(env.observation_space.shape, env.action_space.n).to(DEVICE)
    net_opt = optim.Adam(net.parameters(), lr=LEARNING_RATE)

    agent = common.DQNAgent(env, net, DEVICE, EPSILON_START)
    buffer = common.ReplayBuffer(agent, REPLAY_SIZE)

    test_env = gym.wrappers.Monitor(wrappers.wrap_env(gym.make(ENV_NAME)), "./records", force=True)
    test_agent = common.DQNAgent(test_env, net, DEVICE, 0.)

    writer = SummaryWriter(comment="DQN")

    for batch_idx, batch in enumerate(buffer.iterate(REPLAY_INITIAL, BATCH_SIZE)):
        if not batch_idx % TGT_NET_SYNC:
            tgt_net.load_state_dict(net.state_dict())

        if not batch_idx % EVAL_INTERVAL:
            reward = sum([common.evaluate(test_agent) for _ in range(8)])/8
            writer.add_scalar("reward", reward, batch_idx)

        net_opt.zero_grad()
        loss_v = common.calc_loss(batch, net, tgt_net, GAMMA, DEVICE)
        loss_v.backward()
        net_opt.step()

        agent.epsilon = max(EPSILON_FINAL, EPSILON_START - batch_idx/EPSILON_FRAMES)

        writer.add_scalar("loss", loss_v, batch_idx)
        writer.add_scalar("eps", agent.epsilon, batch_idx)

common.py:

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from collections import deque, namedtuple


Trajectory = namedtuple("Trajectory", ("state",
                                       "action",
                                       "reward",
                                       "next_state",
                                       "done",
                                       "info"))


class DQNAgent(object):
    def __init__(self, env, net, device, epsilon):
        super(DQNAgent, self).__init__()
        self.env = env
        self.net = net
        self.device = device
        self.epsilon = epsilon
        self.done = True

    def __call__(self, obs):
        if np.random.rand() > self.epsilon:
            obs_v = torch.FloatTensor(obs[None, ...]).to(self.device)
            action_vals = self.net(obs_v).squeeze().data.cpu().numpy()
            action = int(np.argmax(action_vals))
        else:
            action = int(self.env.action_space.sample())

        traj_raw = self.env.step(action)
        traj = Trajectory(state=obs, action=action, reward=traj_raw[1],
                          next_state=traj_raw[0], done=traj_raw[2], info=traj_raw[3])
        return traj


class ReplayBuffer(object):
    def __init__(self, agent, maxlen):
        super(ReplayBuffer, self).__init__()
        self.agent = agent
        self.core = deque(maxlen=maxlen)

        self.last_obs = None
        self.done = True

    def populate_once(self):
        if self.done:
            self.last_obs = self.agent.env.reset()
            self.done = False

        traj = self.agent(self.last_obs)
        self.last_obs = traj.next_state
        self.done = traj.done
        self.core.append(traj)

    def populate(self, n):
        for _ in range(n):
            self.populate_once()

    def sample(self, n):
        indices = np.arange(len(self.core))
        np.random.shuffle(indices)
        return [self.core[idx] for idx in indices[:n]]

    def iterate(self, initial, batch_size):
        self.populate(initial)
        while True:
            yield self.sample(batch_size)
            self.populate_once()


def calc_loss(batch, net, tgt_net, gamma, device):
    states, actions, rewards, next_states, dones = [], [], [], [], []
    for traj in batch:
        states.append(np.array(traj.state, dtype=np.float32, copy=False))
        next_states.append(np.array(traj.next_state, dtype=np.float32, copy=False))
        actions.append(int(traj.action))
        rewards.append(float(traj.reward))
        dones.append(bool(traj.done))

    states_v = torch.FloatTensor(np.array(states, copy=False)).to(device)
    next_states_v = torch.FloatTensor(np.array(next_states, copy=False)).to(device)
    actions_t = torch.LongTensor(np.array(actions, dtype=np.int64)).to(device)
    rewards_v = torch.FloatTensor(np.array(rewards, dtype=np.float32)).to(device)

    with torch.no_grad():
        next_states_pred_v = tgt_net(next_states_v).max(dim=1)[0]
        next_states_pred_v[dones] = 0.
        bellman_unroll = (rewards_v + gamma * next_states_pred_v).detach()

    values_pred_v = net(states_v)[range(states_v.shape[0]), actions_t]

    return F.mse_loss(values_pred_v, bellman_unroll)


def evaluate(agent):
    obs = agent.env.reset()
    total_reward = 0.
    while True:
        traj = agent(obs)
        obs = traj.next_state
        total_reward += traj.reward
        if traj.done:
            break
    return total_reward

model.py:

import numpy as np
import torch
import torch.nn as nn


class ConvModel(nn.Module):
    def __init__(self, input_shape, n_actions):
        super(ConvModel, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels=input_shape[0], out_channels=16, kernel_size=(8, 8), stride=(4, 4), padding=(0, 0)), nn.LeakyReLU(.2),
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(4, 4), stride=(2, 2), padding=(0, 0)), nn.LeakyReLU(.2),
            nn.Flatten())

        self.fc = nn.Sequential(
            nn.Linear(in_features=int(np.prod(self.conv(torch.zeros(1, *input_shape)).shape)), out_features=256), nn.LeakyReLU(.2),
            nn.Linear(in_features=256, out_features=n_actions))
    
    def forward(self, x):
        return self.fc(self.conv(x))

wrappers.py:

import gym
import numpy as np
import PIL.Image
from collections import deque


class ActionSkip(gym.Wrapper):
    def __init__(self, env, n_skips=4):
        super(ActionSkip, self).__init__(env)
        self.observation_space = gym.spaces.Box(low=0, high=255, shape=(210, 160, 3), dtype=np.uint8)

        self.n_skips = n_skips

    def step(self, action):
        queue = []
        all_rewards = 0
        for _ in range(self.n_skips):
            obs, reward, done, _ = self.env.step(action)
            queue.append(obs)
            all_rewards += reward
            if done:
                break
        return np.max(np.stack(queue), axis=0), all_rewards, done, {}

    def reset(self, **kwargs):
        return self.env.reset()


class ImageFormat(gym.ObservationWrapper):
    def __init__(self, env):
        super(ImageFormat, self).__init__(env)
        self.observation_space = gym.spaces.Box(low=-1, high=1, shape=(1, 84, 84), dtype=np.float32)

    def observation(self, observation):
        obs_img = PIL.Image.fromarray(observation)
        obs_resize = obs_img.resize((84, 84))
        obs_transpose = np.asarray(obs_resize).mean(axis=-1)[None, ...]
        return obs_transpose / 127.5 - 1.


class ResetEnv(gym.Wrapper):
    def __init__(self, env=None):
        super(ResetEnv, self).__init__(env)
        self.observation_space = gym.spaces.Box(low=-1, high=1, shape=(1, 84, 84), dtype=np.float32)

    def fire(self):
        for _ in range(15):
            obs, r, done, _ = self.env.step(self.env.action_space.sample())
            if done:
                return self.reset()
        return obs

    def step(self, action):
        traj = self.env.step(action)
        return traj

    def reset(self):
        self.env.reset()
        return self.fire()


class FrameStack(gym.ObservationWrapper):
    def __init__(self, env):
        super(FrameStack, self).__init__(env)
        self.queue = deque(maxlen=6)
        self.observation_space = gym.spaces.Box(low=-1, high=1, shape=(6, 84, 84), dtype=np.float32)

    def observation(self, observation):
        self.queue.append(observation)
        obs_stack = np.zeros((6, 84, 84))
        for idx, element in enumerate(self.queue):
            obs_stack[idx] = element[0]
        return obs_stack

    def reset(self, **kwargs):
        self.queue.clear()
        obs = self.env.reset()
        obs_stack = np.zeros((6, 84, 84))
        obs_stack[0] = obs[0]
        return obs_stack


def wrap_env(env):
    return FrameStack(ResetEnv(ImageFormat(ActionSkip(env))))

```

0

Добавить комментарий

Ваш адрес email не будет опубликован. Обязательные поля помечены *