How To Model Experience Replay, Batch Learning and Target Networks
A quick tutorial on three essential tricks for stable and successful Deep Q-learning, using TensorFlow 2.0
If you believe Deep Q-learning is simply a matter of replacing a lookup table with a neural network, you might be in for a rough awakening. Although Deep Q-learning allows handling very large state spaces and complicated non-linear environments, these benefits come at a substantial cost.
For this article, I assume you’re already somewhat familiar with both Q-learning and Deep Q-learning. The update function and loss function below will have to suffice to set the scene. In the remainder, I will zoom in specifically on stability issues and three often-used techniques to mitigate these issues: experience replay, batch learning and target networks.
Stability of (deep) Q-learning
To some degree, stability is a problem in every learning task. Nevertheless, vanilla Q-learning is fairly stable. When observing the rewards corresponding to some state-action pair (s,a), only the corresponding Q-value Q(s,a) is updated. All other Q-values in the lookup table remain the same.
In contrast, a Q-network can be seen as a parameterized function f_θ:s →[Q(s,a)]_a∈A, mapping a state to a vector of Q-values. The key difference here is that a single update changes all Q-values, for every state-action pair. The implications are quite far-reaching, somewhat exacerbated by the non-linear representation that is a neural network (sensitivity to outliers etc.) Even for seemingly straightforward problems, deep Q-learning is often plagued by stability issues.
Experience replay
Reinforcement learning entails making sequential decisions. This often means that subsequent states are closely related (e.g., a single step in a maze, a one-day update in stock price) and therefore rather similar. As a result, sequential observations tend to be highly correlated, which may lead to overfitting the network (e.g., in a suboptimal area of a maze).
However, even though experience is gained sequentially, there is no reason learning should follow the same sequence. Experience replay separates both processes by creating a replay buffer with past observations. Specifically, the replay buffer stores each s,a,r,s' tuple we encounter. Note that the corresponding Q-values are not stored; we determine them at the moment we sample the observation for updating purposes. Concretely, the learning procedure looks as follows:
- Sample random
s,a,r,s'tuple from replay buffer. - Feed
sinto Q-network to obtainQ_t(s,a), using stored actiona. - Feed
s'into Q-network to obtainQ_t+1(s',a*), wherea*∈Ais the optimal action (according to the prevailing Q-values) in states'Recall that Q-learning is off-policy, so we don’t use the actual trajectory. - Use the difference between
Q_t(s,a)andr_t+Q_t+1(s',a*)to compute the loss needed for updating the network.
In addition to breaking correlation and combatting overfitting, a theoretical benefit is that the data is now closer to i.i.d. data, which is typically assumed in supervised learning convergence proofs.
The Python implementation would look something like below. Note that all we do is store s,a,r,s' in the buffer during the experience collection phase and randomly sample them during the learning phase. For the latter, we use the convenient random.choices functionality.
"""Experience replay implementation"""
# Initialize replay buffer
replay_buffer = []
...
# Experience collection phase
# Set state
state = next_state
# Determine action (epsilon_greedy)
if epsilon<= 0.05:
# Select random action
action = np.random.choice(action_dim)
else:
# Select action with highest q-value
action = np.argmax(q_values[0])
# Compute and store reward
reward = get_reward(state, action)
# Determine next state
next_state = get_state(state, action)
# Store observation in replay buffer
observation = (state, action, reward, next_state)
replay_buffer.append(observation)
...
# Learning phase
# Select random sample from replay buffer
if len(replay_buffer) >= min_buffer_size:
observations = random.choices(replay_buffer, k=1)
Storing every past observation and sampling completely at random might not be ideal. Indeed, the procedure could be refined. With prioritized replay we more often sample experiences from which we expect to learn much. Another common technique is to update the replay buffer, deleting older observations. After all, you don’t want to keep dwelling on past observations from regions of the state space that we never should have visited in the first place. Naturally, improvements like these introduce additional modeling challenges.
Batch learning
Updating the Q-network a single observation at a time may not be a great solution. In many cases, such an observation might not contain much useful information – think of a single step in a maze. Worse, the observation might be an outlier not representative for the problem as a whole, yet the update may disastrously impact future decision-making. Ideally, we would like each update to be representative for the problem as a whole.
On the other end of the spectrum, we might perform all training iterations and use the complete batch of observations to fit the Q-network with a single update. Although such a batch would be representative indeed, all observations would be made with our initial (likely very poor) policy, such that we never learn the Q-values corresponding to a good policy.
Thus, large batches are not very useful either. We want to intertwine observing and updating to gradually improve our policy. This does not mean we have to update every observation though. The obvious compromise are mini-batches, meaning that we frequently update our network with a relatively small number of observations. Combined with experience replay, this is a powerful technique to get stable updates based on a vast pool of previous observations.
Being an extension of the basic experience replay, we may still struggle to obtain representative samples. The main implementation questions are the update frequency (updates naturally take longer than for single observations) and the size of the mini-batches.
"""Batch learning implementation"""
no_observations = 100
mini_batch_size = 10
loss_value = 0
if len(replay_buffer) >= no_observations and
i % update_frequency == 0:
# Randomly sample k observations from buffer
observations = random.choices(replay_buffer, k=mini_batch_size)
# Loop over sampled observations
for observation in observations:
# Determine Q-value at time t
q_values = q_network(state)
expected_value = q_values[0, action]
# Determine Q-value at time t+1
next_q_values = tf.stop_gradient(q_network(next_state))
next_action = np.argmax(next_q_values[0])
next_q_value = next_q_values[0, next_action]
# Add direct reward to obtain target value
target_value = reward + (gamma * next_q_value)
# Compute loss value
loss_value += mse_loss(expected_value, target_value)
# Compute mean loss value
loss_value /= batch_size
Target networks
By randomly sampling from past observations (experience replay), we tried to break correlations between observations. Note, however, that the observation tuple contains two closely related states – s and s' – which are fed to the same Q-network to obtain the Q-values. In other words, the expectation and target are correlated as well. Every network update also modifies the target, i.e., we are chasing a moving target.
To reduce the correlation between expectation Q(s,a) and target r+Q(s',a'), we may use a different network to determine Q(s',a') . We call this the target network -our target is based on Q^T(s',a') instead of Q(s',a'). We can use TensorFlow’s clone_model command to copy the network architecture of the original Q-network. Note that this cloning procedure does not copy the weights, for this we use the set_weights command. Using get_weights, we periodically obtain the most recent weights from the Q-network.
"""Target network implementation"""
# Copy network architecture
target_network = tf.keras.models.clone_model(q_network)
# Copy network weights
target_network.set_weights(q_network.get_weights())
...
# Periodically update target network
if episode % update_frequency_target_network == 0:
target_network.set_weights(q_network.get_weights())
The key challenge is to find the right update frequency. If updates are far and between, the target may correspond to an underperforming policy of the past. Too frequently, and correlation between target and expectation remains high. As with the other two techniques, the solution resolves a problem, yet also introduces additional complexity.
Takeaways
- Experience replay stores all observations –
s,a,r,s'tuples – in a buffer, from which random samples can be selected. This breaks the correlation that is often present in sequential observations. - Batch learning performs network updates based on multiple observations. This approach tends to yield more stable updates than single observations, as the losses better represent the overall problem.
- Target networks reduce correlation between expectation
Q(s,a)and targetr+Q^T(s',a'). A target network is no more than a periodic copy of the Q-network. - Each technique introduces new modelling challenges and parameters to tune. In part, this is what makes deep learning so difficult; every solution creates new obstacles and adds to the complexity of the model.
Interested in Deep Q-learning? You might also be interested in the following article:
A Minimal Working Example for Deep Q-Learning in TensorFlow 2.0
Want to know more about the basics of Q-learning (and SARSA)? Check out the article below:
Walking Off The Cliff With Off-Policy Reinforcement Learning
References
Matiisen, Tambet (2015). Demystifying deep reinforcement learning. Computational Neuroscience Lab. Retrieved from neuro.cs.ut.ee/demystifying-deep-reinforcement-learning/
Mnih, V., Kavukcuoglu, K., Silver, D., Rusu, A. A., Veness, J., Bellemare, M. G., … & Hassabis, D. (2015). Human-level control through deep reinforcement learning. Nature, 518(7540), 529–533.
Share This Article
Towards Data Science is a community publication. Submit your insights to reach our global audience and earn through the TDS Author Payment Program.
Write for TDS