{ "cells": [ { "cell_type": "markdown", "source": [ "### Character-Level Language Model\n", "\n", "This notebook contains a generative model working at the level of characters.\n" ], "metadata": { "id": "sZ6I5Yr8QUtD" } }, { "cell_type": "code", "execution_count": 48, "metadata": { "id": "RqPkh2KTGUGp" }, "outputs": [], "source": [ "import numpy as np\n", "from numpy.random import randint,rand,seed,normal,permutation,choice\n", "\n", "import string\n", "import math\n", "\n", "import matplotlib.pyplot as plt\n", "from copy import deepcopy\n", "from tqdm import tqdm\n", "\n", "import torch\n", "from torch import nn, optim\n", "import torch.nn.functional as F\n", "from torch.utils.data import random_split,Dataset,DataLoader\n", "\n", "\n", "\n", "# from torchsummary import summary # must install using pip install torchsummary" ] }, { "cell_type": "code", "execution_count": 49, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "RcDYxJGhGUGr", "outputId": "9be52dba-3725-4654-f53b-46316d0de266" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount(\"/content/drive\", force_remount=True).\n" ] } ], "source": [ "from google.colab import drive\n", "drive.mount('/content/drive')\n", "\n", "data_dir = 'drive/MyDrive/CS505 Datasets/'" ] }, { "cell_type": "markdown", "source": [ "Load a text file. We chose a poem, to see how it did with line breaks. " ], "metadata": { "id": "FbYPF_0RQw1j" } }, { "cell_type": "code", "execution_count": 50, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 35 }, "id": "2_zGTaCKGUGs", "outputId": "46973f9f-1845-40ff-8287-b68b4eb8556a" }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "\"Of Man's first disobedience, and the fruit\\nOf that forbidden tree whose mortal taste\\nBrought death i\"" ], "application/vnd.google.colaboratory.intrinsic+json": { "type": "string" } }, "metadata": {}, "execution_count": 50 } ], "source": [ "\n", "with open(data_dir+\"Milton_Paradise_Lost.txt\", \"r\") as text_file:\n", " text = text_file.read()\n", "\n", "text[:100]" ] }, { "cell_type": "markdown", "source": [ "No normalization will be performed, however,\n", "we will run out of RAM if we attempt to\n", "use the entire poem as data. We have chosen\n", "here to use 10K characters, out of a total\n", "of" ], "metadata": { "id": "jvdwU6PeQugU" } }, { "cell_type": "code", "execution_count": 51, "metadata": { "id": "lVfP44ZUGUGs", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "4f393b80-cd76-4f02-a3e7-1c9c5e21ae21" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Text is 456475 characters long.\n" ] } ], "source": [ "print(f\"Text is {len(text)} characters long.\")\n", "\n", "size = 10000\n", "\n", "text = text[:size]" ] }, { "cell_type": "markdown", "source": [ "Next we figure out how many distinct characters there are in the text; this\n", "will be what is generated at each step of the generation." ], "metadata": { "id": "rwgjqUe5R0eX" } }, { "cell_type": "code", "execution_count": 52, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "aYeHZcdzGUGt", "outputId": "ef380ca5-16ff-467a-fc86-5f3719f42b61" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "There are 62 characters in the text.\n", "Character set: ['\\n', ' ', '!', '\"', \"'\", '(', ')', ',', '-', '.', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'R', 'S', 'T', 'U', 'V', 'W', 'Y', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z'].\n" ] } ], "source": [ "chars_in_text = sorted(list(set(text)))\n", "\n", "num_chars = len(chars_in_text)\n", "\n", "print(f'There are {num_chars} characters in the text.')\n", "\n", "\n", "print(f'Character set: {chars_in_text}.')\n" ] }, { "cell_type": "code", "source": [ "# Create functions mapping characters to integers and back\n", "\n", "def char2int(c):\n", " return chars_in_text.index(c)\n", "\n", "def int2char(i):\n", " return chars_in_text[i]" ], "metadata": { "id": "gwALsQhiStNT" }, "execution_count": 53, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "68JWvJxdGUGt" }, "source": [ "As we're going to predict the next character in the sequence at each time step, we'll have to divide each sentence into\n", "\n", "- Input data\n", " - The last input character should be excluded as it does not need to be fed into the model\n", "- Target/Ground Truth Label\n", " - One time-step ahead of the Input data as this will be the \"correct answer\" for the model at each time step corresponding to the input data\n", "\n", "The sample length is a critical parameter which tells us how much of the source data to ingest at each training step. You might want to play around with this as one of the hyperparameters." ] }, { "cell_type": "code", "execution_count": 54, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "lYyi7B_IGUGu", "outputId": "ee9dab6f-b29d-41e5-a2be-78a32eb6f011" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Input sequence:\n", "Of Man's first disobedience, and the fruit\n", "Of that forbidden tree whose mortal taste\n", "Brought death \n", "Target sequence:\n", "f Man's first disobedience, and the fruit\n", "Of that forbidden tree whose mortal taste\n", "Brought death i\n", "\n", "Input sequence:\n", "f Man's first disobedience, and the fruit\n", "Of that forbidden tree whose mortal taste\n", "Brought death i\n", "Target sequence:\n", " Man's first disobedience, and the fruit\n", "Of that forbidden tree whose mortal taste\n", "Brought death in\n", "\n", "Input sequence:\n", " Man's first disobedience, and the fruit\n", "Of that forbidden tree whose mortal taste\n", "Brought death in\n", "Target sequence:\n", "Man's first disobedience, and the fruit\n", "Of that forbidden tree whose mortal taste\n", "Brought death int\n", "\n", "Input sequence:\n", "Man's first disobedience, and the fruit\n", "Of that forbidden tree whose mortal taste\n", "Brought death int\n", "Target sequence:\n", "an's first disobedience, and the fruit\n", "Of that forbidden tree whose mortal taste\n", "Brought death into\n", "\n", "Input sequence:\n", "an's first disobedience, and the fruit\n", "Of that forbidden tree whose mortal taste\n", "Brought death into\n", "Target sequence:\n", "n's first disobedience, and the fruit\n", "Of that forbidden tree whose mortal taste\n", "Brought death into \n", "\n" ] } ], "source": [ "sample_len = 100\n", "\n", "# Creating lists that will hold our input and target sample sequences\n", "\n", "input_seq_chars = []\n", "target_seq_chars = []\n", "\n", "for k in range(len(text)-sample_len+1):\n", "\n", " # Remove last character for input sequence\n", " input_seq_chars.append(text[k:k+sample_len-1])\n", "\n", " # Remove firsts character for target sequence\n", " target_seq_chars.append(text[k+1:k+sample_len])\n", "\n", "for i in range(5):\n", " print(f'Input sequence:\\n{input_seq_chars[i]}')\n", " print(f'Target sequence:\\n{target_seq_chars[i]}')\n", " print()\n" ] }, { "cell_type": "markdown", "metadata": { "id": "DYT6SCJ9GUGu" }, "source": [ "Now we can convert our input and target sequences to sequences of integers instead of characters by mapping them using the functions we created above. This will allow us to one-hot-encode our input sequence later." ] }, { "cell_type": "code", "execution_count": 55, "metadata": { "id": "WJXQF5NAGUGu", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "74125302-b170-4b2c-d9f8-97076382e5b2" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "[27, 41, 1, 25, 36, 49, 4, 54, 1, 41, 44, 53, 54, 55, 1, 39, 44, 54, 50, 37, 40, 39, 44, 40, 49, 38, 40, 7, 1, 36, 49, 39, 1, 55, 43, 40, 1, 41, 53, 56, 44, 55, 0, 27, 41, 1, 55, 43, 36, 55, 1, 41, 50, 53, 37, 44, 39, 39, 40, 49, 1, 55, 53, 40, 40, 1, 58, 43, 50, 54, 40, 1, 48, 50, 53, 55, 36, 47, 1, 55, 36, 54, 55, 40, 0, 14, 53, 50, 56, 42, 43, 55, 1, 39, 40, 36, 55, 43, 1]\n" ] } ], "source": [ "input_seq = []\n", "target_seq = []\n", "\n", "for i in range(len(input_seq_chars)):\n", " input_seq.append( [char2int(ch) for ch in input_seq_chars[i]])\n", " target_seq.append([char2int(ch) for ch in target_seq_chars[i]])\n", "\n", "print(input_seq[0])" ] }, { "cell_type": "code", "execution_count": 56, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "3Q5xPFM0GUGu", "outputId": "cd14643b-f18b-4397-9541-1fceb596c4f9" }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "array([[0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],\n", " [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]])" ] }, "metadata": {}, "execution_count": 56 } ], "source": [ "# convert an integer into a one-hot encoding of the given size (= number of characters)\n", "def int2OneHot(X,size):\n", "\n", " def int2OneHot1(x,size=10):\n", " tmp = np.zeros(size)\n", " tmp[int(x)] = 1.0\n", " return tmp\n", "\n", " return np.array([ int2OneHot1(x, size) for x in X ]).astype('double')\n", "\n", "int2OneHot( np.array([ 2,3,1,2,3,4 ]),10)" ] }, { "cell_type": "code", "execution_count": 57, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "0v3RGM8AGUGv", "outputId": "55840113-66ca-4b5f-e6b5-64d8c49c16cf" }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "array([[[0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],\n", " [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]],\n", "\n", " [[0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],\n", " [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]],\n", "\n", " [[0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],\n", " [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]]])" ] }, "metadata": {}, "execution_count": 57 } ], "source": [ "# do the same thing, but for a list/array of integers\n", "\n", "def seq2OneHot(seq,size):\n", " return np.array([ int2OneHot(x, size) for x in seq ])\n", "\n", "seq2OneHot( np.array([[ 2,3,1,2,3,4 ],[ 2,3,1,2,3,4 ],[ 2,3,1,2,3,4 ]]),10)" ] }, { "cell_type": "code", "execution_count": 58, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "nbTrt7AjGUGv", "outputId": "3d839dbb-597a-44c8-c81f-e89a9b55a93b" }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "(9901, 99, 62)" ] }, "metadata": {}, "execution_count": 58 } ], "source": [ "# Convert our input sequences to one-hot form\n", "\n", "input_seq = seq2OneHot(input_seq,size=num_chars)\n", "input_seq.shape" ] }, { "cell_type": "code", "execution_count": 59, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "B1KM2AaXGUGv", "outputId": "b7532041-b4a6-4a14-b026-39b090383606" }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "(9901, 99, 62)" ] }, "metadata": {}, "execution_count": 59 } ], "source": [ "# Convert our target sequences to one-hot form\n", "\n", "target_seq = seq2OneHot(target_seq,size=num_chars)\n", "target_seq.shape" ] }, { "cell_type": "markdown", "metadata": { "id": "YEYJ0NE0GUGv" }, "source": [ "Since we're done with all the data pre-processing, we can now move the data from numpy arrays to tensors." ] }, { "cell_type": "code", "execution_count": 60, "metadata": { "id": "CcrO5N7RGUGv" }, "outputs": [], "source": [ "input_seq = torch.Tensor(input_seq).type(torch.DoubleTensor)\n", "target_seq = torch.Tensor(target_seq).type(torch.DoubleTensor)" ] }, { "cell_type": "markdown", "metadata": { "id": "D5NEt539GUGv" }, "source": [ "Now we will build a data loader to manage the batching." ] }, { "cell_type": "code", "source": [ "class Basic_Dataset(Dataset):\n", "\n", " def __init__(self, X,Y):\n", " self.X = X\n", " self.Y = Y\n", "\n", " def __len__(self):\n", " return len(self.X)\n", "\n", " # return a pair x,y at the index idx in the data set\n", " def __getitem__(self, idx):\n", " return self.X[idx], self.Y[idx]\n", "\n", "ds = Basic_Dataset(input_seq,target_seq)\n", "\n", "ds.__len__()" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "hhkkAfJVVViF", "outputId": "165c3008-fcc8-42b6-901a-805f35fb46c7" }, "execution_count": 61, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "9901" ] }, "metadata": {}, "execution_count": 61 } ] }, { "cell_type": "markdown", "source": [ "Batch size is a hyperparameter that will mostly determine how efficiently you can process the data on a GPU." ], "metadata": { "id": "8FPjzvUbVrAd" } }, { "cell_type": "code", "execution_count": 62, "metadata": { "id": "v4vgBd0jGUGv" }, "outputs": [], "source": [ "batch_size = 128\n", "\n", "data_loader = DataLoader(ds, batch_size=batch_size, shuffle=True)\n" ] }, { "cell_type": "markdown", "metadata": { "id": "_wgOsxHYGUGv" }, "source": [ "Check if a GPU is available and use it if it is." ] }, { "cell_type": "code", "execution_count": 63, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "BFwBQZ1uGUGv", "outputId": "3242cf1c-56d9-4600-922c-91e3248a86ac" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "GPU is available\n" ] } ], "source": [ "# torch.cuda.is_available() checks and returns a Boolean True if a GPU is available, else it'll return False\n", "is_cuda = torch.cuda.is_available()\n", "\n", "# If we have a GPU available, we'll set our device to GPU. We'll use this device variable later in our code.\n", "if is_cuda:\n", " device = torch.device(\"cuda\")\n", " print(\"GPU is available\")\n", "else:\n", " device = torch.device(\"cpu\")\n", " print(\"GPU not available, CPU used\")" ] }, { "cell_type": "markdown", "metadata": { "id": "R6u-fhYkGUGv" }, "source": [ "The model will use an LSTM layer and a single linear layer to produce a softmax\n", "of the next character. Various hyperparameters can be chosen to modify this\n", "model. A messy detail is that two vectors, h0 and c0, have to be created for the hidden state in the LSTM layer (these correspond to the two connections\n", "shown in lecture for an LSTM neuron to send to itself in the next time step). " ] }, { "cell_type": "code", "execution_count": 64, "metadata": { "id": "dzDdSvGAGUGw" }, "outputs": [], "source": [ "from os import device_encoding\n", "class Model(nn.Module):\n", " def __init__(self, input_size, output_size, hidden_dim, n_layers,dropout):\n", " super(Model, self).__init__()\n", "\n", " # Defining some parameters\n", " self.hidden_dim = hidden_dim\n", " self.n_layers = n_layers\n", "\n", " #Defining the layers\n", " self.lstm = nn.LSTM(input_size, hidden_dim, n_layers,dropout=dropout,batch_first=True)\n", " # Fully connected layer\n", " self.fc1 = nn.Linear(hidden_dim, output_size)\n", "\n", " def forward(self, x):\n", "\n", " hidden_state_size = x.size(0)\n", "\n", " x = x.to(torch.double)\n", "\n", " h0 = torch.zeros(self.n_layers,hidden_state_size,self.hidden_dim).double().to(device)\n", " c0 = torch.zeros(self.n_layers,hidden_state_size,self.hidden_dim).double().to(device)\n", "\n", " self.lstm = self.lstm.double()\n", "\n", " self.fc1 = self.fc1.double()\n", "\n", " # Passing in the input and hidden state into the model and obtaining outputs\n", " out, (hx,cx) = self.lstm(x, (h0,c0))\n", "\n", " # Reshaping the outputs such that it can be fit into the fully connected layer\n", " out = out.contiguous().view(-1, self.hidden_dim)\n", " out = self.fc1(out)\n", "\n", " return out\n", "\n" ] }, { "cell_type": "markdown", "source": [ "Next, we instantiate the model with its hyperparameters, all of which can be\n", "changed." ], "metadata": { "id": "xr571AKQXDQH" } }, { "cell_type": "code", "execution_count": 65, "metadata": { "id": "zNUwEsRDGUGw", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "6331a61c-88fe-441c-ed46-1ce08d06d502" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Model(\n", " (lstm): LSTM(62, 256, batch_first=True)\n", " (fc1): Linear(in_features=256, out_features=62, bias=True)\n", ")\n" ] } ], "source": [ "# Instantiate the model with hyperparameters\n", "\n", "model = Model(input_size=num_chars, output_size=num_chars, hidden_dim=256, n_layers=1,dropout=0.0)\n", "\n", "print(model)\n", "\n", "model = model.double().to(device)\n", "\n", "# Define Loss, Optimizer\n", "loss_fn = nn.CrossEntropyLoss()\n", "\n", "optimizer = torch.optim.Adam(model.parameters(), lr=0.001,weight_decay=0.0)\n", "\n" ] }, { "cell_type": "markdown", "source": [ "The following is a minimal training loop. We just track the loss, since accuracy\n", "is not the point of a generative model.\n", "\n", "However, overfitting is very much a problem. You will see that overfitting has occurred when you give as prompt a prefix of the text (say the first line) and in generation it just spits out the text (which it has memorized)." ], "metadata": { "id": "rDrJKscMXaU0" } }, { "cell_type": "code", "execution_count": 66, "metadata": { "scrolled": false, "colab": { "base_uri": "https://localhost:8080/", "height": 487 }, "id": "HVClRHtXGUGw", "outputId": "9b729a48-5cbf-46cd-f76a-149401ac695a" }, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "100%|██████████| 10/10 [01:54<00:00, 11.45s/it]\n" ] }, { "output_type": "execute_result", "data": { "text/plain": [ "[]" ] }, "metadata": {}, "execution_count": 66 }, { "output_type": "display_data", "data": { "text/plain": [ "
" ], "image/png": "\n" }, "metadata": {} } ], "source": [ "num_epochs = 10\n", "\n", "losses = []\n", "\n", "model.train()\n", "\n", "for epoch in tqdm(range(num_epochs)):\n", "\n", " for input_seq_batch,target_seq_batch in data_loader:\n", " input_seq_batch = input_seq_batch.to(device)\n", " target_seq_batch = target_seq_batch.to(device)\n", " optimizer.zero_grad()\n", " target_seq_hat = model(input_seq_batch)\n", " loss = loss_fn(target_seq_hat,target_seq_batch.view(-1,num_chars))\n", " loss.backward()\n", " optimizer.step()\n", "\n", " losses.append(loss.item())\n", "\n", "\n", "plt.title('Loss')\n", "plt.plot(losses)" ] }, { "cell_type": "markdown", "source": [ "The temperature of a softmax function will determine the relative strength of different probabilities:\n", "- As temperature approaches 0, distribution approaches a one-hot with 1 for the max\n", "- As temperature increases, it approaches a uniform distribution\n", "\n", "Generally we want to emphasize the higher probabilities, so we choose\n", "a reasonably low temperature." ], "metadata": { "id": "uuyAwhLPXtGd" } }, { "cell_type": "code", "source": [ "\n", "def softmax_with_temperature(vec, temperature):\n", " sum_exp = sum(math.exp(x/temperature) for x in vec)\n", " return [math.exp(x/temperature)/sum_exp for x in vec]\n", "\n", "print(\"Example of softmax with temperature.\")\n", "dist = [0.1, 0.3, 0.6]\n", "print('distribution:',dist)\n", "print(softmax_with_temperature(dist,0.01))\n", "print(softmax_with_temperature(dist,0.1))\n", "print(softmax_with_temperature(dist,0.2))\n", "print(softmax_with_temperature(dist,0.3))\n", "print(softmax_with_temperature(dist,1))\n", "print(softmax_with_temperature(dist,10))" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Sg7j3CzGQDkf", "outputId": "35c12c65-d238-46c7-a812-4478cd0191fe" }, "execution_count": 67, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Example of softmax with temperature.\n", "distribution: [0.1, 0.3, 0.6]\n", "[1.9287498479637375e-22, 9.3576229688393e-14, 0.9999999999999064]\n", "[0.006377460922442302, 0.04712341652466416, 0.9464991225528936]\n", "[0.06289001324586753, 0.1709527801977903, 0.7661572065563421]\n", "[0.12132647558421489, 0.23631170657656433, 0.6423618178392208]\n", "[0.2583896517379799, 0.3155978333128144, 0.4260125149492058]\n", "[0.3255767455856355, 0.3321538321280155, 0.3422694222863489]\n" ] } ] }, { "cell_type": "markdown", "source": [ "Choose a temperature and predict the next character, given a prompt of arbitrary length." ], "metadata": { "id": "hLlygE7WYigs" } }, { "cell_type": "code", "execution_count": 68, "metadata": { "scrolled": false, "colab": { "base_uri": "https://localhost:8080/", "height": 35 }, "id": "kLJX_vFSGUGw", "outputId": "91ba9e04-6bbf-4b80-c302-db6fb13dbf55" }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "'n'" ], "application/vnd.google.colaboratory.intrinsic+json": { "type": "string" } }, "metadata": {}, "execution_count": 68 } ], "source": [ "temperature = 0.3\n", "\n", "def predict(model, ch):\n", "\n", " # only look at last sample_len - 1 characters\n", "\n", " ch = ch[-(sample_len - 1):]\n", "\n", " # One-hot encoding our input to fit into the model\n", " ch = np.array([char2int(c) for c in ch])\n", " ch = np.array([int2OneHot(ch, num_chars)])\n", " ch = torch.from_numpy(ch).to(device)\n", "\n", " out = model(ch)\n", "\n", " # take the probability distribution of the last character in the sequence produced by the model\n", " prob = softmax_with_temperature(out[-1],temperature)\n", "\n", " # Choosing a character based on the probability distribution, with temperature\n", " char_ind = choice(list(range(num_chars)), p=prob)\n", "\n", " return int2char(char_ind)\n", "\n", "predict(model,\"Of man's first disobedience, and the fruit o\")" ] }, { "cell_type": "markdown", "source": [ "Now take a prompt and iterate the previous prediction a specified number of times.\n", "\n", "Prompt is generally taken to be a long sequence randomly selected from the text. You can also try a sequence of words similar to those in the text, but not an exact sequence. It does not have to be the exact length of the data sequences. However, very short prompts tend not to work as well." ], "metadata": { "id": "wThVXlXhYgqt" } }, { "cell_type": "code", "execution_count": 69, "metadata": { "id": "eFcGfQ8VGUGw" }, "outputs": [], "source": [ "def sample(model, out_len, start):\n", " model.eval() # eval mode\n", " # First off, run through the starting characters\n", " chars = [ch for ch in start]\n", " size = out_len - len(chars)\n", " # Now pass in the previous characters and get a new one\n", " for ii in range(size):\n", " char = predict(model, chars)\n", " chars.append(char)\n", "\n", " return ''.join(chars)" ] }, { "cell_type": "markdown", "source": [ "Now we will run our model, but with the parameters we have chosen, and\n", "10 epochs, you can see that it is getting some idea of words and lines, but\n", "it doesn't look like an English poem!\n", "\n", "Run this for another 100 epochs, and observe that at that point,\n", "the network will have simply memorized the poem!" ], "metadata": { "id": "LONowWPkg6bj" } }, { "cell_type": "code", "execution_count": 71, "metadata": { "scrolled": false, "colab": { "base_uri": "https://localhost:8080/" }, "id": "ra7VOpilGUGw", "outputId": "df558ad6-30d1-4840-932a-95e364358acf" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Of Man's first disobedience, and the fruit ros ore thall sus ingeres the store he preat hes pare,\n", "He veint of the gerat on the flaid the vinged of Heaven the wath of the gorthall in served on the gromithe firce sedpent his whale, and will his grong of Heaven on and will his from the prot the tho ders and cengen,\n", "With hese and rought force his fire,\n", "And simed tout for sowe,\n", "That whot ender st of the geall the ing hinger and ranger\n", "Th the his the farl se and rus ind and sure seall\n", "The force of the sore that for hised\n", "And with upmighto formte of the gore ther th of the seaven on the grom the farl seand sulled the the the his dering of Hian, whather th of beis and cous reaven he pering of the dere thers his fired\n", "That with of reaven on the derich and righty stored the the the flor the vence the the sing the bort of Heaven on and compinise, and his pire,\n", "And his om the Alill nised the serce on the derat of Heaven se pint,\n", "And cous ffre stound he pires\n", "And th ur poreid de and sure dain the s\n" ] } ], "source": [ "print(sample(model, 1000, \"Of Man's first disobedience, and the fruit\"))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.9" }, "colab": { "provenance": [], "gpuType": "T4" }, "accelerator": "GPU" }, "nbformat": 4, "nbformat_minor": 0 }