VOOZH about

URL: https://towardsdatascience.com/a-minimal-working-example-for-deep-q-learning-in-tensorflow-2-0-e0ca8a944d5e/

⇱ A Minimal Working Example for Deep Q-Learning in TensorFlow 2.0 | Towards Data Science


A Minimal Working Example for Deep Q-Learning in TensorFlow 2.0

A multi-armed bandit example to train a Q-network. The update procedure takes just a few lines of code using TensorFlow

8 min read
👁 Deep, just like Deep Q-learning. Photo by Kris Mikael Krister on Unsplash
Deep, just like Deep Q-learning. Photo by Kris Mikael Krister on Unsplash

Deep Q-learning is a staple in the arsenal of any Reinforcement Learning (RL) practitioner. It neatly circumvents some shortcomings of traditional Q-learning, and leverages the power of neural network for complex value function approximations.

This article shows how to implement and train a deep Q-network in TensorFlow 2.0, illustrated at the hand of the multi-armed bandit problem (a terminating one-shot game). Some extensions towards temporal difference learning are provided as well. I take the ‘minimal’ in minimal working example quite literal though, so the focus is really on a first-ever implementation of deep Q-learning.

Some background

Before diving into deep learning, I assume you are already familiar with both vanilla Q-learning and artificial neural networks. Without those basics, trying your hand at deep Q-learning will likely be a frustrating experience. The following update mechanism should hold no secrets to you:

👁 Update function for Q-learning [1]
Update function for Q-learning [1]

Traditional Q-learning explicitly stores a Q-value – essentially an estimate for the cumulative discounted reward – for each state-action pair in a lookup table. When taking an action in a particular state, the observed reward improves the value estimate. The size of the lookup table is |S|×|A|, where S is the state space and A the action space. Q-learning tends to work well for toy-sized problems, but falls apart for larger ones. Typically, it is not possible to observe anywhere near all state-action pairs.

👁 Example of Q-learning table for moving on a 16 tile grid. In this case, there are 16*4=64 state-action pairs for which a value Q(s,a) should be learned. [image by author]
Example of Q-learning table for moving on a 16 tile grid. In this case, there are 16*4=64 state-action pairs for which a value Q(s,a) should be learned. [image by author]

In contrast to vanilla Q-learning, deep Q-learning takes the state as input, passes it through a number of neural network layers, and outputs the Q-value per action. The deep Q-network can be viewed as a function f:s→[Q(s,a)]∀ a ∈ A. By adopting a single representation for all states, deep Q-learning is able to handle large state spaces. It presupposes a reasonable number of actions though, as each action is represented by a node in the output layer (size |A|_).

👁 Example of a deep Q-network. In this example, the input is a one-hot encoding of the grid (16 tiles), whereas the output represents the Q-value for each of the four actions. [image by author]
Example of a deep Q-network. In this example, the input is a one-hot encoding of the grid (16 tiles), whereas the output represents the Q-value for each of the four actions. [image by author]

After passing through the network and obtaining Q-values for all actions, we continue as usual. To balance exploration and exploitation, we utilize a basic ϵ-greedy policy. With probability 1-ϵ we select the best action (an argmax operation on the output layer), with probability ϵ we sample a random action.

TensorFlow 2.0 implementation

Defining a Q-network in TensorFlow is not hard. The input dimension is equal to the length of the vector state, the output dimension is equal to the number of actions (if the set of feasible actions is state-dependent, a mask can be applied). A Q-network is a fairly straightforward neural network:

Weight updates are largely handled for you as well, yet you must provide a loss value to the optimizer. The loss represents the error between observation and expectation; a differentiable loss function is needed to properly perform the update. For deep Q-learning, the loss function is typically a simple mean squared error. This is actually a built-in loss function (loss='mse') in TensorFlow, but we will use the GradientTape functionality here, tracing all your operations to compute and apply the gradients[2]. It offers more flexibility and stays close to the underlying mathematics, which is often beneficial when moving towards more complicated RL applications.

The mean-squared loss function (observe the similarity with the update mechanism mentioned earlier) is denoted as follows:

👁 Mean squared error (MSE) loss function for Deep Q-learning
Mean squared error (MSE) loss function for Deep Q-learning

The generic TensorFlow implementation of the Deep Q-learning approach is as follows (the GradientTape is doing its magic underwater):

Multi-armed bandit

The multi-armed bandit problem is a classic in RL[3]. It defines a number of slot machines: every machine i has a mean payoff _μi and a standard deviation _σi. Every decision moment, you play a machine and observe the resulting reward. When played often enough, you can estimate the mean reward of each machine. It goes without saying that the best policy is playing the slot machine with the highest average payoff.

Let’s put our Q-learning network example into action (full Github code here). We define a straightforward neural network with three fully connected 10 node hidden layers. As input we use a tensor with value 1 (representing a fixed state) as input, and four nodes (representing the Q-value of each machine) as output. The network weights are initialized such that all Q-values are 0 initially. For the weight updates, we use the Adam optimizer with a 0.001 learning rate.

Some illustrative results (after 10,000 iterations) are shown in the figure below. The tradeoff between exploration and exploitation can be clearly observed, especially when not exploring at all. Note that the results are not overly accurate; vanilla Q-learning actually performs better for problems like this.

👁 Q-values and true values for multi-armed bandit problem. Results after 10,000 iterations for ϵ=0.0 (top left), ϵ=0.01 (top right), ϵ=0.1 (bottom left) and ϵ=1.0 (bottom right). Less exploration yields closer approximation for the perceived best action. [image by author]
Q-values and true values for multi-armed bandit problem. Results after 10,000 iterations for ϵ=0.0 (top left), ϵ=0.01 (top right), ϵ=0.1 (bottom left) and ϵ=1.0 (bottom right). Less exploration yields closer approximation for the perceived best action. [image by author]

Temporal difference learning

The multi-armed bandit definitely is a minimal working example, but only treats the terminal case where we don’t look beyond the direct reward. Let’s see how we handle the non-terminal case as well. In this case, we deploy temporal difference learning – we use Q(s’,a’) to update Q(s,a).

Obtaining the Q-value corresponding to the next state s’ is not hard per se. You simply insert s’ into the Q-network, and out rolls the set of Q-values. Pick the maximum – always, as this is Q-learning rather than SARSA – and use it to compute the loss function:

next_q_values = tf.stop_gradient(q_network(next_state)) 
next_q_value = np.max(next_q_values[0])

Note that the Q-network is called within a stop_gradient operator[4]. Remind that the GradientTape tracks all operations, and as such would also perform (nonsensical) updates using the next_state input. With the stop_gradient operator, we safely utilize the Q-values corresponding to next state s’, without worrying about erroneous updates!

Some implementation notes

Although the method outlined above can in principle be directly applied to any RL problem, you’ll often find performance quite disappointing. Even for basic problems, don’t be surprised if your vanilla Q-learning implementation outperforms your fancy deep Q-network. In general, neural networks need many observations to learn something, and some level of detail is inherently lost by training a single network for all states that may be encountered.

Aside from good neural network practices (e.g., normalization, one-hot encoding, proper weight initialization), the following adjustments may strongly improve the quality of your algorithm[5]:

  • Mini batches: Rather than updating the network after every single observation, update the Q-network using batches of observations. Stability is often improved by training for multiple observations. The losses per observation are simply averaged. The tf.one_hot mask can be used to update for multiple actions.
  • Experience replay: Build a buffer of prior observations (stored as s,a,r,s’ tuples), sample one (or more, when using mini batches) from the buffer, and plug into the Q-network. The main benefit of this approach is that it removes correlations in the data.
  • Target network: Create a copy of the neural network that is updated only periodically (say every 100 updates). The target network is used to compute Q(s’,a’), whereas the original network is used to determine Q(s,a). This procedure typically produces more stable updates.

Takeaways

  • A deep Q-network is a straightforward neural network, taking the state vector as input and outputting Q-values corresponding to each action. By using a single representation for all states, it can handle much larger state spaces than vanilla Q-learning (which uses a lookup table).
  • TensorFlow’s GradientTape can be used to update the Q-network. The corresponding loss function is a mean squared error that is close to the original Q-learning update mechanism.
  • In temporal difference learning, the estimate for Q(s,a) is updated based on Q(s’,a’). The stop_gradient operator ensures that the gradients corresponding to Q(s’,a’) are ignored.
  • Deep Q-learning comes with some implementation challenges. Don’t be alarmed if vanilla Q-learning actually performs better, especially for toy-sized problems.

_The GitHub code for the minimal working example using multi-armed bandits can be found here._

Want to stabilize your deep Q-learning algorithm? The following article might interest you:

How To Model Experience Replay, Batch Learning and Target Networks

Looking to implement policy gradient methods instead? Please check my articles with minimal working examples for the continuous and discrete case:

A Minimal Working Example for Continuous Policy Gradients in TensorFlow 2.0

A Minimal Working Example for Discrete Policy Gradients in TensorFlow 2.0

References

[1]Sutton, Richard S., and Andrew G. Barto. Reinforcement learning: An introduction. MIT press, 2018.

[2] Rosebrock, A. (2020) Using TensorFlow and GradientTape to train a Keras model. https://www.tensorflow.org/api_docs/python/tf/GradientTape

[3] Ryzhov, I. O., Frazier, P. I., and Powell, W. B. (2010). On the robustness of a one-period look-ahead policy in multi-armed bandit problems. Procedia Computer Science, 1(1):1635{1644.

[4]TensorFlow (2021). Obtained 26 July 2021 from https://www.tensorflow.org/api_docs/python/tf/stop_gradient

[5] Wikipedia Contributors (2021) Deep Q-learning. Obtained 26 July 2021 from https://en.wikipedia.org/wiki/Q-learning#Deep_Q-learning


Written By

Wouter van Heeswijk, PhD

Towards Data Science is a community publication. Submit your insights to reach our global audience and earn through the TDS Author Payment Program.

Write for TDS

Related Articles