Reinforcement Learning via Sequence Modeling (website)
TL;DR
They ingest sequences of reward, state, action tuples: (Rt−1,St−1,at−1,Rt,St,at) into a transformer model and output the corresponding actions at−1,at in autoregressive fashion.
- Conventional RL relies on Bellman backups. Instead, they model trajectories with sequence modeling and use transformers. To illustrate this, they use offline reinforcement learning, where a model is trained from a fixed dataset rather than collecting experience in the environment, in order to train RL policies using the same code as a language model.
- each modality (return, state, or action) is passed into an embedding network (convolutional encoder for images, linear layer for continuous states). The embeddings are then processed by an autoregressive transformer model, trained to predict the next action given the previous tokens using a linear output layer.
- Evaluation: we can initialize by a desired target return (e.g. 1 or 0 for success or failure) and the starting state in the environment. Unrolling the sequence -- similar to standard autoregressive generation in language models -- yields a sequence of actions to execute in the environment.
Q: is each action executed in the environment and the new one condition on the reward? The difference between expected reward is computed?
- Decision Transformer can match the performance of well-studied and specialized TD learning algorithms developed for Atari, OpenAI gym and minigrid key-to-door.
- conditional generation: they initialize a trajectory by setting the desired return as input. The decision transformer does not yield a single policy; rather, it models a wide distribution of policies. Plotting target return vs achieved return shows that the model is relatively well calibrated and learns distinct policies that can match the target.
Example: shortest path
Task of finding the shortest path on a fixed graph is posed as a reinforcement learning problem (accumulated reward = sum of edge weights).
In a training dataset consisting of random walks, we observe many suboptimal trajectories. If we train Decision Transformer on these sequences, we can ask the model to generate an optimal path by conditioning on a large return (the "target" return is passed as input to the model). We find that by training on only random walks, Decision Transformer can learn to stitch together subsequences from different training trajectories in order to produce optimal trajectories at test time.
This is the same behavior which is desired from off-policy Q-learning algorithms commonly used in offline reinforcement learning frameworks. However, without needing to introduce TD learning algorithms, value pessimism, or behavior regularization , we can achieve the same behavior using a sequence modeling framework.