aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKevin Zhao2024-05-05 10:33:51 -0400
committerKevin Zhao2024-05-05 10:33:51 -0400
commit3d4862e725d29ab910f4a958717749abe12aad7a (patch)
tree34b10255e05e48c05360c5cb7b9daa35b214b65e
parent49e4564566bbfd8435f390bba602eb97cc214502 (diff)
Add CNN arch and training code
-rw-r--r--.gitignore160
-rw-r--r--corner_training/coarse_training.py126
-rw-r--r--corner_training/fine_training.py127
-rw-r--r--corner_training/models.py153
-rw-r--r--corner_training/utils.py224
5 files changed, 790 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..c883aa9
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,160 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+# For a library or package, you might want to ignore these files since the code is
+# intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# poetry
+# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
+# This is especially recommended for binary packages to ensure reproducibility, and is more
+# commonly ignored for libraries.
+# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
+#poetry.lock
+
+# pdm
+# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
+#pdm.lock
+# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
+# in version control.
+# https://pdm.fming.dev/#use-with-ide
+.pdm.toml
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+# Cython debug symbols
+cython_debug/
+
+# Custom
+.DS_Store
+.idea/
+Report_Stuff/
+data/
+*.mkv
diff --git a/corner_training/coarse_training.py b/corner_training/coarse_training.py
new file mode 100644
index 0000000..83f2f27
--- /dev/null
+++ b/corner_training/coarse_training.py
@@ -0,0 +1,126 @@
+import argparse
+import json
+import math
+import os
+import random
+import re
+
+import matplotlib.pyplot as plt
+import numpy as np
+import PIL
+import torch
+import torch.nn as nn
+
+import torchvision
+import torchvision.transforms as transforms
+import torchvision.transforms.functional as transforms_f
+# import torchvision.transforms.v2 as transforms
+# import torchvision.transforms.v2.functional as transforms_f
+from tqdm.auto import tqdm
+
+from models import QuantizedV2
+from utils import FlyingFramesDataset, NoamLR, get_gtruth_wrapper
+
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("data_root_path")
+ parser.add_argument("output_dir")
+
+ return parser.parse_args()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ os.makedirs(args.output_dir, exist_ok=False)
+
+ np.random.seed(42)
+ torch.manual_seed(42)
+ random.seed(42)
+
+ lr = 3e-2 # TODO: argparse
+ log_steps = 10
+ train_batch_size = 512
+ num_train_epochs = 10
+
+ # leave 1 shard for eval
+ shard_paths = sorted([path for path in os.listdir(args.data_root_path)
+ if re.fullmatch(r"shard_\d+", path) is not None])
+
+ get_gtruth = get_gtruth_wrapper(4)
+
+ train_dataset = FlyingFramesDataset(args.data_root_path, get_gtruth, shard_paths[:-1])
+ eval_dataset = FlyingFramesDataset(args.data_root_path, get_gtruth, shard_paths[-1:], eval=True)
+ train_dataloader = torch.utils.data.DataLoader(
+ # train_dataset, batch_size=train_batch_size, num_workers=2, pin_memory=True,
+ train_dataset, batch_size=train_batch_size, num_workers=2, pin_memory=False
+ )
+
+ model = QuantizedV2()
+
+ print(f"{sum(p.numel() for p in model.parameters() if p.requires_grad)=}")
+
+ optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
+ scheduler = NoamLR(optimizer, warmup_steps=log_steps * 3)
+
+ # loss_func = nn.MSELoss()
+ def loss_func(output, labels):
+ # weights = (labels != 0) * 49.9 + 0.1 # arbitrary nums (good DeepStage1v0 64x64)
+ weights = (labels != 0) * 99.9 + 0.1 # 4 all pts
+ # weights = (labels != 0) * 199.9 + 0.1 # 1 pt 1 layer
+ return (weights * (output - labels) ** 2).mean()
+
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ model.to(device)
+
+ for epoch in tqdm(range(num_train_epochs), position=0, leave=True):
+ cum_loss = 0
+ model.train()
+ for i, batch in enumerate(tqdm(train_dataloader, position=1, leave=False)):
+ output = model(batch["imgs"].to(device))
+ loss = loss_func(output, batch["labels"].to(device))
+ loss.backward()
+ optimizer.step()
+ optimizer.zero_grad()
+
+ cum_loss += loss.item()
+ if (i + 1) % log_steps == 0:
+ print(f"{cum_loss / log_steps=}")
+ cum_loss = 0
+
+ if i % log_steps != 0:
+ print(f"{cum_loss / (i % log_steps)=}")
+
+ if epoch > 3:
+ # Freeze quantizer parameters
+ model.apply(torch.ao.quantization.disable_observer)
+ if epoch > 2:
+ # Freeze batch norm mean and variance estimates
+ model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
+
+ # model.eval()
+ quantized_model = torch.ao.quantization.convert(model.cpu().eval(), inplace=False)
+ model.to(device)
+ quantized_model.eval() # not sure if this is needed
+ eval_device = "cpu"
+ with torch.no_grad():
+ num_eval_exs = 5
+ fig, axs = plt.subplots(num_eval_exs, 3, figsize=(12, 4 * num_eval_exs))
+
+ eval_loss = 0
+ for i in range(num_eval_exs):
+ eval_ex = eval_dataset[i] # slicing doesn't work yet...
+ axs[i, 0].imshow(transforms_f.to_pil_image(eval_ex["imgs"]))
+ preds = quantized_model(eval_ex["imgs"].to(eval_device).unsqueeze(0))
+ eval_loss += loss_func(preds, eval_ex["labels"].to(eval_device)).item()
+ axs[i, 1].imshow(eval_ex["labels"])
+ axs[i, 1].set_title("Ground Truth")
+ axs[i, 2].imshow(preds.detach().cpu().numpy().squeeze(0))
+ axs[i, 2].set_title("Prediction")
+ for ax in axs[i]:
+ ax.axis("off")
+
+ print(f"{eval_loss / num_eval_exs=}")
+ plt.savefig(os.path.join(args.output_dir, f"validation_results{epoch}.png"), bbox_inches="tight")
+ torch.save(quantized_model.state_dict(),
+ os.path.join(args.output_dir, f"QuantizedV3_Stage1_128_{epoch}.pt"))
diff --git a/corner_training/fine_training.py b/corner_training/fine_training.py
new file mode 100644
index 0000000..6acd1f6
--- /dev/null
+++ b/corner_training/fine_training.py
@@ -0,0 +1,127 @@
+import argparse
+import json
+import math
+import os
+import random
+import re
+
+import matplotlib.pyplot as plt
+import numpy as np
+import PIL
+import torch
+import torch.nn as nn
+
+import torchvision
+import torchvision.transforms as transforms
+import torchvision.transforms.functional as transforms_f
+# import torchvision.transforms.v2 as transforms
+# import torchvision.transforms.v2.functional as transforms_f
+from tqdm.auto import tqdm
+
+from models import QuantizedV3, QuantizedV2
+from utils import FlyingFramesDataset, NoamLR, get_gtruth_wrapper
+
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("data_root_path")
+ parser.add_argument("output_dir")
+
+ return parser.parse_args()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ os.makedirs(args.output_dir, exist_ok=False)
+
+ np.random.seed(42)
+ torch.manual_seed(42)
+ random.seed(42)
+
+ lr = 3e-2 # TODO: argparse
+ log_steps = 10
+ train_batch_size = 512
+ num_train_epochs = 10
+
+ # leave 1 shard for eval
+ shard_paths = sorted([path for path in os.listdir(args.data_root_path)
+ if re.fullmatch(r"shard_\d+", path) is not None])
+
+ get_gtruth = get_gtruth_wrapper(4)
+ train_dataset = FlyingFramesDataset(args.data_root_path, get_gtruth, shard_paths[:-1], eval=True) # note: not flipping
+ eval_dataset = FlyingFramesDataset(args.data_root_path, get_gtruth, shard_paths[-1:], eval=True)
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset, batch_size=train_batch_size, num_workers=2, pin_memory=True
+ )
+
+ # torch.backends.quantized.engine = "qnnpack"
+ model = QuantizedV2()
+ # model.fuse_modules(is_qat=True) # Added for QAT
+
+ print(f"{sum(p.numel() for p in model.parameters() if p.requires_grad)=}")
+
+ optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
+ scheduler = NoamLR(optimizer, warmup_steps=log_steps * 3)
+
+ # model.qconfig = torch.ao.quantization.default_qconfig # Added for QAT
+ # torch.ao.quantization.prepare_qat(model, inplace=True) # Added for QAT
+
+ def loss_func(output, labels):
+ weights = (labels != 0) * 49.9 + 0.1
+ # weights = (labels != 0) * 99.9 + 0.1
+ return (weights * (output - labels) ** 2).mean()
+
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ model.to(device)
+
+ for epoch in tqdm(range(num_train_epochs), position=0, leave=True):
+ cum_loss = 0
+ model.train()
+ for i, batch in enumerate(tqdm(train_dataloader, position=1, leave=False)):
+ output = model(batch["imgs"].to(device))
+ loss = loss_func(output, batch["labels"].to(device))
+ loss.backward()
+ optimizer.step()
+ optimizer.zero_grad()
+
+ cum_loss += loss.item()
+ if (i + 1) % log_steps == 0:
+ print(f"{cum_loss / log_steps=}")
+ cum_loss = 0
+
+ if i % log_steps != 0:
+ print(f"{cum_loss / (i % log_steps)=}")
+
+ if epoch > 3:
+ # Freeze quantizer parameters
+ model.apply(torch.ao.quantization.disable_observer)
+ if epoch > 2:
+ # Freeze batch norm mean and variance estimates
+ model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
+
+ # model.eval()
+ quantized_model = torch.ao.quantization.convert(model.cpu().eval(), inplace=False)
+ model.to(device)
+ quantized_model.eval() # not sure if this is needed
+ eval_device = "cpu"
+ with torch.no_grad():
+ num_eval_exs = 5
+ fig, axs = plt.subplots(num_eval_exs, 3, figsize=(12, 4 * num_eval_exs))
+
+ eval_loss = 0
+ for i in range(num_eval_exs):
+ eval_ex = eval_dataset[i] # slicing doesn't work yet...
+ axs[i, 0].imshow(transforms_f.to_pil_image(eval_ex["imgs"]))
+ preds = quantized_model(eval_ex["imgs"].to(eval_device).unsqueeze(0))
+ eval_loss += loss_func(preds, eval_ex["labels"].to(eval_device)).item()
+ axs[i, 1].imshow(eval_ex["labels"])
+ axs[i, 1].set_title("Ground Truth")
+ axs[i, 2].imshow(preds.detach().cpu().numpy().squeeze(0))
+ axs[i, 2].set_title("Prediction")
+ for ax in axs[i]:
+ ax.axis("off")
+
+ print(f"{eval_loss / num_eval_exs=}")
+ plt.savefig(os.path.join(args.output_dir, f"validation_results{epoch}.png"), bbox_inches="tight")
+ torch.save(quantized_model.state_dict(),
+ os.path.join(args.output_dir, f"QuantizedV3_Stage2_128_{epoch}.pt"))
diff --git a/corner_training/models.py b/corner_training/models.py
new file mode 100644
index 0000000..dff0ac2
--- /dev/null
+++ b/corner_training/models.py
@@ -0,0 +1,153 @@
+import torch
+import torch.nn as nn
+
+
+class UNetStage1v1(nn.Module):
+ def __init__(self, in_channels=3):
+ super().__init__()
+
+ self.block0 = nn.Sequential(
+ nn.BatchNorm2d(in_channels),
+ nn.Conv2d(in_channels, 32, kernel_size=3, padding="same"),
+ nn.BatchNorm2d(32),
+ nn.ReLU(),
+ nn.Conv2d(32, 32, kernel_size=3, padding="same")
+ )
+
+ self.pool0 = nn.MaxPool2d(kernel_size=2)
+
+ self.block1 = nn.Sequential(
+ nn.BatchNorm2d(32),
+ nn.ReLU(),
+ nn.Conv2d(32, 64, kernel_size=3, padding="same"),
+ nn.BatchNorm2d(64),
+ nn.ReLU(),
+ nn.Conv2d(64, 64, kernel_size=3, padding="same"),
+ )
+
+ self.pool1 = nn.MaxPool2d(kernel_size=2)
+
+ self.block2 = nn.Sequential(
+ nn.BatchNorm2d(64),
+ nn.ReLU(),
+ nn.Conv2d(64, 128, kernel_size=3, padding="same"),
+ nn.BatchNorm2d(128),
+ nn.ReLU(),
+ nn.Conv2d(128, 128, kernel_size=3, padding="same"),
+ )
+
+ self.upsample0 = nn.Sequential(
+ nn.BatchNorm2d(128),
+ nn.ReLU(),
+ nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
+ )
+
+ self.block3 = nn.Sequential(
+ nn.BatchNorm2d(128),
+ nn.ReLU(),
+ nn.Conv2d(128, 64, kernel_size=3, padding="same"),
+ nn.BatchNorm2d(64),
+ nn.ReLU(),
+ nn.Conv2d(64, 64, kernel_size=3, padding="same"),
+ )
+
+ self.upsample1 = nn.Sequential(
+ nn.BatchNorm2d(64),
+ nn.ReLU(),
+ nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2),
+ )
+
+ self.block4 = nn.Sequential(
+ nn.BatchNorm2d(64),
+ nn.ReLU(),
+ nn.Conv2d(64, 32, kernel_size=3, padding="same"),
+ nn.BatchNorm2d(32),
+ nn.ReLU(),
+ nn.Conv2d(32, 32, kernel_size=3, padding="same"),
+ )
+
+ def forward(self, imgs):
+ assert imgs.dim() == 4, imgs.size()
+ x = self.block0(imgs)
+ downsampled_x0 = self.block1(self.pool0(x))
+ downsampled_x1 = self.block2(self.pool1(downsampled_x0))
+ upsampled_x0 = self.block3(torch.cat([self.upsample0(downsampled_x1), downsampled_x0], dim=1))
+ x = self.block4(torch.cat([self.upsample1(upsampled_x0), x], dim=1))
+ x = torch.max(x, dim=-3)[0] # max over channel dim
+ return torch.sigmoid(x)
+
+
+class QuantizedV2(nn.Module):
+ """ Normal convolutions with quantization """
+ def __init__(self, in_channels=3):
+ super().__init__()
+
+ self.layers = nn.Sequential(
+ torch.quantization.QuantStub(),
+ nn.Conv2d(in_channels, 32, kernel_size=3, padding=1),
+ nn.BatchNorm2d(32),
+ nn.ReLU(),
+ nn.Conv2d(32, 32, kernel_size=3, dilation=2, padding=2),
+ nn.BatchNorm2d(32),
+ nn.ReLU(),
+ nn.Conv2d(32, 8, kernel_size=3, padding=1),
+ torch.quantization.DeQuantStub(),
+ )
+
+ def forward(self, imgs):
+ x = self.layers(imgs)
+ x = torch.max(x, dim=-3)[0] # max over channel dim
+ return torch.sigmoid(x)
+
+ def fuse_modules(self, is_qat=False):
+ fuse_modules = torch.ao.quantization.fuse_modules_qat if is_qat else torch.ao.quantization.fuse_modules
+ fuse_modules(self, [[f"layers.{i}", f"layers.{i+1}", f"layers.{i+2}"] for i in range(1, 6, 3)], inplace=True)
+
+
+class QuantizedV3(nn.Module):
+ """ Depthwise convs except input layer; no inverted bottleneck """
+
+ def __init__(self, in_channels=3):
+ super().__init__()
+
+ hidden_size = 32
+ self.quant = torch.quantization.QuantStub()
+ self.dequant = torch.quantization.DeQuantStub()
+ self.input_conv = nn.Sequential(
+ nn.Conv2d(in_channels, hidden_size, kernel_size=3, padding=1),
+ nn.BatchNorm2d(hidden_size),
+ )
+
+ self.block1 = nn.Sequential(
+ nn.Conv2d(hidden_size, hidden_size, kernel_size=3, dilation=2, groups=hidden_size, padding=2),
+ nn.BatchNorm2d(hidden_size),
+ nn.ReLU(),
+ nn.Conv2d(hidden_size, hidden_size, kernel_size=1, padding=0),
+ nn.BatchNorm2d(hidden_size),
+ nn.ReLU(),
+ )
+
+ self.block2 = nn.Sequential(
+ nn.Conv2d(hidden_size, hidden_size, kernel_size=3, dilation=1, groups=hidden_size, padding=1),
+ nn.BatchNorm2d(hidden_size),
+ nn.ReLU(),
+ nn.Conv2d(hidden_size, 8, kernel_size=1, padding=0),
+ )
+
+ def forward(self, imgs):
+ x = self.quant(imgs)
+ x = self.input_conv(x)
+ x = self.block1(x)
+ x = self.block2(x)
+ x = self.dequant(x)
+ x = torch.max(x, dim=-3)[0] # max over channel dim
+ return torch.sigmoid(x)
+
+ def fuse_modules(self, is_qat=False):
+ fuse_modules = torch.ao.quantization.fuse_modules_qat if is_qat else torch.ao.quantization.fuse_modules
+ fuse_modules(self, [
+ ["input_conv.0", "input_conv.1"],
+ ["block1.0", "block1.1", "block1.2"],
+ ["block1.3", "block1.4", "block1.5"],
+ ["block2.0", "block2.1", "block2.2"],
+ ], inplace=True)
diff --git a/corner_training/utils.py b/corner_training/utils.py
new file mode 100644
index 0000000..40e09d6
--- /dev/null
+++ b/corner_training/utils.py
@@ -0,0 +1,224 @@
+import json
+import math
+import os
+import random
+import re
+from typing import Callable
+
+import matplotlib.pyplot as plt
+import numpy as np
+import PIL
+import torch
+import torch.nn as nn
+
+import torchvision
+import torchvision.transforms as transforms
+import torchvision.transforms.functional as transforms_f
+# import torchvision.transforms.v2 as transforms
+# import torchvision.transforms.v2.functional as transforms_f
+from tqdm.auto import tqdm
+
+
+def get_gaussian_filter(sigma, half_window_size) -> torch.Tensor:
+ full_window_size = half_window_size * 2 + 1
+ x_sq_offsets = torch.square(torch.arange(full_window_size) - half_window_size) \
+ .expand((full_window_size, full_window_size))
+ y_sq_offsets = x_sq_offsets.T
+ gaussian_filter = torch.exp(-(x_sq_offsets + y_sq_offsets) / sigma) # "normalized" so that the peak is 1
+ return gaussian_filter
+
+
+def get_bounded_slices(size_a: torch.Size, size_b: torch.Size, i: int, j: int) \
+ -> tuple[tuple[slice, slice], tuple[slice, slice]]:
+ """
+ Args:
+ size_a: size of the first 2D tensor (`a_tensor`)
+ size_b: size of the second 2D tensor (`tensor_b`),
+ should be at most `size_a`, and have odd height and width (so that center is clearly defined)
+ i: the row of `tensor_a` to center `tensor_b` on
+ j: the column of `tensor_a` to center `tensor_b` on
+
+ Returns:
+ ((row_slice_a, col_slice_a), (row_slice_b, col_slice_b))
+ which are as large as possible (at most the size of `size_b`) while remaining in bounds.
+ """
+ assert len(size_a) == 2 and len(size_b) == 2 \
+ and size_b[0] % 2 == size_b[1] % 2 == 1
+ half_mask_width, half_mask_height = size_b[1] // 2, size_b[0] // 2
+ left_offset = max(0, j - half_mask_width)
+ right_offset = min(size_a[1], j + half_mask_width + 1) # exclusive
+ top_offset = max(0, i - half_mask_height)
+ bottom_offset = min(size_a[0], i + half_mask_height + 1)
+
+ return (slice(top_offset, bottom_offset), slice(left_offset, right_offset)), \
+ (slice(half_mask_height + top_offset - i, half_mask_height + bottom_offset - i),
+ slice(half_mask_width + left_offset - j, half_mask_width + right_offset - j))
+
+
+def get_gtruth_wrapper(sigma: int, display_pts_inds: list[int] = None):
+ """
+ Args:
+ sigma: variance for the Gaussian smoothing.
+ display_pts_inds: If specified, will only display the points corresponding to these indices in `transformed_pts`
+ Otherwise, will display all points by default.
+
+ One main reason for using a closure is to reuse the Gaussian mask for efficiency.
+ """
+
+ half_window_size = sigma
+ # half_window_size = 2 * sigma
+ # half_window_size = 3 * sigma
+ # half_window_size = int(3 * sigma / math.sqrt(2))
+
+ gaussian_filter = get_gaussian_filter(sigma, half_window_size)
+
+ # def get_gtruth(transformed_pts: torchvision.tv_tensors.BoundingBoxes, img_size: tuple) -> torch.Tensor:
+ def get_gtruth(transformed_pts: torch.Tensor, img_size: tuple) -> torch.Tensor:
+ """
+ Converts coordinates of points to a heatmap with Gaussians centered at those coordinates.
+
+ Args: # TODO doc
+ transformed_pts:
+ img_size:
+ """
+ smoothed_gtruth = torch.zeros(img_size)
+
+ for ind in range(len(transformed_pts)) if display_pts_inds is None else display_pts_inds:
+ # x, y, _, _ = transformed_pts[ind]
+ x, y = transformed_pts[ind]
+
+ gtruth_slice, gaussian_slice = get_bounded_slices(smoothed_gtruth.size(), gaussian_filter.size(), y, x)
+ smoothed_gtruth[gtruth_slice] = torch.maximum( # should modify smoothed_gtruth in-place
+ smoothed_gtruth[gtruth_slice],
+ gaussian_filter[gaussian_slice]
+ )
+
+ return smoothed_gtruth
+
+ return get_gtruth
+
+
+class FlyingFramesDataset(torch.utils.data.Dataset):
+ """
+ Different from the `FlyingFramesDataset` in Simulate_Data, which
+ randomly generates each frame. Here we read the saved examples.
+ """
+ def __init__(self, data_root_path: str, get_gtruth: Callable, shard_paths: list[str] = None, eval: bool = False):
+ self.data_root_path, self.get_gtruth = data_root_path, get_gtruth
+ with open(os.path.join(self.data_root_path, "configs.json"), "r") as f:
+ self.configs = json.load(f)
+
+ assert self.configs["dataset_size"] % self.configs["shard_size"] == 0 # simplifying assumption
+ self.shard_size = self.configs["shard_size"]
+
+ if shard_paths is None: # use all shards
+ self.shard_paths = sorted([path for path in os.listdir(data_root_path)
+ if re.fullmatch(r"shard_\d+", path) is not None])
+ else:
+ self.shard_paths = shard_paths
+
+ assert len(self.shard_paths) * self.shard_size <= self.configs["dataset_size"], (len(self.shard_paths), self.shard_size, self.configs["dataset_size"])
+
+ identity_transform = transforms.Lambda(lambda x: x) # for eval=True
+ # Some last-mile stuff
+ # self.img_transforms = transforms.Compose([
+ # transforms.PILToTensor(),
+ # # transforms.ColorJitter((0.875, 1.125), (0.5, 1.5), (0.5, 1.5), (-0.05, 0.05))
+ # # if not eval else identity_transform,
+ # transforms.ColorJitter((0.875, 1.125), (0.5, 1.5), (0.5, 1.5), (-0.05, 0.05)), # color jitter even in eval
+ # # transforms.Grayscale(),
+ # # normalize?
+ # transforms.ToDtype(torch.float32, scale=True),
+ # # transforms.ToPILImage(), # for visualization
+ # ])
+
+ # self.unified_transforms = transforms.Compose([
+ # # transforms.Resize(self.img_size),
+ # transforms.RandomHorizontalFlip(),
+ # transforms.RandomVerticalFlip(),
+ # transforms.RandomApply([
+ # transforms.GaussianBlur(kernel_size=(3, 3))
+ # ], p=0.33), # TODO: tune
+ # ]) if not eval else identity_transform
+
+ self.img_transforms = transforms.Compose([
+ transforms.PILToTensor(),
+ # transforms.ColorJitter((0.875, 1.125), (0.5, 1.5), (0.5, 1.5), (-0.05, 0.05))
+ # if not eval else identity_transform,
+ transforms.ColorJitter((0.875, 1.125), (0.5, 1.5), (0.5, 1.5), (-0.05, 0.05)), # color jitter even in eval
+ # transforms.Grayscale(),
+ # normalize?
+ transforms.ConvertImageDtype(torch.float32),
+ # transforms.ToPILImage(), # for visualization
+ ])
+
+ self.unified_transforms = transforms.Compose([
+ # transforms.Resize(self.img_size),
+ unified_hflip,
+ unified_vflip,
+ lambda img_and_pts: (transforms.RandomApply([
+ transforms.GaussianBlur(kernel_size=(3, 3))
+ ], p=0.33)(img_and_pts[0]), img_and_pts[1]), # TODO: tune
+ ]) if not eval else identity_transform
+
+ def __getitem__(self, idx):
+ shard_path = self.shard_paths[idx // self.shard_size]
+ shard_ind = int(re.fullmatch(r"shard_(?P<shard_ind>\d+)", shard_path).group("shard_ind"))
+ img_ind = shard_ind * self.shard_size + idx % self.shard_size
+ img_path = os.path.join(self.data_root_path, shard_path, f"{img_ind}.jpg")
+ bbox_path = os.path.join(self.data_root_path, shard_path, f"{img_ind}_tensor.pt")
+
+ assert os.path.isfile(img_path) and os.path.isfile(bbox_path), (img_path, bbox_path)
+ img = self.img_transforms(PIL.Image.open(img_path))
+ bbox = torch.load(bbox_path)
+
+ img, bbox = self.unified_transforms((img, bbox))
+
+ return {
+ "imgs": img,
+ "labels": self.get_gtruth(bbox, self.configs["img_size"]),
+ }
+
+ def __len__(self):
+ return len(self.shard_paths) * self.shard_size
+
+
+class NoamLR(torch.optim.lr_scheduler._LRScheduler):
+ """
+ Taken from https://github.com/tugstugi/pytorch-saltnet/blob/master/utils/lr_scheduler.py
+ """
+
+ def __init__(self, optimizer, warmup_steps): # steps, not ratio
+ self.warmup_steps = warmup_steps
+ super().__init__(optimizer)
+
+ def get_lr(self):
+ last_epoch = max(1, self.last_epoch)
+ scale = self.warmup_steps ** 0.5 * min(last_epoch ** (-0.5), last_epoch * self.warmup_steps ** (-1.5))
+ return [base_lr * scale for base_lr in self.base_lrs]
+
+
+def unified_hflip(img_and_pts):
+ img, corner_pts = img_and_pts
+ if random.random() < 0.5:
+ return img, corner_pts
+
+ _, h, w = img.size()
+ img = transforms_f.hflip(img)
+
+ # (x, y) -> (w - x, y)
+ corner_pts = torch.stack([w - corner_pts[:, 0], corner_pts[:, 1]], dim=1)
+ return img, corner_pts
+
+
+def unified_vflip(img_and_pts):
+ img, corner_pts = img_and_pts
+ if random.random() < 0.5:
+ return img, corner_pts
+
+ _, h, w = img.size()
+ img = transforms_f.vflip(img)
+
+ # (x, y) -> (x, h - y)
+ corner_pts = torch.stack([corner_pts[:, 0], h - corner_pts[:, 1]], dim=1)
+ return img, corner_pts