aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSIPB2024-11-21 01:37:45 +0000
committerSIPB2024-11-21 01:37:45 +0000
commit072c87a342e07b6498e8568a9be5e6a068feadde (patch)
tree5ff1e5d6714362f32a3ff089a12c5f802b6c6a34
parent68887a78efee447a389275d898f69cc9a796ab85 (diff)
New embeddings and readout
-rw-r--r--transformer_shortest_paths.ipynb997
1 files changed, 250 insertions, 747 deletions
diff --git a/transformer_shortest_paths.ipynb b/transformer_shortest_paths.ipynb
index b2cea24..a158fcb 100644
--- a/transformer_shortest_paths.ipynb
+++ b/transformer_shortest_paths.ipynb
@@ -16,12 +16,13 @@
"Question: \n",
"\n",
"- Do the attention heads learn to attend to the same positional encodings\n",
- "- do interp -- what is it doing? can we figure out?"
+ "- do interp -- what is it doing? can we figure out?\n",
+ "- update: I think we should do interp once it's bigger. "
]
},
{
"cell_type": "code",
- "execution_count": 75,
+ "execution_count": 1,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
@@ -31,18 +32,18 @@
},
"outputs": [],
"source": [
- "from collections import deque\n",
"# using tqdm.auto glitches out collaborative editing\n",
"from tqdm import tqdm\n",
- "\n",
"import torch\n",
"import torch.nn as nn\n",
- "from math import sqrt\n",
"import matplotlib.pyplot as plt\n",
- "torch.manual_seed(42)\n",
"\n",
+ "from math import sqrt\n",
+ "from collections import deque\n",
"import os\n",
"import random\n",
+ "\n",
+ "torch.manual_seed(42)\n",
"random.seed(42)\n",
"\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
@@ -51,18 +52,18 @@
},
{
"cell_type": "code",
- "execution_count": 91,
+ "execution_count": 2,
"metadata": {
"id": "lylOX2POPwFL"
},
"outputs": [],
"source": [
"# VTXS numbers here are inclusive\n",
- "MIN_VTXS = 3 # 3\n",
- "MAX_VTXS = 3 # 8\n",
- "MAX_TUNE_VTXS = 3 # 15\n",
+ "MIN_VTXS = 3\n",
+ "MAX_VTXS = 31\n",
+ "MAX_TUNE_VTXS = 15 # 15\n",
"AVG_DEG = 2\n",
- "SEQ_LEN = MAX_VTXS * AVG_DEG + 1 # means 32 edges, final token is the target vertex\n",
+ "SEQ_LEN = MAX_VTXS + 1 # means 32 edges, final token is the target vertex\n",
"PAD_TOKEN = 0\n",
"# vertices are labelled 1,2,...,63\n",
"# we also have a padding token which is 0."
@@ -79,7 +80,7 @@
},
{
"cell_type": "code",
- "execution_count": 76,
+ "execution_count": 4,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
@@ -102,16 +103,15 @@
"def random_graph(n):\n",
" edge_list = []\n",
" adjacencies = [set() for _ in range(n+1)]\n",
- " indices = [random.randint(1, n) for _ in range(AVG_DEG * (n-1))]\n",
+ " indices = [random.randint(1, n) for _ in range(AVG_DEG * n)]\n",
" for i in range(0, len(indices), 2):\n",
" u = indices[i]\n",
" v = indices[i + 1]\n",
" if u != v:\n",
- " edge_list += [u,v]\n",
+ " edge_list += [min(u,v),max(u,v)]\n",
" adjacencies[u].add(v)\n",
" adjacencies[v].add(u)\n",
- "\n",
- " edge_list += [PAD_TOKEN]*(SEQ_LEN-len(edge_list))\n",
+ " edge_list += [PAD_TOKEN]*(2*SEQ_LEN-1-len(edge_list))\n",
" return edge_list, adjacencies\n",
"\n",
"\"\"\"\n",
@@ -200,27 +200,71 @@
},
{
"cell_type": "code",
- "execution_count": 86,
+ "execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
- "(tensor([[4, 3, 2, 3, 1, 2, 0, 0, 2],\n",
- " [3, 4, 3, 1, 2, 3, 0, 0, 2],\n",
- " [4, 2, 3, 2, 0, 0, 0, 0, 2],\n",
- " [3, 4, 4, 3, 0, 0, 0, 0, 2],\n",
- " [1, 2, 2, 4, 0, 0, 0, 0, 2]], device='cuda:0'),\n",
- " tensor([1., 2., 4., 4., 1.], device='cuda:0', dtype=torch.bfloat16),\n",
- " tensor([[False, False, False, False, False, False, True, True, False],\n",
- " [False, False, False, False, False, False, True, True, False],\n",
- " [False, False, False, False, True, True, True, True, False],\n",
- " [False, False, False, False, True, True, True, True, False],\n",
- " [False, False, False, False, True, True, True, True, False]],\n",
- " device='cuda:0'))"
+ "(tensor([[ 1, 4, 8, 9, 5, 8, 4, 22, 3, 18, 14, 19, 1, 2, 3, 7, 8, 17,\n",
+ " 1, 20, 7, 18, 21, 23, 18, 23, 8, 14, 15, 19, 1, 9, 6, 23, 11, 14,\n",
+ " 5, 9, 7, 11, 3, 4, 4, 13, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
+ " 0, 0, 0, 0, 0, 0, 0, 0, 2],\n",
+ " [ 2, 9, 15, 18, 4, 13, 3, 18, 10, 21, 12, 20, 7, 19, 2, 3, 8, 22,\n",
+ " 3, 10, 4, 8, 9, 13, 15, 21, 6, 12, 7, 22, 9, 22, 3, 21, 20, 21,\n",
+ " 6, 18, 6, 8, 13, 15, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
+ " 0, 0, 0, 0, 0, 0, 0, 0, 2],\n",
+ " [ 9, 11, 4, 11, 1, 6, 1, 4, 6, 7, 2, 5, 4, 10, 4, 6, 8, 11,\n",
+ " 7, 11, 3, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
+ " 0, 0, 0, 0, 0, 0, 0, 0, 2],\n",
+ " [ 3, 4, 5, 10, 7, 10, 6, 7, 3, 4, 8, 9, 1, 2, 2, 3, 3, 11,\n",
+ " 7, 11, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
+ " 0, 0, 0, 0, 0, 0, 0, 0, 2],\n",
+ " [ 3, 13, 13, 20, 15, 17, 9, 18, 1, 22, 4, 22, 9, 18, 11, 21, 4, 10,\n",
+ " 6, 14, 1, 15, 9, 17, 6, 17, 4, 21, 10, 21, 17, 20, 5, 7, 6, 12,\n",
+ " 17, 18, 1, 20, 11, 16, 1, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
+ " 0, 0, 0, 0, 0, 0, 0, 0, 2]], device='cuda:0'),\n",
+ " tensor([ 1., 22., 11., 1., 22.], device='cuda:0', dtype=torch.bfloat16),\n",
+ " tensor([[False, False, False, False, False, False, False, False, False, False,\n",
+ " False, False, False, False, False, False, False, False, False, False,\n",
+ " False, False, False, False, False, False, False, False, False, False,\n",
+ " False, False, False, False, False, False, False, False, False, False,\n",
+ " False, False, False, False, True, True, True, True, True, True,\n",
+ " True, True, True, True, True, True, True, True, True, True,\n",
+ " True, True, False],\n",
+ " [False, False, False, False, False, False, False, False, False, False,\n",
+ " False, False, False, False, False, False, False, False, False, False,\n",
+ " False, False, False, False, False, False, False, False, False, False,\n",
+ " False, False, False, False, False, False, False, False, False, False,\n",
+ " False, False, True, True, True, True, True, True, True, True,\n",
+ " True, True, True, True, True, True, True, True, True, True,\n",
+ " True, True, False],\n",
+ " [False, False, False, False, False, False, False, False, False, False,\n",
+ " False, False, False, False, False, False, False, False, False, False,\n",
+ " False, False, True, True, True, True, True, True, True, True,\n",
+ " True, True, True, True, True, True, True, True, True, True,\n",
+ " True, True, True, True, True, True, True, True, True, True,\n",
+ " True, True, True, True, True, True, True, True, True, True,\n",
+ " True, True, False],\n",
+ " [False, False, False, False, False, False, False, False, False, False,\n",
+ " False, False, False, False, False, False, False, False, False, False,\n",
+ " True, True, True, True, True, True, True, True, True, True,\n",
+ " True, True, True, True, True, True, True, True, True, True,\n",
+ " True, True, True, True, True, True, True, True, True, True,\n",
+ " True, True, True, True, True, True, True, True, True, True,\n",
+ " True, True, False],\n",
+ " [False, False, False, False, False, False, False, False, False, False,\n",
+ " False, False, False, False, False, False, False, False, False, False,\n",
+ " False, False, False, False, False, False, False, False, False, False,\n",
+ " False, False, False, False, False, False, False, False, False, False,\n",
+ " False, False, False, False, True, True, True, True, True, True,\n",
+ " True, True, True, True, True, True, True, True, True, True,\n",
+ " True, True, False]], device='cuda:0'))"
]
},
- "execution_count": 86,
+ "execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
@@ -231,41 +275,39 @@
},
{
"cell_type": "code",
- "execution_count": 87,
+ "execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
- "(array([ 663., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
- " 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
- " 0., 0., 0., 284., 0., 0., 0., 0., 0.,\n",
- " 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
- " 0., 0., 0., 0., 0., 0., 39., 0., 0.,\n",
- " 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
- " 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
- " 1062.]),\n",
- " array([1. , 1.046875, 1.09375 , 1.140625, 1.1875 , 1.234375,\n",
- " 1.28125 , 1.328125, 1.375 , 1.421875, 1.46875 , 1.515625,\n",
- " 1.5625 , 1.609375, 1.65625 , 1.703125, 1.75 , 1.796875,\n",
- " 1.84375 , 1.890625, 1.9375 , 1.984375, 2.03125 , 2.078125,\n",
- " 2.125 , 2.171875, 2.21875 , 2.265625, 2.3125 , 2.359375,\n",
- " 2.40625 , 2.453125, 2.5 , 2.546875, 2.59375 , 2.640625,\n",
- " 2.6875 , 2.734375, 2.78125 , 2.828125, 2.875 , 2.921875,\n",
- " 2.96875 , 3.015625, 3.0625 , 3.109375, 3.15625 , 3.203125,\n",
- " 3.25 , 3.296875, 3.34375 , 3.390625, 3.4375 , 3.484375,\n",
- " 3.53125 , 3.578125, 3.625 , 3.671875, 3.71875 , 3.765625,\n",
- " 3.8125 , 3.859375, 3.90625 , 3.953125, 4. ]),\n",
+ "(array([320., 0., 310., 0., 264., 0., 178., 0., 119., 0., 90.,\n",
+ " 0., 69., 0., 42., 0., 0., 39., 0., 30., 0., 30.,\n",
+ " 0., 31., 0., 26., 0., 28., 0., 27., 0., 0., 20.,\n",
+ " 0., 30., 0., 29., 0., 26., 0., 20., 0., 30., 0.,\n",
+ " 34., 0., 23., 0., 0., 35., 0., 29., 0., 17., 0.,\n",
+ " 26., 0., 33., 0., 35., 0., 32., 0., 26.]),\n",
+ " array([ 1. , 1.46875, 1.9375 , 2.40625, 2.875 , 3.34375,\n",
+ " 3.8125 , 4.28125, 4.75 , 5.21875, 5.6875 , 6.15625,\n",
+ " 6.625 , 7.09375, 7.5625 , 8.03125, 8.5 , 8.96875,\n",
+ " 9.4375 , 9.90625, 10.375 , 10.84375, 11.3125 , 11.78125,\n",
+ " 12.25 , 12.71875, 13.1875 , 13.65625, 14.125 , 14.59375,\n",
+ " 15.0625 , 15.53125, 16. , 16.46875, 16.9375 , 17.40625,\n",
+ " 17.875 , 18.34375, 18.8125 , 19.28125, 19.75 , 20.21875,\n",
+ " 20.6875 , 21.15625, 21.625 , 22.09375, 22.5625 , 23.03125,\n",
+ " 23.5 , 23.96875, 24.4375 , 24.90625, 25.375 , 25.84375,\n",
+ " 26.3125 , 26.78125, 27.25 , 27.71875, 28.1875 , 28.65625,\n",
+ " 29.125 , 29.59375, 30.0625 , 30.53125, 31. ]),\n",
" <BarContainer object of 64 artists>)"
]
},
- "execution_count": 87,
+ "execution_count": 6,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
- "image/png": "",
+ "image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
@@ -280,40 +322,40 @@
},
{
"cell_type": "code",
- "execution_count": 90,
+ "execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
- "(array([1714., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
- " 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
- " 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
- " 0., 0., 0., 0., 0., 305., 0., 0., 0.,\n",
- " 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
- " 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
+ "(array([1289., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
+ " 0., 477., 0., 0., 0., 0., 0., 0., 0.,\n",
+ " 0., 0., 0., 192., 0., 0., 0., 0., 0.,\n",
+ " 0., 0., 0., 0., 0., 55., 0., 0., 0.,\n",
+ " 0., 0., 0., 0., 0., 0., 28., 0., 0.,\n",
+ " 0., 0., 0., 0., 0., 0., 0., 0., 5.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
- " 29.]),\n",
- " array([1. , 1.03125, 1.0625 , 1.09375, 1.125 , 1.15625, 1.1875 ,\n",
- " 1.21875, 1.25 , 1.28125, 1.3125 , 1.34375, 1.375 , 1.40625,\n",
- " 1.4375 , 1.46875, 1.5 , 1.53125, 1.5625 , 1.59375, 1.625 ,\n",
- " 1.65625, 1.6875 , 1.71875, 1.75 , 1.78125, 1.8125 , 1.84375,\n",
- " 1.875 , 1.90625, 1.9375 , 1.96875, 2. , 2.03125, 2.0625 ,\n",
- " 2.09375, 2.125 , 2.15625, 2.1875 , 2.21875, 2.25 , 2.28125,\n",
- " 2.3125 , 2.34375, 2.375 , 2.40625, 2.4375 , 2.46875, 2.5 ,\n",
- " 2.53125, 2.5625 , 2.59375, 2.625 , 2.65625, 2.6875 , 2.71875,\n",
- " 2.75 , 2.78125, 2.8125 , 2.84375, 2.875 , 2.90625, 2.9375 ,\n",
- " 2.96875, 3. ]),\n",
+ " 2.]),\n",
+ " array([1. , 1.09375, 1.1875 , 1.28125, 1.375 , 1.46875, 1.5625 ,\n",
+ " 1.65625, 1.75 , 1.84375, 1.9375 , 2.03125, 2.125 , 2.21875,\n",
+ " 2.3125 , 2.40625, 2.5 , 2.59375, 2.6875 , 2.78125, 2.875 ,\n",
+ " 2.96875, 3.0625 , 3.15625, 3.25 , 3.34375, 3.4375 , 3.53125,\n",
+ " 3.625 , 3.71875, 3.8125 , 3.90625, 4. , 4.09375, 4.1875 ,\n",
+ " 4.28125, 4.375 , 4.46875, 4.5625 , 4.65625, 4.75 , 4.84375,\n",
+ " 4.9375 , 5.03125, 5.125 , 5.21875, 5.3125 , 5.40625, 5.5 ,\n",
+ " 5.59375, 5.6875 , 5.78125, 5.875 , 5.96875, 6.0625 , 6.15625,\n",
+ " 6.25 , 6.34375, 6.4375 , 6.53125, 6.625 , 6.71875, 6.8125 ,\n",
+ " 6.90625, 7. ]),\n",
" <BarContainer object of 64 artists>)"
]
},
- "execution_count": 90,
+ "execution_count": 7,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
- "image/png": "",
+ "image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
@@ -328,21 +370,21 @@
},
{
"cell_type": "code",
- "execution_count": 97,
+ "execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
- "import pickle\n",
+ "# import pickle\n",
"\n",
- "graphs, labels, _ = mkbatch(3*10**5)\n",
+ "# graphs, labels, _ = mkbatch(3*10**5)\n",
"\n",
- "data = {\n",
- " \"data\": graphs,\n",
- " \"labels\": labels\n",
- "}\n",
+ "# data = {\n",
+ "# \"data\": graphs,\n",
+ "# \"labels\": labels\n",
+ "# }\n",
"\n",
- "with open('data.pkl', 'wb') as file:\n",
- " pickle.dump(data, file)"
+ "# with open('data.pkl', 'wb') as file:\n",
+ "# pickle.dump(data, file)"
]
},
{
@@ -356,43 +398,34 @@
},
{
"cell_type": "code",
- "execution_count": 135,
+ "execution_count": 10,
"metadata": {
"id": "tLOWhg_CeWzH"
},
"outputs": [],
"source": [
"class TransformerModel(nn.Module):\n",
- " def __init__(self, input_dim, model_dim, output_dim, num_heads, num_layers, seq_len, device, dropout):\n",
+ " def __init__(self, input_dim, model_dim, output_dim, num_heads, num_layers, seq_len, dropout):\n",
" super().__init__()\n",
- " self.embedding = nn.Embedding(input_dim, model_dim//2, dtype=torch.bfloat16)\n",
- " # seq_len is odd\n",
- " self.fancy_encoding = torch.repeat_interleave(torch.rand((1, seq_len // 2 + 1, model_dim // 2), device=device, dtype=torch.bfloat16), 2, dim=1)\n",
- " # cut off last element since the target vertex is not repeated\n",
- " self.fancy_encoding = self.fancy_encoding[:, :seq_len, :]\n",
- " \n",
" self.model_dim = model_dim\n",
- " self.seq_len = seq_len\n",
- " self.device = device\n",
- "\n",
+ " self.embedding = nn.Embedding(input_dim, model_dim // 2, dtype=torch.bfloat16)\n",
+ " # # seq_len is odd\n",
+ " # self.fancy_encoding = torch.repeat_interleave(torch.rand((1, seq_len // 2 + 1, model_dim // 2), device=device, dtype=torch.bfloat16), 2, dim=1)\n",
+ " # # cut off last element since the target vertex is not repeated\n",
+ " # self.fancy_encoding = self.fancy_encoding[:, :seq_len, :]\n",
" encoder_layer = nn.TransformerEncoderLayer(d_model=model_dim, nhead=num_heads,\n",
" dim_feedforward=model_dim*4,\n",
" dropout=dropout, batch_first=True, dtype=torch.bfloat16)\n",
" self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers)\n",
+ " self.fc_out = nn.Linear(model_dim, output_dim, dtype=torch.bfloat16)\n",
"\n",
- " self.fc_out = nn.Linear(model_dim*seq_len, output_dim, dtype=torch.bfloat16)\n",
- "\n",
- " def full_embedding(self, src):\n",
- " batch_size, src_len = src.size(0), src.size(1)\n",
- " return torch.cat((self.embedding(src) * sqrt(self.model_dim), self.fancy_encoding.repeat((batch_size, 1, 1))), dim=2)\n",
- " \n",
" def forward(self, src, key_padding_mask):\n",
- " embed = self.full_embedding(src)\n",
- " output = self.transformer_encoder(embed, src_key_padding_mask=key_padding_mask)\n",
- " output[key_padding_mask] = 0 # Hack to stop no_grad problem\n",
- " flat_output = torch.flatten(output, start_dim=1, end_dim=2)\n",
- " output = self.fc_out(flat_output)\n",
- " return output"
+ " batch_sz = src.size(0)\n",
+ " embed = torch.cat((self.embedding(src[:,:-1:2]), self.embedding(src[:,1::2])), dim=2)\n",
+ " last_dude = torch.cat((self.embedding(src[:,-1:]), torch.ones((batch_sz, 1, self.model_dim // 2), dtype=torch.bfloat16, device=device)), dim=2)\n",
+ " final_embed = torch.cat((embed, last_dude), dim=1)\n",
+ " output = self.transformer_encoder(final_embed, src_key_padding_mask=key_padding_mask[:, ::2])\n",
+ " return self.fc_out(output[:, -1, :])"
]
},
{
@@ -406,7 +439,7 @@
},
{
"cell_type": "code",
- "execution_count": 136,
+ "execution_count": 11,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
@@ -419,8 +452,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "Training data: 1049M\n",
- "Trainable parameters in the model: 50K\n"
+ "Training data: 524M\n",
+ "Trainable parameters in the model: 800K\n"
]
}
],
@@ -429,17 +462,17 @@
"VOCAB_SIZE = 1 + MAX_VTXS # one more than the max number of vertices\n",
"MODEL_DIM = 64 # Dimension of model (embedding and transformer)\n",
"NEPOCHS = 1000\n",
- "BSZ = 8196 * 4 # Batch size\n",
+ "BSZ = 2**14 # Batch size\n",
"BPE = 32 # Batches per epoch\n",
- "LR = 5e-3\n",
- "WD = 2e-3\n",
- "NHEADS = 1 #4\n",
- "NLAYERS = 1 #16\n",
- "DROPOUT = 0.2\n",
+ "LR = 5e-5\n",
+ "WD = 1e-5\n",
+ "NHEADS = 4\n",
+ "NLAYERS = 16\n",
+ "DROPOUT = 0 # 0.2\n",
"model = TransformerModel(input_dim=VOCAB_SIZE, model_dim=MODEL_DIM,\n",
" output_dim=1, num_heads=NHEADS,\n",
" num_layers=NLAYERS, seq_len=SEQ_LEN,\n",
- " dropout=DROPOUT, device=device).to(device)\n",
+ " dropout=DROPOUT).to(device)\n",
"# model = torch.compile(model)\n",
"\n",
"criterion = nn.MSELoss()\n",
@@ -452,21 +485,38 @@
},
{
"cell_type": "code",
- "execution_count": 137,
+ "execution_count": 121,
"metadata": {},
"outputs": [],
"source": [
- "from torch.utils.data import DataLoader, TensorDataset\n",
+ "# from torch.utils.data import DataLoader, TensorDataset\n",
"\n",
- "with open(\"data.pkl\", \"rb\") as f:\n",
- " pickled_stuff = pickle.load(f)\n",
+ "# with open(\"data.pkl\", \"rb\") as f:\n",
+ "# pickled_stuff = pickle.load(f)\n",
"\n",
- "data = pickled_stuff[\"data\"].to(device)\n",
- "label = pickled_stuff[\"labels\"].to(device)\n",
- "padding_mask = (data == PAD_TOKEN).bool().to(device)\n",
- "dataset = TensorDataset(data, label, padding_mask)\n",
- "# train_dataset, test_dataset = torch.utils.data.random_split(dataset, [.9, .1])\n",
- "train_loader = DataLoader(dataset, batch_size=BSZ, shuffle=True)"
+ "# data = pickled_stuff[\"data\"].to(device)\n",
+ "# label = pickled_stuff[\"labels\"].to(device)\n",
+ "# padding_mask = (data == PAD_TOKEN).bool().to(device)\n",
+ "# dataset = TensorDataset(data, label, padding_mask)\n",
+ "# # train_dataset, test_dataset = torch.utils.data.random_split(dataset, [.9, .1])\n",
+ "# train_loader = DataLoader(dataset, batch_size=BSZ, shuffle=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# figure out if it's doing better on short paths\n",
+ "def evaluate_short():\n",
+ " model.eval()\n",
+ " test_loss = 0\n",
+ " with torch.no_grad():\n",
+ " batch_src, batch_labels, batch_padding_mask = mkbatch(BSZ)\n",
+ " output = model(batch_src, batch_padding_mask)\n",
+ " loss = criterion(output[batch_labels == 1].squeeze(1), batch_labels[batch_labels==1])\n",
+ " return loss.item()"
]
},
{
@@ -480,7 +530,7 @@
},
{
"cell_type": "code",
- "execution_count": 138,
+ "execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
@@ -496,7 +546,7 @@
},
{
"cell_type": "code",
- "execution_count": 139,
+ "execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
@@ -510,580 +560,119 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:04<00:00, 7.01it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 1/1000 \t Train Err: 0.5381 \t Test Err: 0.1865\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:04<00:00, 7.75it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 2/1000 \t Train Err: 0.1227 \t Test Err: 0.1128\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:04<00:00, 7.03it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 3/1000 \t Train Err: 0.1071 \t Test Err: 0.1118\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:04<00:00, 7.01it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 4/1000 \t Train Err: 0.1008 \t Test Err: 0.1035\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:04<00:00, 7.75it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 5/1000 \t Train Err: 0.0972 \t Test Err: 0.1021\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:04<00:00, 7.03it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 6/1000 \t Train Err: 0.0949 \t Test Err: 0.0981\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:04<00:00, 7.76it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 7/1000 \t Train Err: 0.0929 \t Test Err: 0.1021\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:04<00:00, 7.00it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 8/1000 \t Train Err: 0.0908 \t Test Err: 0.0977\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:04<00:00, 7.01it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 9/1000 \t Train Err: 0.0886 \t Test Err: 0.0952\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:04<00:00, 7.72it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 10/1000 \t Train Err: 0.0910 \t Test Err: 0.0962\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:04<00:00, 7.02it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 11/1000 \t Train Err: 0.0851 \t Test Err: 0.0898\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:04<00:00, 7.03it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 12/1000 \t Train Err: 0.0849 \t Test Err: 0.0864\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:04<00:00, 7.74it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 13/1000 \t Train Err: 0.0795 \t Test Err: 0.0684\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:04<00:00, 7.02it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 14/1000 \t Train Err: 0.0691 \t Test Err: 0.0293\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:04<00:00, 7.74it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 15/1000 \t Train Err: 0.0455 \t Test Err: 0.0271\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:04<00:00, 6.99it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 16/1000 \t Train Err: 0.0421 \t Test Err: 0.0210\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:04<00:00, 6.99it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 17/1000 \t Train Err: 0.0817 \t Test Err: 0.0505\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:04<00:00, 7.75it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 18/1000 \t Train Err: 0.0456 \t Test Err: 0.0176\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:04<00:00, 7.03it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 19/1000 \t Train Err: 0.0370 \t Test Err: 0.0165\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:04<00:00, 7.03it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 20/1000 \t Train Err: 0.0374 \t Test Err: 0.0205\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:04<00:00, 7.73it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 21/1000 \t Train Err: 0.0372 \t Test Err: 0.0142\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:04<00:00, 7.01it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 22/1000 \t Train Err: 0.0343 \t Test Err: 0.0132\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:04<00:00, 7.75it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 23/1000 \t Train Err: 0.0337 \t Test Err: 0.0119\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:04<00:00, 7.03it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 24/1000 \t Train Err: 0.0713 \t Test Err: 0.0259\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:04<00:00, 7.04it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 25/1000 \t Train Err: 0.0522 \t Test Err: 0.0143\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:04<00:00, 7.75it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 26/1000 \t Train Err: 0.0342 \t Test Err: 0.0117\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:04<00:00, 7.02it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 27/1000 \t Train Err: 0.0864 \t Test Err: 0.0728\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:04<00:00, 7.04it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 28/1000 \t Train Err: 0.0701 \t Test Err: 0.0510\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:04<00:00, 7.72it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 29/1000 \t Train Err: 0.0598 \t Test Err: 0.0369\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:04<00:00, 6.99it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 30/1000 \t Train Err: 0.0462 \t Test Err: 0.0231\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:04<00:00, 7.73it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 31/1000 \t Train Err: 0.0387 \t Test Err: 0.0181\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:04<00:00, 7.02it/s]\n"
+ "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:13<00:00, 2.30it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
- "Epoch 32/1000 \t Train Err: 0.0351 \t Test Err: 0.0142\n"
+ "Epoch 1/1000 \t Train Err: 149.6562 \t Test Err: 137.0000, Test short loss: 0.3164\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
- "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:04<00:00, 7.04it/s]\n"
+ "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:13<00:00, 2.41it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
- "Epoch 33/1000 \t Train Err: 0.0337 \t Test Err: 0.0123\n"
+ "Epoch 2/1000 \t Train Err: 129.4688 \t Test Err: 120.0000, Test short loss: 0.4512\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
- "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:04<00:00, 7.74it/s]\n"
+ "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:13<00:00, 2.39it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
- "Epoch 34/1000 \t Train Err: 0.0331 \t Test Err: 0.0117\n"
+ "Epoch 3/1000 \t Train Err: 116.1719 \t Test Err: 110.5000, Test short loss: 2.2500\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
- "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:04<00:00, 7.02it/s]\n"
+ "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:13<00:00, 2.38it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
- "Epoch 35/1000 \t Train Err: 0.0329 \t Test Err: 0.0119\n"
+ "Epoch 4/1000 \t Train Err: 107.9375 \t Test Err: 104.0000, Test short loss: 4.0625\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
- "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:04<00:00, 6.98it/s]\n"
+ "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:13<00:00, 2.37it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
- "Epoch 36/1000 \t Train Err: 0.0327 \t Test Err: 0.0109\n"
+ "Epoch 5/1000 \t Train Err: 103.0938 \t Test Err: 102.0000, Test short loss: 5.5625\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
- "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:04<00:00, 7.74it/s]\n"
+ "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:13<00:00, 2.38it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
- "Epoch 37/1000 \t Train Err: 0.0319 \t Test Err: 0.0101\n"
+ "Epoch 6/1000 \t Train Err: 99.9531 \t Test Err: 99.0000, Test short loss: 6.6250\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
- "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:04<00:00, 7.02it/s]\n"
+ "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:13<00:00, 2.41it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
- "Epoch 38/1000 \t Train Err: 0.0317 \t Test Err: 0.0099\n"
+ "Epoch 7/1000 \t Train Err: 99.0469 \t Test Err: 101.0000, Test short loss: 7.2188\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
- "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:04<00:00, 7.71it/s]\n"
+ "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:13<00:00, 2.39it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
- "Epoch 39/1000 \t Train Err: 0.0313 \t Test Err: 0.0095\n"
+ "Epoch 8/1000 \t Train Err: 97.8594 \t Test Err: 97.5000, Test short loss: 7.6250\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
- "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:04<00:00, 7.01it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 40/1000 \t Train Err: 0.0310 \t Test Err: 0.0097\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- " 28%|██████████████████████████████████████████▊ | 9/32 [00:01<00:03, 6.50it/s]\n"
- ]
- },
- {
- "ename": "KeyboardInterrupt",
- "evalue": "",
- "output_type": "error",
- "traceback": [
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
- "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
- "Cell \u001b[0;32mIn[139], line 11\u001b[0m\n\u001b[1;32m 9\u001b[0m train_loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[1;32m 10\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m tqdm(\u001b[38;5;28mrange\u001b[39m(BPE)):\n\u001b[0;32m---> 11\u001b[0m batch_src, batch_labels, batch_padding_mask \u001b[38;5;241m=\u001b[39m \u001b[43mmkbatch\u001b[49m\u001b[43m(\u001b[49m\u001b[43mBSZ\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 12\u001b[0m \u001b[38;5;66;03m# for batch_src, batch_labels, batch_padding_mask in tqdm(train_loader):\u001b[39;00m\n\u001b[1;32m 13\u001b[0m optimizer\u001b[38;5;241m.\u001b[39mzero_grad()\n",
- "Cell \u001b[0;32mIn[76], line 55\u001b[0m, in \u001b[0;36mmkbatch\u001b[0;34m(size)\u001b[0m\n\u001b[1;32m 53\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(size):\n\u001b[1;32m 54\u001b[0m n \u001b[38;5;241m=\u001b[39m random\u001b[38;5;241m.\u001b[39mrandint(MIN_VTXS, MAX_VTXS)\n\u001b[0;32m---> 55\u001b[0m edge_list, adj_list \u001b[38;5;241m=\u001b[39m \u001b[43mrandom_graph\u001b[49m\u001b[43m(\u001b[49m\u001b[43mn\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 56\u001b[0m dist \u001b[38;5;241m=\u001b[39m SSSP(n, adj_list)\n\u001b[1;32m 57\u001b[0m edge_list[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m2\u001b[39m \u001b[38;5;66;03m# target token\u001b[39;00m\n",
- "Cell \u001b[0;32mIn[76], line 15\u001b[0m, in \u001b[0;36mrandom_graph\u001b[0;34m(n)\u001b[0m\n\u001b[1;32m 13\u001b[0m adjacencies \u001b[38;5;241m=\u001b[39m [\u001b[38;5;28mset\u001b[39m() \u001b[38;5;28;01mfor\u001b[39;00m _ \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(n\u001b[38;5;241m+\u001b[39m\u001b[38;5;241m1\u001b[39m)]\n\u001b[1;32m 14\u001b[0m indices \u001b[38;5;241m=\u001b[39m [random\u001b[38;5;241m.\u001b[39mrandint(\u001b[38;5;241m1\u001b[39m, n) \u001b[38;5;28;01mfor\u001b[39;00m _ \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(AVG_DEG \u001b[38;5;241m*\u001b[39m (n\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m))]\n\u001b[0;32m---> 15\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28;43mrange\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mlen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mindices\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m)\u001b[49m:\n\u001b[1;32m 16\u001b[0m u \u001b[38;5;241m=\u001b[39m indices[i]\n\u001b[1;32m 17\u001b[0m v \u001b[38;5;241m=\u001b[39m indices[i \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m1\u001b[39m]\n",
- "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
+ " 44%|██████████████████████████████████████████████████████████████████ | 14/32 [00:05<00:07, 2.41it/s]"
]
}
],
@@ -1108,12 +697,13 @@
" optimizer.step()\n",
"\n",
" test_loss = evaluate()\n",
+ " test_short_loss = evaluate_short()\n",
" \n",
" test_err.append(test_loss)\n",
" train_err.append(train_loss)\n",
" with open('loss', 'a') as f:\n",
" f.write(f\"{train_loss} {test_loss}\\n\")\n",
- " print(f\"Epoch {epoch + 1}/{NEPOCHS} \\t Train Err: {train_loss:.4f} \\t Test Err: {test_loss:.4f}\")\n",
+ " print(f\"Epoch {epoch + 1}/{NEPOCHS} \\t Train Err: {train_loss:.4f} \\t Test Err: {test_loss:.4f}, Test short loss: {test_short_loss:.4f}\")\n",
" \n",
" if epoch % 100 == 99:\n",
" torch.save(model.state_dict(), f\"model_weights_{epoch}.pth\")"
@@ -1121,55 +711,38 @@
},
{
"cell_type": "code",
- "execution_count": 163,
+ "execution_count": 125,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "(torch.Size([1, 7, 64]), torch.Size([64, 64]))"
- ]
- },
- "execution_count": 163,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
- "\"\"\"\n",
- "Now let's figure out what it's doing. \n",
+ "# \"\"\"\n",
+ "# Now let's figure out what it's doing. \n",
"\n",
- "step 1: figure out what people are attending to \n",
- "\"\"\"\n",
+ "# step 1: figure out what people are attending to \n",
+ "# \"\"\"\n",
+ "\n",
+ "# example_graph, answer, padding = mkbatch(1)\n",
+ "# sentance_embeddings = model.full_embedding(example_graph)[0,:,:][example_graph.flatten() != 0]\n",
+ "# WQ,WK,WV = torch.split(model.transformer_encoder.layers[0].self_attn.in_proj_weight, (MODEL_DIM, MODEL_DIM, MODEL_DIM))\n",
"\n",
- "example_graph, answer, padding = mkbatch(1)\n",
- "sentance_embeddings = model.full_embedding(example_graph)\n",
- "Q,K,V = torch.split(model.transformer_encoder.layers[0].self_attn.in_proj_weight, (MODEL_DIM, MODEL_DIM, MODEL_DIM))\n",
+ "# Q = sentance_embeddings@WQ\n",
+ "# K = sentance_embeddings@WK\n",
"\n",
- "sentance_embeddings.shape, Q.shape\n",
- "sentance_embeddings@Q.T\n",
+ "# raw_scores = Q @ K.T / sqrt(MODEL_DIM)\n",
+ "# soft = torch.softmax(raw_scores, dim=-1).detach().cpu().to(float).numpy()\n",
+ "# plt.imshow(soft)\n",
+ "# plt.show()\n",
"\n",
- "# (sentance_embeddings @ Q).shape\n",
- "# sentance_embeddings.shape\n",
- "# K @ sentance_embeddings"
+ "# print(example_graph)\n",
+ "\n",
+ "# print(Q)"
]
},
{
"cell_type": "code",
- "execution_count": 34,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "<matplotlib.legend.Legend at 0x702d2d2eed20>"
- ]
- },
- "execution_count": 34,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"plt.suptitle('MSE vs Epochs')\n",
"plt.plot(train_err, label='Train', color='blue')\n",
@@ -1182,83 +755,17 @@
},
{
"cell_type": "code",
- "execution_count": 35,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "24.625"
- ]
- },
- "execution_count": 35,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "evaluate()"
- ]
- },
- {
- "cell_type": "code",
"execution_count": null,
"metadata": {
"id": "LoGEmM5lH7_A"
},
"outputs": [],
"source": [
- "batch_src, batch_labels, batch_padding_mask = next(iter(train_loader))\n",
+ "batch_src, batch_labels, batch_padding_mask = mkbatch(4096)\n",
"output = model(batch_src, batch_padding_mask)\n",
- "batch_src[0], batch_labels[0], output[0]"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "plt.hist(output.detach().cpu().numpy().flatten(),bins=32)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "plt.hist(label.detach().cpu().numpy().flatten(),bins=32)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "plt.scatter(batch_labels.detach().cpu().numpy().flatten(),output.detach().cpu().numpy().flatten())"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "batch_src2, batch_labels2, batch_padding_mask2 = next(iter(test_loader))\n",
- "output2 = model(batch_src2, batch_padding_mask2)\n",
- "loss = criterion(output2.squeeze(1), batch_labels2)\n",
- "batch_src2[0], batch_labels2[0], output2[0], loss"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "plt.scatter(batch_labels2.detach().cpu().numpy().flatten(),output2.detach().cpu().numpy().flatten())"
+ "batch_src[0], batch_labels[0], output[0]\n",
+ "plt.scatter(batch_labels.detach().to(torch.float16).cpu().numpy().flatten(), output.detach().to(torch.float16).cpu().numpy().flatten())\n",
+ "plt.show()"
]
},
{
@@ -1277,8 +784,8 @@
"outputs": [],
"source": [
"N_TUNE_EPOCHS = 100\n",
- "TUNE_LR = 0.003\n",
- "TUNE_WD = 0.002\n",
+ "TUNE_LR = 1e-5\n",
+ "TUNE_WD = 1e-5\n",
"\n",
"tune_criterion = nn.MSELoss()\n",
"tune_optimizer = torch.optim.Adam(model.parameters(), lr=TUNE_LR, weight_decay=TUNE_WD)"
@@ -1306,32 +813,6 @@
"metadata": {},
"outputs": [],
"source": [
- "# This has to be in a separate cell for some weird event loop reasons\n",
- "%matplotlib widget\n",
- "fig,ax = plt.subplots()\n",
- "fig.suptitle('MSE vs Epochs')\n",
- "plt.show()\n",
- "%matplotlib inline"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 13,
- "metadata": {},
- "outputs": [
- {
- "ename": "NameError",
- "evalue": "name 'N_TUNE_EPOCHS' is not defined",
- "output_type": "error",
- "traceback": [
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
- "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
- "Cell \u001b[0;32mIn[13], line 4\u001b[0m\n\u001b[1;32m 1\u001b[0m tune_train_err \u001b[38;5;241m=\u001b[39m []\n\u001b[1;32m 2\u001b[0m tune_test_err \u001b[38;5;241m=\u001b[39m []\n\u001b[0;32m----> 4\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m epoch \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[43mN_TUNE_EPOCHS\u001b[49m):\n\u001b[1;32m 5\u001b[0m model\u001b[38;5;241m.\u001b[39mtrain()\n\u001b[1;32m 6\u001b[0m train_loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m\n",
- "\u001b[0;31mNameError\u001b[0m: name 'N_TUNE_EPOCHS' is not defined"
- ]
- }
- ],
- "source": [
"tune_train_err = []\n",
"tune_test_err = []\n",
"\n",
@@ -1346,7 +827,7 @@
" optimizer.zero_grad()\n",
" output = model(batch_src, batch_padding_mask)\n",
" loss = criterion(output.squeeze(1), batch_labels)\n",
- " train_loss += loss.item()/BPE\n",
+ " train_loss += loss.item() / BPE\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
@@ -1356,11 +837,6 @@
" tune_train_err.append(train_loss)\n",
" with open('tune_loss', 'a') as f:\n",
" f.write(f\"{train_loss} {test_loss}\\n\")\n",
- " ax.plot(tune_train_err, label='Train', color='blue')\n",
- " ax.plot(tune_test_err, label='Test', color='red')\n",
- " ax.set_xlabel('Epochs')\n",
- " ax.set_ylabel('MSE')\n",
- " fig.canvas.draw()\n",
" print(f\"Epoch {epoch + 1}/{NEPOCHS} \\t Train Err: {train_loss:.4f} \\t Test Err: {test_loss:.4f}\")\n",
"\n",
" if epoch % 10 == 9:\n",
@@ -1368,6 +844,33 @@
]
},
{
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "plt.suptitle('MSE vs Epochs')\n",
+ "plt.plot(tune_train_err, label='Train', color='blue')\n",
+ "plt.plot(tune_test_err, label='Test', color='red')\n",
+ "plt.xlabel('Epochs')\n",
+ "plt.ylabel('MSE')\n",
+ "plt.legend()\n",
+ "plt.show()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "batch_src, batch_labels, batch_padding_mask = mktunebatch(2048)\n",
+ "output = model(batch_src, batch_padding_mask)\n",
+ "batch_src[0], batch_labels[0], output[0]\n",
+ "plt.scatter(batch_labels.detach().to(torch.float16).cpu().numpy().flatten(), output.detach().to(torch.float16).cpu().numpy().flatten())"
+ ]
+ },
+ {
"cell_type": "markdown",
"metadata": {
"id": "JtTLXn4zC1z_"