# TreeQN Code https://github.com/oxwhirl/treeqn (Pytorch) TreeQN is a differentiable, recursive, tree-structured model that serves as a drop-in replacement for any value function network in deep RL with discrete actions. ## The problem - A promising approach to improving model-free deep reinforcement learning (RL) is to combine it with on-line planning. One strategy for on-line planning is to use look-ahead tree search with simulator or environment models. - A simple approach to learning environment models is to maximise a similarity metric between model predictions and ground truth in the observation space. - However, this objective causes significant model capacity to be devoted to predicting irrelevant aspects of the environment dynamics, such as noisy backgrounds, at the expense of value-critical features that may occupy only a small part of the observation space. - Another strategy is to train a model such that, when it is used to predict a value function, the error in those predictions is minimised. Doing so can encourage the model to focus on features of the observations that are relevant for the control task. - Example is [[The Predictron]] where the model is used to aid policy evaluation without addressing control. - [[Value Prediction Network]] take a similar approach but use the model to construct a look-ahead tree only when constructing bootstrap targets and selecting actions. The model is not embedded in a planning algorithm during optimisation. ## The solution - By formulating the tree look-ahead in a differentiable way and integrating it directly into the Q- function or policy, TreeQN trains the entire agent, including its learned transition model, end-to-end. This ensures that the model is optimised for the correct goal and is suitable for on-line planning during execution of the policy. - TreeQN, encodes an inductive bias based on the prior knowledge that the environment is a stationary Markov process, which facilitates faster learning of better policies. ![[treeqn.jpg]] ## The details - Uses a version of n-step Q-learning ([[Deep Q-Learning]], but can be used with other algorithms like [[Soft Actor-Critic]]) with synchronous environment threads. In particular, starting at a timestep $t$, roll forward $n_{\mathrm{env}}=16$ threads for $n=5$ timesteps each. Then bootstrap off the final states only and gather all $n_{\text {env }} \times n=80$ transitions in a single batch for the backward loss, minimizing the loss: $ \mathcal{L}_{\text {nstep- } Q}=\sum_{\text {envs }} \sum_{j=1}^{n}\left(\sum_{k=1}^{j}\left[\gamma^{j-k} r_{t+n-k}\right]+\gamma^{j} \max _{a^{\prime}} Q\left(\mathbf{s}_{t+n}, a^{\prime}, \theta^{-}\right)-Q\left(\mathbf{s}_{t+n-j}, a_{t+n-j}, \theta\right)\right)^{2} $ - Instead of directly estimating the $Q(s_t, a)$ value from the current encoded state $z_t$, TreeQN uses a recursive tree-structured neural network between the encoded state $z_t$ and the predicted state-action values. - It learns action-dependent transition function that, given a state representation $\mathbf{z}_{l \mid t},$ predicts the next state representation $\mathbf{z}_{l+1 \mid t}^{a_{i}}$ for action $a_{i} \in \mathcal{A},$ and the corresponding reward $\hat{r}_{l \mid t}^{a_{i}}$. The $\mathrm{z}_{l \mid t}$ denotes the encoded state at time $t$ after $l$ internal transitions, starting with $\mathrm{z}_{0 \mid t}$ for the encoding of $\mathrm{s}_{t}$. This transition function is recursively applied to construct a tree containing the state representations and rewards received for all possible sequences of actions up to some predefined depth $d$. - The value of each predicted state is estimated as $ \begin{aligned} Q^{l}\left(\mathbf{z}_{l \mid t}, a_{i}\right) &=r\left(\mathbf{z}_{l \mid t}, a_{i}\right)+\gamma V^{(\lambda)}\left(\mathbf{z}_{l+1 \mid t}\right) \\ V^{(\lambda)}\left(\mathbf{z}_{l \mid t}\right) &=\left\{\begin{array}{ll} V\left(\mathbf{z}_{l \mid t}^{a_{i}}\right) & l=d \\ (1-\lambda) V\left(\mathbf{z}_{l \mid t}^{a_{i}}\right)+\lambda \mathbf{b}\left(Q^{l+1}\left(\mathbf{z}_{l+1 \mid t}^{a_{i}}, a_{j}\right)\right) & l<d \end{array}\right. \end{aligned} $ where b is a function to recursively perform the backup. - TreeQN imposes a significant structure on the value function by decomposing it as a sum of action-conditional reward and next-state value, and using a shared value function to evaluate each next-state representation. - Encoder function - A series of convolutional layers produces an embedding of the observed state, $\mathrm{z}_{0 \mid t}=\text{encode}\left(\mathbf{s}_{t}\right)$ - Transition function - First apply a single fully connected layer to the current state embedding, shared by all actions, $\mathbf{z}_{l+1 \mid t}^{\mathrm{env}}=\mathbf{z}_{l \mid t}+\tanh \left(\boldsymbol{W}^{\mathrm{env}} \mathbf{z}_{l \mid t}+\mathbf{b}^{\mathrm{env}}\right)$. This generates an intermediate representation $\left(\mathrm{z}_{l+1 \mid t}^{\mathrm{env}}\right)$ that could carry information about action-agnostic changes to the environment. - Another fully connected layer per action, which is applied to the intermediate representation to calculate a next-state representation that carries information about the effect of taking action $a_i$ i.e. $\mathbf{z}_{l+1 \mid t}^{a_{i}}=\mathbf{z}_{l+1 \mid t}^{\mathrm{env}}+\tanh \left(\boldsymbol{W}^{a_{i}} \mathbf{z}_{l+1 \mid t}^{\mathrm{env}}\right)$. - Reward function - In addition to predicting the next state, also predict the immediate reward for every action $a_{i} \in \mathcal{A}$ in state $\mathbf{z}_{l \mid t}$ using $ \hat{\mathbf{r}}\left(\mathbf{z}_{l \mid t}\right)=\boldsymbol{W}_{2}^{r} \operatorname{ReLU}\left(\boldsymbol{W}_{1}^{r} \mathbf{z}_{l \mid t}+\mathbf{b}_{1}^{r}\right)+\mathbf{b}_{2}^{r} $ - Value function - The value of a state representation $\mathrm{z}$ is estimated as $ V(\mathbf{z})=\mathbf{w}^{\top} \mathbf{z}+b $ - Backup function - We use the following function that can be recursively applied to calculate the tree backup $\mathrm{b}(\mathbf{x})=\sum_{i} x_{i} \operatorname{softmax}(\mathbf{x})_{i}$ - All the components are learned jointly and ensures that they are useful for planning on-line. - Auxiliary objectives based on minimising the error in predicting rewards or observations could improve the performance by helping to ground the transition and reward functions to the environment. - TreeQN without auxiliary objectives (reward/state ground regularization) can be seen as a model-free approach that draws inspiration from tree-search planning to encode valuable inductive biases into the neural network architecture. At the other extreme, perfect, grounded reward and transition models could in principle be learned. Using them in our architecture would then correspond to standard model-based lookahead planning. ## The results - TreeQN have clear advantage over DQN in all environments tested in the paper except Frostbite. - TreeQN also substantially speeds up learning. This is believed to be brought by regularization effect due to greater structure. - --- ## References 1. ICLR 2018 paper https://arxiv.org/pdf/1710.11417.pdf