Training an LLM (In a Nutshell!)

A large language model (LLM) learns how to reply to a conversation by learning how to predict the next token (akin to a word), given the preceding conversation. This is framed as a multi-class classification problem, where each token represents a different class. The LLM outputs the likelihood of each token in the vocabulary (i.e. set of all possible tokens) of being the next token (which represent the probability parameters of a categorical distribution). Prediction of the next token occurs by sampling from the categorical distribution of all possible tokens. After a token is predicted, it is used along with tokens from the preceding conversation to predict the next token.

Training an LLM typically consists of three main stages (different LLMs may have different training schemes):

  1. Unsupervised pre-training
  2. Supervised fine-tuning
  3. Reinforcement learning

Stages 2 and 3 are commonly termed collectively as the fine-tuning stage.

Unsupervised pre-training

In the unsupervised pre-training stage, massive amounts of text data scraped from the Internet (e.g. GitHub, Wikipedia, ...) is used to train the LLM to predict the next token based on a fixed-length context window of tokens before it. This creates a base model which is able to generate text when prompted. However, there is no guarantee that the text generated would be an answer to the prompt, as the LLM only predicts the most probable words that will come after the prompt.

Supervised fine-tuning

In order for the LLM to be able to respond to prompts in the style of a chat assistant, supervised fine-tuning (SFT) of the LLM has to be performed. In SFT, a dataset is created by human labellers consisting of example conversations between an assistant and a user. Special tokens such as <|user|> and <|assistant|> are used to indicate speaker roles in a conversation. This supervised dataset is used to fine-tune the LLM using the same multi-class prediction objective.

Reinforcement learning

In the last stage, a technique known as reinforcement learning with human feedback (RLHF) is used to align the LLM chat assistant with human preferences. In RLHF, a prompt is given to the fine-tuned LLM and different replies generated by the LLM to the same prompt are collected. A human labeller would rank the different outputs and the rankings would be used to generate different pairwise combinations of ranks. I.e. A > B > C → {A > B, A > C, B > C}. A reward model (commonly a neural network) is trained to predict a reward, and is used in place of a typically pre-determined reward function. The reward model takes as inputs the chat history and one of the LLM assistant’s reply, and outputs a single reward score. The loss function is a sigmoid function of the difference between the reward scores of a pair of outputs (e.g., A > B). See the article by Aritra or the video by Ari Seff for further details.

The reward model is used in reinforcement learning using Proximal Policy Optimisation (PPO) to fine-tune the weights of the LLM after SFT. Training of the LLM is now framed as a reinforcement learning problem, where the probability distribution over the vocabulary for the next-token prediction task is now interpreted as parameters of a stochastic policy. The goal of reinforcement learning is for an agent to learn the optimal policy that will maximise expected returns. Repeated rounds of reinforcement learning are performed, where both the policy model (i.e. the LLM) and the reward model are iteratively updated. This is required as the inputs to the reward model changes as the policy model is optimised.

References

OpenAI - Introducing ChatGPT

[YouTube] Ari Seff - How ChatGPT is Trained

Comments

Popular posts from this blog

Self-Attention and the Key-Value Cache (In A Nutshell!)

Upcoming blog posts