bigrams annotated

This is an annotated Python program that constructs a bigram model for predicting the next letter in a name.

This program is based on this Andrej Karpathy video: building makemore.

import torch
import torch.nn.functional as F
list_of_names = open('names.txt', 'r').read().splitlines()
chars = sorted(set(''.join(list_of_names)))
stoi = {s:i+1 for i,s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s,i in stoi.items()}
xs, ys = [], []
for w in list_of_names:
  chs = ['.'] + list(w) + ['.']
  for ch1, ch2 in zip(chs, chs[1:]):
    ix1 = stoi[ch1]
    ix2 = stoi[ch2]
    xs.append(ix1)
    ys.append(ix2)
    
xs = torch.tensor(xs)
ys = torch.tensor(ys)
num = xs.nelement()


# Encode the xs training data as one hot vectors
xenc = F.one_hot(xs, num_classes=27).float()
print('xenc shape:', xenc.shape)
print(xenc)

# Create a 27 by 27 weight matrix (tensor) with randomly initialized values
g = torch.Generator().manual_seed(2147483647)
W = torch.randn((27, 27), generator=g, requires_grad=True)
print('W matrix shape:', W.shape)
print(W)

# Perform gradient descent to train the values in W
for k in range(10):

    logits = xenc @ W                             # predict log-counts

    counts = logits.exp()                         # calculate counts

    probs = counts / counts.sum(1, keepdims=True) # calculates probabilities for next character
                                                  # Note: probs = softmax(logits)

    loss = -probs[torch.arange(num), ys].log().mean() + 0.01*(W**2).mean()

    W.grad = None   # initialize the W matrix gradient to zero
    loss.backward() # calculate the gradient

    # Update W based on the gradient
    W.data += -50 * W.grad

# Generate sample names using the trained W matrix
g = torch.Generator().manual_seed(2147483647)

for i in range(5):
    out = []
    ix = 0
    while True:
        xenc = F.one_hot(torch.tensor([ix]), num_classes=27).float()
        logits = xenc @ W                         # predict log-counts
        counts = logits.exp()                     # counts, equivalent to N
        p = counts / counts.sum(1, keepdims=True) # probabilities for next character
        ix = torch.multinomial(p, num_samples=1, replacement=True, generator=g).item()
        out.append(itos[ix])
        if ix == 0:
            break