Я только что написал свой pong DQN. Вроде работает. Я ищу анализ производительности всего, что может замедлить обучение на сложных моделях.
main.py:
import gym
import torch
import torch.optim as optim
import matplotlib.pyplot as plt
from tensorboardX import SummaryWriter
import wrappers
import model
import common
plt.ion()
plt.show()
ENV_NAME = "PongNoFrameskip-v4"
DEVICE = torch.device("cuda")
EPSILON_START = 1.
EPSILON_FRAMES = 100000
EPSILON_FINAL = .02
REPLAY_SIZE = 100000
REPLAY_INITIAL = 10000
TGT_NET_SYNC = 1024
BATCH_SIZE = 64
LEARNING_RATE = 1e-4
GAMMA = .99
EVAL_INTERVAL = 10000
if __name__ == "__main__":
env = wrappers.wrap_env(gym.make(ENV_NAME))
net = model.ConvModel(env.observation_space.shape, env.action_space.n).to(DEVICE)
tgt_net = model.ConvModel(env.observation_space.shape, env.action_space.n).to(DEVICE)
net_opt = optim.Adam(net.parameters(), lr=LEARNING_RATE)
agent = common.DQNAgent(env, net, DEVICE, EPSILON_START)
buffer = common.ReplayBuffer(agent, REPLAY_SIZE)
test_env = gym.wrappers.Monitor(wrappers.wrap_env(gym.make(ENV_NAME)), "./records", force=True)
test_agent = common.DQNAgent(test_env, net, DEVICE, 0.)
writer = SummaryWriter(comment="DQN")
for batch_idx, batch in enumerate(buffer.iterate(REPLAY_INITIAL, BATCH_SIZE)):
if not batch_idx % TGT_NET_SYNC:
tgt_net.load_state_dict(net.state_dict())
if not batch_idx % EVAL_INTERVAL:
reward = sum([common.evaluate(test_agent) for _ in range(8)])/8
writer.add_scalar("reward", reward, batch_idx)
net_opt.zero_grad()
loss_v = common.calc_loss(batch, net, tgt_net, GAMMA, DEVICE)
loss_v.backward()
net_opt.step()
agent.epsilon = max(EPSILON_FINAL, EPSILON_START - batch_idx/EPSILON_FRAMES)
writer.add_scalar("loss", loss_v, batch_idx)
writer.add_scalar("eps", agent.epsilon, batch_idx)
common.py:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from collections import deque, namedtuple
Trajectory = namedtuple("Trajectory", ("state",
"action",
"reward",
"next_state",
"done",
"info"))
class DQNAgent(object):
def __init__(self, env, net, device, epsilon):
super(DQNAgent, self).__init__()
self.env = env
self.net = net
self.device = device
self.epsilon = epsilon
self.done = True
def __call__(self, obs):
if np.random.rand() > self.epsilon:
obs_v = torch.FloatTensor(obs[None, ...]).to(self.device)
action_vals = self.net(obs_v).squeeze().data.cpu().numpy()
action = int(np.argmax(action_vals))
else:
action = int(self.env.action_space.sample())
traj_raw = self.env.step(action)
traj = Trajectory(state=obs, action=action, reward=traj_raw[1],
next_state=traj_raw[0], done=traj_raw[2], info=traj_raw[3])
return traj
class ReplayBuffer(object):
def __init__(self, agent, maxlen):
super(ReplayBuffer, self).__init__()
self.agent = agent
self.core = deque(maxlen=maxlen)
self.last_obs = None
self.done = True
def populate_once(self):
if self.done:
self.last_obs = self.agent.env.reset()
self.done = False
traj = self.agent(self.last_obs)
self.last_obs = traj.next_state
self.done = traj.done
self.core.append(traj)
def populate(self, n):
for _ in range(n):
self.populate_once()
def sample(self, n):
indices = np.arange(len(self.core))
np.random.shuffle(indices)
return [self.core[idx] for idx in indices[:n]]
def iterate(self, initial, batch_size):
self.populate(initial)
while True:
yield self.sample(batch_size)
self.populate_once()
def calc_loss(batch, net, tgt_net, gamma, device):
states, actions, rewards, next_states, dones = [], [], [], [], []
for traj in batch:
states.append(np.array(traj.state, dtype=np.float32, copy=False))
next_states.append(np.array(traj.next_state, dtype=np.float32, copy=False))
actions.append(int(traj.action))
rewards.append(float(traj.reward))
dones.append(bool(traj.done))
states_v = torch.FloatTensor(np.array(states, copy=False)).to(device)
next_states_v = torch.FloatTensor(np.array(next_states, copy=False)).to(device)
actions_t = torch.LongTensor(np.array(actions, dtype=np.int64)).to(device)
rewards_v = torch.FloatTensor(np.array(rewards, dtype=np.float32)).to(device)
with torch.no_grad():
next_states_pred_v = tgt_net(next_states_v).max(dim=1)[0]
next_states_pred_v[dones] = 0.
bellman_unroll = (rewards_v + gamma * next_states_pred_v).detach()
values_pred_v = net(states_v)[range(states_v.shape[0]), actions_t]
return F.mse_loss(values_pred_v, bellman_unroll)
def evaluate(agent):
obs = agent.env.reset()
total_reward = 0.
while True:
traj = agent(obs)
obs = traj.next_state
total_reward += traj.reward
if traj.done:
break
return total_reward
model.py:
import numpy as np
import torch
import torch.nn as nn
class ConvModel(nn.Module):
def __init__(self, input_shape, n_actions):
super(ConvModel, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels=input_shape[0], out_channels=16, kernel_size=(8, 8), stride=(4, 4), padding=(0, 0)), nn.LeakyReLU(.2),
nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(4, 4), stride=(2, 2), padding=(0, 0)), nn.LeakyReLU(.2),
nn.Flatten())
self.fc = nn.Sequential(
nn.Linear(in_features=int(np.prod(self.conv(torch.zeros(1, *input_shape)).shape)), out_features=256), nn.LeakyReLU(.2),
nn.Linear(in_features=256, out_features=n_actions))
def forward(self, x):
return self.fc(self.conv(x))
wrappers.py:
import gym
import numpy as np
import PIL.Image
from collections import deque
class ActionSkip(gym.Wrapper):
def __init__(self, env, n_skips=4):
super(ActionSkip, self).__init__(env)
self.observation_space = gym.spaces.Box(low=0, high=255, shape=(210, 160, 3), dtype=np.uint8)
self.n_skips = n_skips
def step(self, action):
queue = []
all_rewards = 0
for _ in range(self.n_skips):
obs, reward, done, _ = self.env.step(action)
queue.append(obs)
all_rewards += reward
if done:
break
return np.max(np.stack(queue), axis=0), all_rewards, done, {}
def reset(self, **kwargs):
return self.env.reset()
class ImageFormat(gym.ObservationWrapper):
def __init__(self, env):
super(ImageFormat, self).__init__(env)
self.observation_space = gym.spaces.Box(low=-1, high=1, shape=(1, 84, 84), dtype=np.float32)
def observation(self, observation):
obs_img = PIL.Image.fromarray(observation)
obs_resize = obs_img.resize((84, 84))
obs_transpose = np.asarray(obs_resize).mean(axis=-1)[None, ...]
return obs_transpose / 127.5 - 1.
class ResetEnv(gym.Wrapper):
def __init__(self, env=None):
super(ResetEnv, self).__init__(env)
self.observation_space = gym.spaces.Box(low=-1, high=1, shape=(1, 84, 84), dtype=np.float32)
def fire(self):
for _ in range(15):
obs, r, done, _ = self.env.step(self.env.action_space.sample())
if done:
return self.reset()
return obs
def step(self, action):
traj = self.env.step(action)
return traj
def reset(self):
self.env.reset()
return self.fire()
class FrameStack(gym.ObservationWrapper):
def __init__(self, env):
super(FrameStack, self).__init__(env)
self.queue = deque(maxlen=6)
self.observation_space = gym.spaces.Box(low=-1, high=1, shape=(6, 84, 84), dtype=np.float32)
def observation(self, observation):
self.queue.append(observation)
obs_stack = np.zeros((6, 84, 84))
for idx, element in enumerate(self.queue):
obs_stack[idx] = element[0]
return obs_stack
def reset(self, **kwargs):
self.queue.clear()
obs = self.env.reset()
obs_stack = np.zeros((6, 84, 84))
obs_stack[0] = obs[0]
return obs_stack
def wrap_env(env):
return FrameStack(ResetEnv(ImageFormat(ActionSkip(env))))
```