core

Review from LLMs-from-scratch.

In this notebook, we will go through code we learned from LLMs-from-scratch.

Ch, 2: Working with Text Data

We will use simple numbers as text data and train it. In the book, we used verdict story.

To create the data, use this command:

seq -s ', ' 0 3000 > numbers.txt

This creates numbers from 0 to 3000 separated by ‘,’.

with open("numbers.txt", "r", encoding="utf-8") as f:
    raw_text = f.read()
raw_text[:99]
'0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 2'
vocab = {str(n):n for n in range(10)}
vocab.update({',': 10, ' ': 11})
vocab
{'0': 0,
 '1': 1,
 '2': 2,
 '3': 3,
 '4': 4,
 '5': 5,
 '6': 6,
 '7': 7,
 '8': 8,
 '9': 9,
 ',': 10,
 ' ': 11}
samp = raw_text[:30]
samp
'0, 1, 2, 3, 4, 5, 6, 7, 8, 9, '

Here is a simple tokenizer which encodes text into tokens or decodes tokens into text. The reason we have to tokenize text into numbers is computers cannot directly use characters to train neural network. So we turn them into numbers.

class SimpleTokenizer:
    def __init__(self, vocab):
        self.str_to_int = vocab
        self.int_to_str = {v:k for k,v in vocab.items()}
    
    def encode(self, text): return [vocab[o] for o in samp]
    
    def decode(self, tokens): return ''.join([int_to_str[o] for o in tokens])
tokenizer = SimpleTokenizer(vocab)
tokenizer.encode(samp)
[0,
 10,
 11,
 1,
 10,
 11,
 2,
 10,
 11,
 3,
 10,
 11,
 4,
 10,
 11,
 5,
 10,
 11,
 6,
 10,
 11,
 7,
 10,
 11,
 8,
 10,
 11,
 9,
 10,
 11]
tokenizer.decode(tokenizer.encode(samp))
'0, 1, 2, 3, 4, 5, 6, 7, 8, 9, '

The tokenizer looks good. Let’s move on to Dataset. Dataset provides us with input and target. Target is simply the next token from the input because we want to predict the next token given the input. In this example, our dataset also tokenizes the text, but tokenization does not have to happen inside of dataset. In practice, text data are pretokenized.

class SimpleDataset(Dataset):
    def __init__(self, txt, tokenizer, max_length):
        self.input_ids = []
        self.target_ids = []

        token_ids = tokenizer.encode(txt)
        assert len(token_ids) > max_length, "Number of tokenized inputs must at least be equal to max_length+1"

        for i in range(0, len(token_ids) - max_length, max_length):
            input_chunk = token_ids[i:i + max_length]
            target_chunk = token_ids[i + 1: i + max_length + 1]
            self.input_ids.append(torch.tensor(input_chunk))
            self.target_ids.append(torch.tensor(target_chunk))

    def __len__(self): return len(self.input_ids)

    def __getitem__(self, idx): return self.input_ids[idx], self.target_ids[idx]
ds = SimpleDataset(samp, tokenizer, 4)
ds[0]
(tensor([ 0, 10, 11,  1]), tensor([10, 11,  1, 10]))

Next, we have a data loader. The data loader wraps data we got from dataset together into batches. We can also shuffle for training data or leave them as they are for validation or test.

def mk_dataloader(txt, tokenizer, batch_size=4, max_length=256, 
                  shuffle=True, drop_last=True, num_workers=0):
    return DataLoader(
        dataset=SimpleDataset(txt, tokenizer, max_length),
        batch_size=batch_size,
        shuffle=shuffle,
        drop_last=drop_last,
        num_workers=num_workers
    )
dl = mk_dataloader(samp, tokenizer, batch_size=2, max_length=4)
xb, yb = next(iter(dl))
xb, yb
(tensor([[10, 11,  2, 10],
         [10, 11,  6, 10]]),
 tensor([[11,  2, 10, 11],
         [11,  6, 10, 11]]))

Ch. 3: Coding Attention Mechanisms

By using attention mechanisms, the model figures out which token to pay more attention to. For instance, in our numbers example, paying attention to the last digit of the sequence is very important.

class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert (d_out % num_heads == 0), \
            "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)  # Linear layer to combine head outputs
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length),
                       diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        # As in `CausalAttention`, for inputs where `num_tokens` exceeds `context_length`, 
        # this will result in errors in the mask creation further below. 
        # In practice, this is not a problem since the LLM (chapters 4-7) ensures that inputs  
        # do not exceed `context_length` before reaching this forwar

        keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
        queries = self.W_query(x)
        values = self.W_value(x)

        # We implicitly split the matrix by adding a `num_heads` dimension
        # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim) 
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

        # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        # Compute scaled dot-product attention (aka self-attention) with a causal mask
        attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head

        # Original mask truncated to the number of tokens and converted to boolean
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

        # Use the mask to fill attention scores
        attn_scores.masked_fill_(mask_bool, -torch.inf)
        
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Shape: (b, num_tokens, num_heads, head_dim)
        context_vec = (attn_weights @ values).transpose(1, 2) 
        
        # Combine heads, where self.d_out = self.num_heads * self.head_dim
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec) # optional projection

        return context_vec