How to constrain your model to produce defined formats
In this article, I will explain and demonstrate the concept of “structured generative AI”: generative AI constrained to defined formats. By the end of the article, you will understand where and when it can be used and how to implement it, whether you are creating a transformer model from scratch or using Hugging Face's models. Additionally, we'll cover an important tip for tokenization that is particularly relevant for structured languages.
One of the many uses of generative AI is as a translation tool. This often involves translation between two human languages, but can also include computer languages or formats. For example, your application may need to translate natural (human) language into SQL:
Natural language: “Get customer names and emails of customers from the US”SQL: "SELECT name, email FROM customers WHERE country = 'USA'"
Or to convert text data to JSON format:
Natural language: “I am John Doe, phone number is 555–123–4567,
my friends are Anna and Sara”JSON: {name: "John Doe",
phone_number: "555–123–5678",
friends: {
name: (("Anna", "Sara"))}
}
Naturally, many other applications are possible for other structured languages. The training process for such tasks involves feeding natural language examples as well as structured formats to an encoder-decoder model. Alternatively, leveraging a pre-trained language model (LLM) may be sufficient.
Although it is impossible to achieve 100% accuracy, there is one class of errors that we can eliminate: syntax errors. These are violations of language format, such as replacing commas with periods, using table names that are not present in the SQL schema, or omitting square brackets that make SQL or JSON non-executable.
The fact that we are translating into a structured language means that the list of legitimate tokens at each generation step is limited and predetermined. If we could integrate this knowledge into the generative AI process, we could avoid a wide range of incorrect results. This is the idea behind structured generative AI: constraining it to a list of legitimate tokens.
A quick reminder on how tokens are generated
Whether it is an encoder-decoder or GPT architecture, token generation works sequentially. The selection of each token is based on both input tokens and previously generated tokens, continuing until a token
Limit token generation
To limit token generation, we integrate knowledge of the output language structure. Illegitimate tokens have their logits set to -inf, ensuring their exclusion from selection. For example, if only a comma or “FROM” is valid after “Select Name”, all other token logits are set to -inf.
If you are using Hugging Face, this can be implemented using a “logits processor”. To use it, you must implement a class with a __call__ method, which will be called after calculating the logits, but before sampling. This method receives all generated token logits and input IDs, returning the modified logits for all tokens.
I will demonstrate the code with a simplified example. First we initialize the model, we will use Bart in this case, but it can work with any model.
from transformers import BartForConditionalGeneration, BartTokenizerFast, PreTrainedTokenizer
from transformers.generation.logits_process import LogitsProcessorList, LogitsProcessor
import torchname = 'facebook/bart-large'
tokenizer = BartTokenizerFast.from_pretrained(name, add_prefix_space=True)
pretrained_model = BartForConditionalGeneration.from_pretrained(name)
If we want to generate a natural language to SQL translation, we can run:
to_translate = 'customers emails from the us'
words = to_translate.split()
tokenized_text = tokenizer((words), is_split_into_words=True)out = pretrained_model.generate(
torch.tensor(tokenized_text("input_ids")),
max_new_tokens=20,
)
print(tokenizer.convert_tokens_to_string(
tokenizer.convert_ids_to_tokens(
out(0), skip_special_tokens=True)))
Back
'More emails from the us'
Since we did not refine the model for text-to-SQL tasks, the result does not look like SQL. We will not train the model in this tutorial, but we will guide it to generate an SQL query. We achieve this by employing a function that maps each generated token to a list of next allowed tokens. For the sake of simplicity, we will focus only on the token immediately preceding it, but more complicated mechanisms are easy to implement. We will use a dictionary defining for each token, which tokens are allowed to follow it. For example, the query must start with “SELECT” or “DELETE”, and after “SELECT”, only “name”, “email” or “id” are allowed since these are the columns in our schema.
rules = {'<s>': ('SELECT', 'DELETE'), # beginning of the generation
'SELECT': ('name', 'email', 'id'), # names of columns in our schema
'DELETE': ('name', 'email', 'id'),
'name': (',', 'FROM'),
'email': (',', 'FROM'),
'id': (',', 'FROM'),
',': ('name', 'email', 'id'),
'FROM': ('customers', 'vendors'), # names of tables in our schema
'customers': ('</s>'),
'vendors': ('</s>'), # end of the generation
}
We now need to convert these tokens into identifiers used by the model. This will happen in a class inheriting from LogitsProcessor.
def convert_token_to_id(token):
return tokenizer(token, add_special_tokens=False)('input_ids')(0)class SQLLogitsProcessor(LogitsProcessor):
def __init__(self, tokenizer: PreTrainedTokenizer):
self.tokenizer = tokenizer
self.rules = {convert_token_to_id(k): (convert_token_to_id(v0) for v0 in v) for k,v in rules.items()}
Finally, we will implement the __call__ function, which is called after calculating the logits. The function creates a new tensor of -infs, checks which identifiers are legitimate according to the rules (the dictionary), and places their scores in the new tensor. The result is a tensor that only has valid values for valid tokens.
class SQLLogitsProcessor(LogitsProcessor):
def __init__(self, tokenizer: PreTrainedTokenizer):
self.tokenizer = tokenizer
self.rules = {convert_token_to_id(k): (convert_token_to_id(v0) for v0 in v) for k,v in rules.items()}def __call__(self, input_ids: torch.LongTensor, scores: torch.LongTensor):
if not (input_ids == self.tokenizer.bos_token_id).any():
# we must allow the start token to appear before we start processing
return scores
# create a new tensor of -inf
new_scores = torch.full((1, self.tokenizer.vocab_size), float('-inf'))
# ids of legitimate tokens
legit_ids = self.rules(int(input_ids(0, -1)))
# place their values in the new tensor
new_scores(:, legit_ids) = scores(0, legit_ids)
return new_scores
And that's all! We can now launch a generation with the logits processor:
to_translate = 'customers emails from the us'
words = to_translate.split()
tokenized_text = tokenizer((words), is_split_into_words=True, return_offsets_mapping=True)logits_processor = LogitsProcessorList((SQLLogitsProcessor(tokenizer)))
out = pretrained_model.generate(
torch.tensor(tokenized_text("input_ids")),
max_new_tokens=20,
logits_processor=logits_processor
)
print(tokenizer.convert_tokens_to_string(
tokenizer.convert_ids_to_tokens(
out(0), skip_special_tokens=True)))
Back
SELECT email , email , id , email FROM customers
The result is a bit strange, but remember: we haven't even trained the model! We only enforced token generation based on specific rules. Notably, constraining generation does not interfere with training; constraints only apply during post-training generation. So, when implemented appropriately, these constraints can only improve generation accuracy.
Our simplistic implementation does not cover all SQL syntax. A real implementation should support more syntax, potentially considering not just the last token but multiple ones, and allow for batch generation. With these enhancements in place, our trained model can reliably generate executable SQL queries constrained to valid table and column names in the schema. A similar approach can impose constraints in JSON generation, ensuring keys are present and brackets are closed.
Beware of tokenization
Tokenization is often overlooked, but correct tokenization is crucial when using generative AI for structured output. However, under the hood, tokenization can impact the training of your model. For example, you can refine a model to translate text to JSON. As part of the fine-tuning process, you provide the model with sample text-JSON pairs, which it tokenizes. What will this tokenization look like?
While you read “((” as two square brackets, the tokenizer converts them into a single identifier, which will be treated as a completely separate class from the single parenthesis by the token classifier. This makes all the logic that the model needs to learn – more complicated (e.g. remembering how many brackets to close). Likewise, adding a space before words can change their tokenization and class identifier.
Once again, this complicates the logic that the model will have to learn since the weights linked to each of these identifiers will have to be learned separately, for slightly different cases.
For easier learning, make sure every concept and punctuation is consistently converted into the same token, adding spaces before words and characters.
Entering spaced examples during fine-tuning simplifies the patterns the model must learn, thereby improving model accuracy. When predicting, the model will display the JSON with spaces, which you can then remove before parsing.
Summary
Generative AI offers a valuable approach to translating into formatted language. By leveraging knowledge of the output structure, we can limit the generation process, eliminating a class of errors and ensuring the executability of queries and the ability to analyze data structures.
Additionally, these formats may use punctuation marks and keywords to signify certain meanings. Ensuring that the tokenization of these keywords is consistent can significantly reduce the complexity of the patterns the model must learn, thereby reducing the required model size and its training time, while increasing its accuracy.
Structured generative AI can efficiently translate natural language into any structured format. These translations allow you to extract information from text or generate queries, which is a powerful tool for many applications.