Training an LLM (In a Nutshell!)
A large language model (LLM) learns how to reply to a conversation by learning how to predict the next token (akin to a word), given the preceding conversation. This is framed as a multi-class classification problem, where each token represents a different class. The LLM outputs the likelihood of each token in the vocabulary (i.e. set of all possible tokens) of being the next token (which represent the probability parameters of a categorical distribution). Prediction of the next token occurs by sampling from the categorical distribution of all possible tokens. After a token is predicted, it is used along with tokens from the preceding conversation to predict the next token.
Training an LLM typically consists of three main stages (different LLMs may have different training schemes):
- Unsupervised pre-training
- Supervised fine-tuning
- Reinforcement learning
Stages 2 and 3 are commonly termed collectively as the fine-tuning stage.
Unsupervised pre-training
In the unsupervised pre-training stage, massive amounts of text data scraped from the Internet (e.g. GitHub, Wikipedia, ...) is used to train the LLM to predict the next token based on a fixed-length context window of tokens before it. This creates a base model which is able to generate text when prompted. However, there is no guarantee that the text generated would be an answer to the prompt, as the LLM only predicts the most probable words that will come after the prompt.
Supervised fine-tuning
In order for the LLM to be able to respond to prompts in the style of a chat assistant, supervised fine-tuning (SFT) of the LLM has to be performed. In SFT, a dataset is created by human labellers consisting of example conversations between a chat assistant and a user. The text “User:” and “Assistant:” is used to mark out the conversation history. This supervised dataset is used to fine-tune the LLM using the same multi-class prediction objective.
Reinforcement learning
In the last stage, a technique known as reinforcement learning with human feedback (RLHF) is used to align the LLM chat assistant with human preferences. In RLHF, a prompt is given to the fine-tuned LLM and different replies generated by the LLM to the same prompt are collected. A human labeller would rank the different outputs and the rankings would be used to generate different pairwise combinations of ranks. I.e. A > B > C → {A > B, A > C, B > C}. A reward model (commonly a neural network) is trained to predict a reward, and is used in place of a typically pre-determined reward function. The reward model takes as inputs the chat history and the LLM assistant’s reply, and outputs a single reward score. It is trained using a special loss function similar to binary classification. For more details, Ari Seff explains the process best in his video.
The reward model is used in reinforcement learning using Proximal Policy Optimisation (PPO) to fine-tune the weights of the LLM after SFT. Training of the LLM is now framed as a reinforcement learning problem, where the probability outputs from the categorical distribution representing the entire vocabulary is now treated as the parameters of a stochastic policy. The goal of reinforcement learning is for an agent to learn the optimal policy that will maximise expected returns. Repeated rounds of reinforcement learning are performed, where both the policy model (i.e. the LLM) and the reward model are iteratively updated. This is required as the inputs to the reward model changes as the policy model is optimised.
Comments
Post a Comment