36

Building an LSTM from Scratch in PyTorch (LSTMs in Depth Part 1)

 5 years ago
source link: https://www.tuicool.com/articles/hit/ARnEfur
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
neoserver,ios ssh client

Despite being invented over 20 (!) years ago, LSTMs are still one of the most prevalent and effective architectures in deep learning. Multiple papers have claimed that they developed an architecture that outperforms LSTMs, only for someone else to come along afterwards and discover that well-tuned LSTMs were better all along.

In this series of posts, I’ll be covering LSTMs in depth: building, analyzing, and optimizing them. In this first post, I’ll be building an LSTM from scratch in PyTorch to gain a better understanding of their inner workings. As Richard Feynman said, “what I cannot create, I do not understand”.

Qna2Anf.jpg!web

I’ve uploaded the full code for this post in this notebook .

The Basics of the LSTM

Before we start coding, we’ll need to discuss how the LSTM works to begin with. The image below is from Wikipedia and represents how the LSTM cell works.

Yjqeiqn.png!web

Seems pretty intimidating, doesn’t it? Don’t worry: we’ll dissect the LSTM piece by piece. The first step is to look at the equation for the LSTM:

eeYFjiB.png!web

is the sigmoid function, and represents element-wise multiplication. There are a lot of steps here, but the core equation is just this line:

3AJJrmY.png!web

Let’s pick this equation apart: is the new cell state, which is basically the memory of the LSTM.

vMZbMji.png!web is called the “forget gate”: it dictates how much of the previous cell state to retain (but is slightly confusingly named the forget gate).

is the “input gate” and dictates how much to update the cell state with new information.

Finally, is the information we use to update the cell state.

Basically, an LSTM chooses to keep a certain portion of its previous cell state and add a certain amount of new information. These proportions are controlled using gates.

Let’s contrast this update rule with the update rule of a simpler RNN:

7JZNveF.png!web

To make the contrast clearer, I’m representing the hidden state of the RNN as . Note that the hidden state is more commonly referred to as QjQZFnv.png!web .

As you can see, there is a huge difference between the simple RNN’s update rule and the LSTM’s update rule. Whereas the RNN computes the new hidden state from scratch based on the previous hidden state and the input, the LSTM computes the new hidden state by choosing what to add to the current state. This is similar to how ResNets learn: they learn what to add to the current state/block instead of directly learning the new state. In other words, LSTMs are great primarily because they are additive . We’ll formalize this intuition later when we examine the gradient flow, but this is the basic idea behind the LSTM.

Side Note: One thing that is slightly confusing about the LSTM is that it has two “hidden states”: and QjQZFnv.png!web . Intuitively, is the “internal” hidden state that retains important information for longer timesteps, whereas QjQZFnv.png!web is the “external” hidden state that exposes that information to the outside world.

FzYjIrB.png!web

The image above explains what I just wrote visually. One thing to be careful of is that the above image is a diagram representing an older version of the LSTM that did not have the forget gate. We’ll see what the consequences of removing the forget gate is later on in this post.

Now that we have a rudimentary understanding, let’s move our hands and write some code!

Building an LSTM from Scratch

The LSTM code is really simple: you just need to translate the equations above into PyTorch operations. Here’s a code example for a naively implemented LSTM.

from enum import IntEnum
class Dim(IntEnum):
    batch = 0
    seq = 1
    feature = 2

class NaiveLSTM(nn.Module):
    def __init__(self, input_sz: int, hidden_sz: int):
        super().__init__()
        self.input_size = input_sz
        self.hidden_size = hidden_sz
        # input gate
        self.W_ii = Parameter(torch.Tensor(input_sz, hidden_sz))
        self.W_hi = Parameter(torch.Tensor(hidden_sz, hidden_sz))
        self.b_i = Parameter(torch.Tensor(hidden_sz))
        # forget gate
        self.W_if = Parameter(torch.Tensor(input_sz, hidden_sz))
        self.W_hf = Parameter(torch.Tensor(hidden_sz, hidden_sz))
        self.b_f = Parameter(torch.Tensor(hidden_sz))
        # ???
        self.W_ig = Parameter(torch.Tensor(input_sz, hidden_sz))
        self.W_hg = Parameter(torch.Tensor(hidden_sz, hidden_sz))
        self.b_g = Parameter(torch.Tensor(hidden_sz))
        # output gate
        self.W_io = Parameter(torch.Tensor(input_sz, hidden_sz))
        self.W_ho = Parameter(torch.Tensor(hidden_sz, hidden_sz))
        self.b_o = Parameter(torch.Tensor(hidden_sz))
        
        self.init_weights()
    
    def init_weights(self):
        for p in self.parameters():
            if p.data.ndimension() >= 2:
                nn.init.xavier_uniform_(p.data)
            else:
                nn.init.zeros_(p.data)
        
    def forward(self, x: torch.Tensor, 
                init_states: Optional[Tuple[torch.Tensor]]=None
               ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        """Assumes x is of shape (batch, sequence, feature)"""
        bs, seq_sz, _ = x.size()
        hidden_seq = []
        if init_states is None:
            h_t, c_t = torch.zeros(self.hidden_size).to(x.device), torch.zeros(self.hidden_size).to(x.device)
        else:
            h_t, c_t = init_states
        for t in range(seq_sz): # iterate over the time steps
            x_t = x[:, t, :]
            i_t = torch.sigmoid(x_t @ self.W_ii + h_t @ self.W_hi + self.b_i)
            f_t = torch.sigmoid(x_t @ self.W_if + h_t @ self.W_hf + self.b_f)
            g_t = torch.tanh(x_t @ self.W_ig + h_t @ self.W_hg + self.b_g)
            o_t = torch.sigmoid(x_t @ self.W_io + h_t @ self.W_ho + self.b_o)
            c_t = f_t * c_t + i_t * g_t
            h_t = o_t * torch.tanh(c_t)
            hidden_seq.append(h_t.unsqueeze(Dim.batch))
        hidden_seq = torch.cat(hidden_seq, dim=Dim.batch)
        # reshape from shape (sequence, batch, feature) to (batch, sequence, feature)
        hidden_seq = hidden_seq.transpose(Dim.batch, Dim.seq).contiguous()
        return hidden_seq, (h_t, c_t)

Side Note: If you look carefully at the equation of the LSTM, you’ll notice that the bias terms are redundant. The reason they are there is for compatibility with the CuDNN backend. Until we touch on CuDNN (which I plan to do in a future post), we’ll use a single bias term.

As you can see, all we’re doing is applying the equation over and over, one timestep at a time. I recommend trying to replicate the code above without looking at the code I wrote (you can look at the equations, but try and implement them with your own hands!).

I won’t discuss testing the above code here, but you can confirm that the code runs successfully and can train in the full notebook I uploaded here .

Making our LSTM Faster

If you look at the code for our LSTM carefully, you’ll notice that there is a lot of shared processing that could be batched together. For instance, the input and forget gates are both computed based on a linear transformation of the input and the hidden states.

We can group these computations into just two matrix multiplications. The code now looks like this:

class OptimizedLSTM(nn.Module):
    def __init__(self, input_sz: int, hidden_sz: int):
        super().__init__()
        self.input_sz = input_sz
        self.hidden_size = hidden_sz
        self.weight_ih = Parameter(torch.Tensor(input_sz, hidden_sz * 4))
        self.weight_hh = Parameter(torch.Tensor(hidden_sz, hidden_sz * 4))
        self.bias = Parameter(torch.Tensor(hidden_sz * 4))
        self.init_weights()
    
    def init_weights(self):
        for p in self.parameters():
            if p.data.ndimension() >= 2:
                nn.init.xavier_uniform_(p.data)
            else:
                nn.init.zeros_(p.data)
        
    def forward(self, x: torch.Tensor, 
                init_states: Optional[Tuple[torch.Tensor]]=None
               ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        """Assumes x is of shape (batch, sequence, feature)"""
        bs, seq_sz, _ = x.size()
        hidden_seq = []
        if init_states is None:
            h_t, c_t = (torch.zeros(self.hidden_size).to(x.device), 
                        torch.zeros(self.hidden_size).to(x.device))
        else:
            h_t, c_t = init_states
        
        HS = self.hidden_size
        for t in range(seq_sz):
            x_t = x[:, t, :]
            # batch the computations into a single matrix multiplication
            gates = x_t @ self.weight_ih + h_t @ self.weight_hh + self.bias
            i_t, f_t, g_t, o_t = (
                torch.sigmoid(gates[:, :HS]), # input
                torch.sigmoid(gates[:, HS:HS*2]), # forget
                torch.tanh(gates[:, HS*2:HS*3]),
                torch.sigmoid(gates[:, HS*3:]), # output
            )
            c_t = f_t * c_t + i_t * g_t
            h_t = o_t * torch.tanh(c_t)
            hidden_seq.append(h_t.unsqueeze(Dim.batch))
        hidden_seq = torch.cat(hidden_seq, dim=Dim.batch)
        # reshape from shape (sequence, batch, feature) to (batch, sequence, feature)
        hidden_seq = hidden_seq.transpose(Dim.batch, Dim.seq).contiguous()
        return hidden_seq, (h_t, c_t)

Our code is now much closer to the official LSTM implementation in PyTorch and is much faster, especially on the GPU.

Understanding why LSTMs Perform Well

Why exactly do LSTMs learn so well? LSTMs were originally invented to combat the vanishing gradient problem. The problem with traditional RNNs was that their gradients tended to decay as the number of timesteps increased, making it difficult for RNNs to learn long-term depen dencies.

To see this in action, let’s analyze the dynamics of LSTM learning by checking how the gradients change and comparing them to the gradients of a simple RNN.

The gradient dynamics of simple RNNs

First, let’s see how the gradients change with a simple RNN. Here’s the code we’ll use:

rnn = SimpleRNN(50, 125)
def rnn_step(x_t, h_t, weight_ih, weight_hh, bias_hh):
    return torch.tanh(x_t @ weight_ih + h_t @ weight_hh + bias_hh)

h_0 = torch.zeros(rnn.hidden_size, requires_grad=True).to(test_embeddings.device)
h_t = h_0
grads = []

for t in range(100):
    h_t = rnn_step(
        test_embeddings[:, t, :], h_t,
        rnn.weight_ih, rnn.weight_hh, rnn.bias_hh,
    )
    loss = h_t.abs().sum() # we'll use the l1 norm of the current hidden state as the loss
    loss.backward(retain_graph=True)
    grads.append(torch.norm(h_0.grad).item())
    h_0.grad.zero_()

In this piece of code, we’re computing the gradient between the norm of the hidden state at timestep $ t $ and the initial hidden state. The larger the gradient, the larger the effect of the initial hidden state on the hidden state at timestep $ t $.

Here’s how the magnitude of the gradient changes:

u2Uziia.png!web Gradient decay for a Simple RNN

As you can see, the gradients rapidly decay as time progresses. This makes long term dependencies difficult to learn since later timesteps cannot influence earlier timesteps.

The gradient dynamics of LSTMs

Now, let’s compare this with LSTMs. For starters, we’ll try using an LSTM without a forget gate.

lstm = NaiveLSTM(50, 125)
hidden_size = lstm.hidden_size

def lstm_step(x_t, h_t, c_t, W_ii, W_hi, b_i, W_if, W_hf, b_f,
              W_ig, W_hg, b_g, W_io, W_ho, b_o, use_forget_gate=False):
    i_t = torch.sigmoid(x_t @ W_ii + h_t @ W_hi + lstm.b_i)
    if use_forget_gate:
        f_t = torch.sigmoid(x_t @ W_if + h_t @ W_hf + lstm.b_f)
    g_t = torch.tanh(x_t @ W_ig + h_t @ W_hg + lstm.b_g)
    o_t = torch.sigmoid(x_t @ W_io + h_t @ W_ho + lstm.b_o)
    if use_forget_gate:
        c_t = f_t * c_t + i_t * g_t
    else:
        c_t = c_t + i_t * g_t
    h_t = o_t * torch.tanh(c_t)
    return h_t, c_t

# generate 
h_0, c_0 = (torch.zeros(hidden_size, requires_grad=True), 
            torch.zeros(hidden_size, requires_grad=True))
grads = []
h_t, c_t = h_0, c_0
for t in range(100):
    h_t, c_t = lstm_step(
        test_embeddings[:, t, :], h_t, c_t,
        lstm.W_ii, lstm.W_hi, lstm.b_i,
        lstm.W_if, lstm.W_hf, lstm.b_f,
        lstm.W_ig, lstm.W_hg, lstm.b_g,
        lstm.W_io, lstm.W_ho, lstm.b_o,
        use_forget_gate=False,
    )
    loss = h_t.abs().sum()
    loss.backward(retain_graph=True)
    grads.append(torch.norm(h_0.grad).item())
    h_0.grad.zero_()
    lstm.zero_grad()

Here’s the plot:

IjEf2eb.png!web Gradient decay (accumulation) for an LSTM with no forget gate

Now, instead of decaying, the gradient keeps on accumulating! The reason the gradient behaves this way is because of the update rule

n2Yz2aa.png!web

If you’re familiar with gradient calculus, you’ll see that the gradients for propagate straight back to the gradients for UBvaqeU.png!web .

If you’re familiar with gradient calculus, you’ll see that the gradients for propagate straight back to the gradients for UBvaqeU.png!web . Therefore, the gradient of the initial timestep keeps increasing: since influences zQFn6ri.png!web , which in turn influences , and so on, the influence of the initial state never disappears.

Of course, this can be a mixed blessing: sometimes we don’t want the current timestep to influence the hidden state 200 steps into the future. Sometimes, we want to “forget” the information we learned earlier and overwrite it with what we have newly learned. This is where the forget gate comes into play.

Turning the forget gate on

The forget gate was originally proposed in the paper Learning to Forget: Continual Prediction with LSTM . Let’s see how the gradients change when we turn the forget gate on. Adhering to best practices, we’ll initialize the bias for the forget gate to 1.

jQrUzu7.png!web Gradient decay of an LSTM with the forget gate

Notice how the gradients decay much more slowly than in the case of the Simple RNN. On the other hand, when we don’t initialize the forget gate bias to 1…

QbiArq3.png!web Gradient decay with the forget gate bias initialized to 0

The gradient decays much more quickly now: this is why initializing the forget gate to 1 is a good idea, at least in the initial stages of training.

Now, let’s see what happens when we initialize the forget gate to -1.

aAbQbqE.png!web Gradient decay with the forget gate initialized to -1

The weights decay even faster now!

We looked at a lot of charts, but the most important point is that the LSTM basically has control over how much of the gradient to allow to flow through each timestep. This is what makes them so easy to train.

Conclusion and Further Readings

We implemented an LSTM cell from scratch, and gained a basic understanding of what makes LSTMs effective in this post. This is only scratching the surface of LSTMs though: there are still many best practices and implementation details we haven’t yet covered (such as AWD LSTMs and CuDNN). In future posts, I’m planning to cover these topics in more detail. For now, you can refer to the original papers regarding the mechanics of LSTMs: you’ll be surprised at how much they got right and how many hints they contained regarding deep learning (you might even discover some great hints for your own research/projects!)

Further Readings

The original LSTM paper

Learning to Forget: Continual Prediction with LSTM


Recommend

About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK