Creating an LLM

After custom-coding my own simple classification neural net, I decided to continue coding ai models, with a language model next on my bucket list. My journey to the LLM that concluded this chapter of my ai journey was divided into 3 distinct steps: an RNN, an LSTM, and finally a small LLM.

An RNN (Recurrent Neural Network)’s main principle is that it utilizes a hidden state, or a value passed on from earlier time states of the model to the current one. In the context of a language processor, this results in data or meaning from the earlier tokens in a text being passed on to the model as it reads, allowing it to encode some overarching trends across larger writing samples.

The formula for the hidden state is given below, using the tanh() function to compress the output between -1 and 1 to ensure it stays small across multiple time steps.

As you can see, the calculation for hidden state, or h_t, utilizes the hidden state from the last step, h_{t-1}, which carries meaning across steps.

However, there is an inbuilt flaw in the RNN architecture seen above, something that LSTMs were created to fix. Since tanh() compresses the value of the hidden state to [-1,1], the encoded meaning from time steps a decent amount back gets reduced to almost zero. This leads to RNN’s being much worse creating at long coherent sentences compared to other language models, often drifting topics quite easily due to not being able to access significant meaning from the starting tokens after a few steps.

I tuned my RNN, called Benvolio (a character from Romeo and Juliet), with character-by-character tokenization on the TinyShakespeare dataset, which includes the full length of some of Shakespeare’s works, perfect for teaching a basic language model how to speak in Shakespearean English. At its best performance, I managed to achieve ~1.33 loss with about 450,000 parameters, providing decent results for such a small model, and you can see some of its output below or in my projects page.

Looking at the model’s response, we can see it emulates many aspects of Shakespearean style, including real Shakespeare characters like Romeo, Petruchio, and Coriolanus. It also copies the format of character speech broken up by headings. However, looking at the seed text: “Julius Caesar jumped out of the door and shouted “, we see that the model immediately drifts off topic, talking about horses. Looking on, we can also see that, while using the correct formatting, the actual text resembles what I would call Shakespeare-flavored gibberish. It is this problem of rambling text and nonsensical, yet almost correct, words that I tried to fix by upgrading to an LSTM.

The main improvement an LSTM (Long Short-Term Memory) makes over an RNN is the addition of a second memory state along with 4 gates to update it. While in an RNN, memory is passed across time steps using a singular hidden state, which is constantly updated and overwritten each step, causing meaning from the start of the text to be diluted and lost very quickly. An LSTM, on the other hand, uses a separate cell state, which is only ever updated through multiplication and addition (no tanh!) to avoid squashing meaning and eventually losing it. The updates to the cell state are controlled by 4 gates, each with their own parameters: the forget gate, the input gate, the candidate gate, and the output gate.

This set of equations probably seems very overwhelming at first, so let’s walk through it slowly together. Starting off, we have the sigmoid function, denoted by lowercase sigma, and this works very similarly to tanh(), but it squashes values to [0,1] instead of [-1,1]. Then, the 4 gates (forget, input, candidate, and output) all act in the same way, multiplying their exclusive weights with the previous hidden state before adding their respective biases (all weights and biases are tuned in training) and wrapping the whole thing in a squashing function.

Individually, the forget gate controls what is forgotten, or removed from the previous state. The input and candidate gates work together to add new things to the state, with input identifying the dimension for data change and candidate identifying how much to change. The reason candidate gate uses tanh() while the other two use sigmoid is because tanh() allows candidate to propose both increases to certain weights and decreases to others while sigmoid would not (think about the different ranges), making it essential for proper updating of values. If it did not exist, the values would simply continue to increase forever with no way of decreasing them.

If you are wondering why forget doesn’t use tanh() when it needs to decrease weights to “forget” them, take a look at the equation for C_t, or cell state. Since forget is multiplied by the previous cell state, allowing negative values could cause some matrix values to be flipped in sign, turning a highly weighted point into its exact opposite and causing problems down the road. Since multiplication is used, forget’s output being in the range [0,1] will naturally be able to decrease weights by multiplying by a number < 1, all without problematically flipping a sign.

Moving on to the cell state calculation, or C_t, we can see that the forget gate output is multiplied by the previous cell state to remove unnecessary values, and the input and candidate outputs are combined through multiplication added to that to create the cell state for this time step. Notice the distinct lack of a squashing function, allowing cell state to keep values from much earlier in the text compared to an RNN.

Finally, the hidden state for this time step is extracted from the cell state by using tanh() and multiplying by the output gate, which has the exact same structure as forget and input. The only difference is that its parameters are tuned with the purpose of extracting the necessary values from cell state.

The main difference in hidden state usage between an RNN and an LSTM is as follows: an RNN hidden state serves two purposes, carrying data and determining output; an LSTM hidden state is created once per time step and is only used for the latter, with a separate cell state covering the other functionality. This allows an LSTM to outperform a vanilla RNN in terms of long-term output coherence and relation to seed string.

My LSTM Mercutio (can you see the theme) already outperformed my RNN Benvolio in training, reaching a minimum of ~0.85 loss compared to the latter’s ~1.33. You can see some of its output in my interactive project demo, along with that of Benvolio and their older cousin Romeo (Transformer model).

This article has already gone on for quite a while, so I will continue my journey to creating a Transformer LLM in my next post. I hope you all learned something today, and see you next time!


Comments

Leave a Reply

Discover more from VJ's Field Notes

Subscribe now to keep reading and get access to the full archive.

Continue reading