When a human-AI conversation involves many rounds of continuous dialogue, the powerful large-scale machine learning models that drive chatbots like ChatGPT sometimes begin to break down, leading to rapid deterioration in bot performance.
A team of researchers from MIT and elsewhere identified a surprising cause of this problem and developed a simple solution that allows a chatbot to maintain a non-stop conversation without crashing or slowing down.
Their method involves a modification of the key-value cache (which is like a conversation memory) at the heart of many large language models. In some methods, when this cache needs to hold more information than it has capacity for, the first few pieces of data are discarded. This can cause the model to fail.
By ensuring that these early data points remain in memory, the researchers' method allows a chatbot to continue chatting regardless of the length of the conversation.
The method, called StreamingLLM, allows a model to remain effective even when a conversation spans more than 4 million words. Compared to another method that avoids crashes by constantly recalculating part of past conversations, StreamingLLM performed more than 22 times faster.
This could allow a chatbot to conduct long conversations throughout the workday without needing to be continually restarted, enabling effective AI assistants for tasks such as writing, editing or generating of code.
“Now, using this method, we can persistently deploy these large language models. By creating a chatbot that we can always chat with and that can always respond to us based on our recent conversations, we could use these chatbots in new applications,” says Guangxuan Xiao, a graduate student in electrical and computer engineering (EECS). and lead author of an article on StreamingLLM.
Xiao's co-authors include his advisor, Song Han, associate professor at EECS, member of the MIT-IBM Watson AI Lab and distinguished scientist at NVIDIA; as well as Yuandong Tian, research scientist at Meta AI; Beidi Chen, assistant professor at Carnegie Mellon University; and lead author Mike Lewis, a research scientist at Meta AI. The work will be presented at the International Conference on Learning Representations.
A confusing phenomenon
Large language models encode data, such as words in a user query, into representations called tokens. Many models use what is called an attention mechanism which uses these tokens to generate new text.
Typically, an AI chatbot writes new text based on the text it just saw. It therefore stores recent tokens in memory, called KV cache, for use later. The attention mechanism constructs a grid that includes all the tokens in the cache, an “attention map” that indicates the relationship of each token, or word, to each other tokens.
Understanding these relationships is a feature that allows large language models to generate human-like text.
But when the cache becomes very large, the attention map can become even larger, which slows down the calculation.
Additionally, if content encoding requires more tokens than the cache can hold, model performance decreases. For example, a popular model can store 4,096 tokens, while there are around 10,000 tokens in an academic paper.
To get around these problems, researchers use a “sliding cache” that deletes the oldest tokens to add new ones. However, model performance often drops as soon as the first token is evicted, quickly reducing the quality of newly generated words.
In this new paper, the researchers realized that if they kept the first token in the sliding cache, the model would maintain its performance even when the cache size was exceeded.
But that didn't make any sense. The first word of a novel probably has nothing to do with the last word, so why would the first word be so important for the model to generate the newest word?
In their new paper, the researchers also discovered the cause of this phenomenon.
Attention decreases
Some models use a Softmax operation in their attention mechanism, which assigns a score to each token that represents its relationship with other tokens. The Softmax operation requires all attention scores to sum to 1. As most tokens are not strongly related, their attention scores are very low. The model removes any remaining attention score in the first token.
Researchers call this first token an “attention sink.”
“We need an attention receiver, and the model decides to use the first token as the attention receiver because it is globally visible – all other tokens can see it. We found that we always need to keep attention in the cache to maintain the dynamics of the model,” says Han.
In creating StreamingLLM, researchers discovered that having four attention receiver tokens at the start of the sliding cache leads to optimal performance.
They also found that the positional encoding of each token should remain the same, even when new tokens are added and others are removed. If token 5 is deleted, token 6 should remain encoded as 6, even though it is now the fifth token in the cache.
By combining these two ideas, they allowed StreamingLLM to maintain a continuous conversation while outperforming a popular method using recomputation.
For example, when the cache contains 256 tokens, the recalculation method takes 63 milliseconds to decode a new token, while StreamingLLM takes 31 milliseconds. However, if the cache size reaches 4,096 tokens, the recalculation requires 1,411 milliseconds for a new token, while StreamingLLM only needs 65 milliseconds.
“StreamingLLM's innovative approach, centered on the attention sink mechanism, ensures stable memory usage and performance even when processing texts up to 4 million tokens long,” says Yang You , Young Presidential Professor of Computer Science at the National University of Singapore, who did not participate in this work. “This ability is not only impressive; it is transformative, allowing StreamingLLM to be applied to a wide range of AI applications. The performance and versatility of StreamingLLM make it a very promising technology, poised to revolutionize the way we approach AI-driven generation applications.
Tianqi Chen, an assistant professor in the departments of machine learning and computer science at Carnegie Mellon University who was also not involved in this research, agrees, saying that “LLM streaming allows for smooth expansion the conversation duration of large language models. We use it to enable the deployment of Mistral models on iPhone with great success.
The researchers also explored the use of attention receivers during model training by adding multiple placeholder tokens in all training samples.
They found that training with attention receptors allowed a model to maintain performance with just one attention receptor in its cache, rather than the four typically needed to stabilize the performance of a pre-trained model.
But even though StreamingLLM allows a model to conduct a continuous conversation, the model cannot remember words that are not stored in the cache. In the future, the researchers plan to target this limitation by investigating methods to recover tokens that have been evicted or allow the model to remember previous conversations.
StreamingLLM has been integrated with NVIDIA's large language model optimization library, TensorRT-LLM.
This work is funded in part by the MIT-IBM Watson AI Lab, the MIT Science Hub, and the U.S. National Science Foundation.