diff options
author | Kevin Zhao | 2024-05-05 10:33:51 -0400 |
---|---|---|
committer | Kevin Zhao | 2024-05-05 10:33:51 -0400 |
commit | 3d4862e725d29ab910f4a958717749abe12aad7a (patch) | |
tree | 34b10255e05e48c05360c5cb7b9daa35b214b65e | |
parent | 49e4564566bbfd8435f390bba602eb97cc214502 (diff) |
Add CNN arch and training code
-rw-r--r-- | .gitignore | 160 | ||||
-rw-r--r-- | corner_training/coarse_training.py | 126 | ||||
-rw-r--r-- | corner_training/fine_training.py | 127 | ||||
-rw-r--r-- | corner_training/models.py | 153 | ||||
-rw-r--r-- | corner_training/utils.py | 224 |
5 files changed, 790 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..c883aa9 --- /dev/null +++ b/.gitignore @@ -0,0 +1,160 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# Custom +.DS_Store +.idea/ +Report_Stuff/ +data/ +*.mkv diff --git a/corner_training/coarse_training.py b/corner_training/coarse_training.py new file mode 100644 index 0000000..83f2f27 --- /dev/null +++ b/corner_training/coarse_training.py @@ -0,0 +1,126 @@ +import argparse +import json +import math +import os +import random +import re + +import matplotlib.pyplot as plt +import numpy as np +import PIL +import torch +import torch.nn as nn + +import torchvision +import torchvision.transforms as transforms +import torchvision.transforms.functional as transforms_f +# import torchvision.transforms.v2 as transforms +# import torchvision.transforms.v2.functional as transforms_f +from tqdm.auto import tqdm + +from models import QuantizedV2 +from utils import FlyingFramesDataset, NoamLR, get_gtruth_wrapper + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("data_root_path") + parser.add_argument("output_dir") + + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + os.makedirs(args.output_dir, exist_ok=False) + + np.random.seed(42) + torch.manual_seed(42) + random.seed(42) + + lr = 3e-2 # TODO: argparse + log_steps = 10 + train_batch_size = 512 + num_train_epochs = 10 + + # leave 1 shard for eval + shard_paths = sorted([path for path in os.listdir(args.data_root_path) + if re.fullmatch(r"shard_\d+", path) is not None]) + + get_gtruth = get_gtruth_wrapper(4) + + train_dataset = FlyingFramesDataset(args.data_root_path, get_gtruth, shard_paths[:-1]) + eval_dataset = FlyingFramesDataset(args.data_root_path, get_gtruth, shard_paths[-1:], eval=True) + train_dataloader = torch.utils.data.DataLoader( + # train_dataset, batch_size=train_batch_size, num_workers=2, pin_memory=True, + train_dataset, batch_size=train_batch_size, num_workers=2, pin_memory=False + ) + + model = QuantizedV2() + + print(f"{sum(p.numel() for p in model.parameters() if p.requires_grad)=}") + + optimizer = torch.optim.AdamW(model.parameters(), lr=lr) + scheduler = NoamLR(optimizer, warmup_steps=log_steps * 3) + + # loss_func = nn.MSELoss() + def loss_func(output, labels): + # weights = (labels != 0) * 49.9 + 0.1 # arbitrary nums (good DeepStage1v0 64x64) + weights = (labels != 0) * 99.9 + 0.1 # 4 all pts + # weights = (labels != 0) * 199.9 + 0.1 # 1 pt 1 layer + return (weights * (output - labels) ** 2).mean() + + device = "cuda" if torch.cuda.is_available() else "cpu" + model.to(device) + + for epoch in tqdm(range(num_train_epochs), position=0, leave=True): + cum_loss = 0 + model.train() + for i, batch in enumerate(tqdm(train_dataloader, position=1, leave=False)): + output = model(batch["imgs"].to(device)) + loss = loss_func(output, batch["labels"].to(device)) + loss.backward() + optimizer.step() + optimizer.zero_grad() + + cum_loss += loss.item() + if (i + 1) % log_steps == 0: + print(f"{cum_loss / log_steps=}") + cum_loss = 0 + + if i % log_steps != 0: + print(f"{cum_loss / (i % log_steps)=}") + + if epoch > 3: + # Freeze quantizer parameters + model.apply(torch.ao.quantization.disable_observer) + if epoch > 2: + # Freeze batch norm mean and variance estimates + model.apply(torch.nn.intrinsic.qat.freeze_bn_stats) + + # model.eval() + quantized_model = torch.ao.quantization.convert(model.cpu().eval(), inplace=False) + model.to(device) + quantized_model.eval() # not sure if this is needed + eval_device = "cpu" + with torch.no_grad(): + num_eval_exs = 5 + fig, axs = plt.subplots(num_eval_exs, 3, figsize=(12, 4 * num_eval_exs)) + + eval_loss = 0 + for i in range(num_eval_exs): + eval_ex = eval_dataset[i] # slicing doesn't work yet... + axs[i, 0].imshow(transforms_f.to_pil_image(eval_ex["imgs"])) + preds = quantized_model(eval_ex["imgs"].to(eval_device).unsqueeze(0)) + eval_loss += loss_func(preds, eval_ex["labels"].to(eval_device)).item() + axs[i, 1].imshow(eval_ex["labels"]) + axs[i, 1].set_title("Ground Truth") + axs[i, 2].imshow(preds.detach().cpu().numpy().squeeze(0)) + axs[i, 2].set_title("Prediction") + for ax in axs[i]: + ax.axis("off") + + print(f"{eval_loss / num_eval_exs=}") + plt.savefig(os.path.join(args.output_dir, f"validation_results{epoch}.png"), bbox_inches="tight") + torch.save(quantized_model.state_dict(), + os.path.join(args.output_dir, f"QuantizedV3_Stage1_128_{epoch}.pt")) diff --git a/corner_training/fine_training.py b/corner_training/fine_training.py new file mode 100644 index 0000000..6acd1f6 --- /dev/null +++ b/corner_training/fine_training.py @@ -0,0 +1,127 @@ +import argparse +import json +import math +import os +import random +import re + +import matplotlib.pyplot as plt +import numpy as np +import PIL +import torch +import torch.nn as nn + +import torchvision +import torchvision.transforms as transforms +import torchvision.transforms.functional as transforms_f +# import torchvision.transforms.v2 as transforms +# import torchvision.transforms.v2.functional as transforms_f +from tqdm.auto import tqdm + +from models import QuantizedV3, QuantizedV2 +from utils import FlyingFramesDataset, NoamLR, get_gtruth_wrapper + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("data_root_path") + parser.add_argument("output_dir") + + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + os.makedirs(args.output_dir, exist_ok=False) + + np.random.seed(42) + torch.manual_seed(42) + random.seed(42) + + lr = 3e-2 # TODO: argparse + log_steps = 10 + train_batch_size = 512 + num_train_epochs = 10 + + # leave 1 shard for eval + shard_paths = sorted([path for path in os.listdir(args.data_root_path) + if re.fullmatch(r"shard_\d+", path) is not None]) + + get_gtruth = get_gtruth_wrapper(4) + train_dataset = FlyingFramesDataset(args.data_root_path, get_gtruth, shard_paths[:-1], eval=True) # note: not flipping + eval_dataset = FlyingFramesDataset(args.data_root_path, get_gtruth, shard_paths[-1:], eval=True) + train_dataloader = torch.utils.data.DataLoader( + train_dataset, batch_size=train_batch_size, num_workers=2, pin_memory=True + ) + + # torch.backends.quantized.engine = "qnnpack" + model = QuantizedV2() + # model.fuse_modules(is_qat=True) # Added for QAT + + print(f"{sum(p.numel() for p in model.parameters() if p.requires_grad)=}") + + optimizer = torch.optim.AdamW(model.parameters(), lr=lr) + scheduler = NoamLR(optimizer, warmup_steps=log_steps * 3) + + # model.qconfig = torch.ao.quantization.default_qconfig # Added for QAT + # torch.ao.quantization.prepare_qat(model, inplace=True) # Added for QAT + + def loss_func(output, labels): + weights = (labels != 0) * 49.9 + 0.1 + # weights = (labels != 0) * 99.9 + 0.1 + return (weights * (output - labels) ** 2).mean() + + device = "cuda" if torch.cuda.is_available() else "cpu" + model.to(device) + + for epoch in tqdm(range(num_train_epochs), position=0, leave=True): + cum_loss = 0 + model.train() + for i, batch in enumerate(tqdm(train_dataloader, position=1, leave=False)): + output = model(batch["imgs"].to(device)) + loss = loss_func(output, batch["labels"].to(device)) + loss.backward() + optimizer.step() + optimizer.zero_grad() + + cum_loss += loss.item() + if (i + 1) % log_steps == 0: + print(f"{cum_loss / log_steps=}") + cum_loss = 0 + + if i % log_steps != 0: + print(f"{cum_loss / (i % log_steps)=}") + + if epoch > 3: + # Freeze quantizer parameters + model.apply(torch.ao.quantization.disable_observer) + if epoch > 2: + # Freeze batch norm mean and variance estimates + model.apply(torch.nn.intrinsic.qat.freeze_bn_stats) + + # model.eval() + quantized_model = torch.ao.quantization.convert(model.cpu().eval(), inplace=False) + model.to(device) + quantized_model.eval() # not sure if this is needed + eval_device = "cpu" + with torch.no_grad(): + num_eval_exs = 5 + fig, axs = plt.subplots(num_eval_exs, 3, figsize=(12, 4 * num_eval_exs)) + + eval_loss = 0 + for i in range(num_eval_exs): + eval_ex = eval_dataset[i] # slicing doesn't work yet... + axs[i, 0].imshow(transforms_f.to_pil_image(eval_ex["imgs"])) + preds = quantized_model(eval_ex["imgs"].to(eval_device).unsqueeze(0)) + eval_loss += loss_func(preds, eval_ex["labels"].to(eval_device)).item() + axs[i, 1].imshow(eval_ex["labels"]) + axs[i, 1].set_title("Ground Truth") + axs[i, 2].imshow(preds.detach().cpu().numpy().squeeze(0)) + axs[i, 2].set_title("Prediction") + for ax in axs[i]: + ax.axis("off") + + print(f"{eval_loss / num_eval_exs=}") + plt.savefig(os.path.join(args.output_dir, f"validation_results{epoch}.png"), bbox_inches="tight") + torch.save(quantized_model.state_dict(), + os.path.join(args.output_dir, f"QuantizedV3_Stage2_128_{epoch}.pt")) diff --git a/corner_training/models.py b/corner_training/models.py new file mode 100644 index 0000000..dff0ac2 --- /dev/null +++ b/corner_training/models.py @@ -0,0 +1,153 @@ +import torch +import torch.nn as nn + + +class UNetStage1v1(nn.Module): + def __init__(self, in_channels=3): + super().__init__() + + self.block0 = nn.Sequential( + nn.BatchNorm2d(in_channels), + nn.Conv2d(in_channels, 32, kernel_size=3, padding="same"), + nn.BatchNorm2d(32), + nn.ReLU(), + nn.Conv2d(32, 32, kernel_size=3, padding="same") + ) + + self.pool0 = nn.MaxPool2d(kernel_size=2) + + self.block1 = nn.Sequential( + nn.BatchNorm2d(32), + nn.ReLU(), + nn.Conv2d(32, 64, kernel_size=3, padding="same"), + nn.BatchNorm2d(64), + nn.ReLU(), + nn.Conv2d(64, 64, kernel_size=3, padding="same"), + ) + + self.pool1 = nn.MaxPool2d(kernel_size=2) + + self.block2 = nn.Sequential( + nn.BatchNorm2d(64), + nn.ReLU(), + nn.Conv2d(64, 128, kernel_size=3, padding="same"), + nn.BatchNorm2d(128), + nn.ReLU(), + nn.Conv2d(128, 128, kernel_size=3, padding="same"), + ) + + self.upsample0 = nn.Sequential( + nn.BatchNorm2d(128), + nn.ReLU(), + nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2), + ) + + self.block3 = nn.Sequential( + nn.BatchNorm2d(128), + nn.ReLU(), + nn.Conv2d(128, 64, kernel_size=3, padding="same"), + nn.BatchNorm2d(64), + nn.ReLU(), + nn.Conv2d(64, 64, kernel_size=3, padding="same"), + ) + + self.upsample1 = nn.Sequential( + nn.BatchNorm2d(64), + nn.ReLU(), + nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2), + ) + + self.block4 = nn.Sequential( + nn.BatchNorm2d(64), + nn.ReLU(), + nn.Conv2d(64, 32, kernel_size=3, padding="same"), + nn.BatchNorm2d(32), + nn.ReLU(), + nn.Conv2d(32, 32, kernel_size=3, padding="same"), + ) + + def forward(self, imgs): + assert imgs.dim() == 4, imgs.size() + x = self.block0(imgs) + downsampled_x0 = self.block1(self.pool0(x)) + downsampled_x1 = self.block2(self.pool1(downsampled_x0)) + upsampled_x0 = self.block3(torch.cat([self.upsample0(downsampled_x1), downsampled_x0], dim=1)) + x = self.block4(torch.cat([self.upsample1(upsampled_x0), x], dim=1)) + x = torch.max(x, dim=-3)[0] # max over channel dim + return torch.sigmoid(x) + + +class QuantizedV2(nn.Module): + """ Normal convolutions with quantization """ + def __init__(self, in_channels=3): + super().__init__() + + self.layers = nn.Sequential( + torch.quantization.QuantStub(), + nn.Conv2d(in_channels, 32, kernel_size=3, padding=1), + nn.BatchNorm2d(32), + nn.ReLU(), + nn.Conv2d(32, 32, kernel_size=3, dilation=2, padding=2), + nn.BatchNorm2d(32), + nn.ReLU(), + nn.Conv2d(32, 8, kernel_size=3, padding=1), + torch.quantization.DeQuantStub(), + ) + + def forward(self, imgs): + x = self.layers(imgs) + x = torch.max(x, dim=-3)[0] # max over channel dim + return torch.sigmoid(x) + + def fuse_modules(self, is_qat=False): + fuse_modules = torch.ao.quantization.fuse_modules_qat if is_qat else torch.ao.quantization.fuse_modules + fuse_modules(self, [[f"layers.{i}", f"layers.{i+1}", f"layers.{i+2}"] for i in range(1, 6, 3)], inplace=True) + + +class QuantizedV3(nn.Module): + """ Depthwise convs except input layer; no inverted bottleneck """ + + def __init__(self, in_channels=3): + super().__init__() + + hidden_size = 32 + self.quant = torch.quantization.QuantStub() + self.dequant = torch.quantization.DeQuantStub() + self.input_conv = nn.Sequential( + nn.Conv2d(in_channels, hidden_size, kernel_size=3, padding=1), + nn.BatchNorm2d(hidden_size), + ) + + self.block1 = nn.Sequential( + nn.Conv2d(hidden_size, hidden_size, kernel_size=3, dilation=2, groups=hidden_size, padding=2), + nn.BatchNorm2d(hidden_size), + nn.ReLU(), + nn.Conv2d(hidden_size, hidden_size, kernel_size=1, padding=0), + nn.BatchNorm2d(hidden_size), + nn.ReLU(), + ) + + self.block2 = nn.Sequential( + nn.Conv2d(hidden_size, hidden_size, kernel_size=3, dilation=1, groups=hidden_size, padding=1), + nn.BatchNorm2d(hidden_size), + nn.ReLU(), + nn.Conv2d(hidden_size, 8, kernel_size=1, padding=0), + ) + + def forward(self, imgs): + x = self.quant(imgs) + x = self.input_conv(x) + x = self.block1(x) + x = self.block2(x) + x = self.dequant(x) + x = torch.max(x, dim=-3)[0] # max over channel dim + return torch.sigmoid(x) + + def fuse_modules(self, is_qat=False): + fuse_modules = torch.ao.quantization.fuse_modules_qat if is_qat else torch.ao.quantization.fuse_modules + fuse_modules(self, [ + ["input_conv.0", "input_conv.1"], + ["block1.0", "block1.1", "block1.2"], + ["block1.3", "block1.4", "block1.5"], + ["block2.0", "block2.1", "block2.2"], + ], inplace=True) diff --git a/corner_training/utils.py b/corner_training/utils.py new file mode 100644 index 0000000..40e09d6 --- /dev/null +++ b/corner_training/utils.py @@ -0,0 +1,224 @@ +import json +import math +import os +import random +import re +from typing import Callable + +import matplotlib.pyplot as plt +import numpy as np +import PIL +import torch +import torch.nn as nn + +import torchvision +import torchvision.transforms as transforms +import torchvision.transforms.functional as transforms_f +# import torchvision.transforms.v2 as transforms +# import torchvision.transforms.v2.functional as transforms_f +from tqdm.auto import tqdm + + +def get_gaussian_filter(sigma, half_window_size) -> torch.Tensor: + full_window_size = half_window_size * 2 + 1 + x_sq_offsets = torch.square(torch.arange(full_window_size) - half_window_size) \ + .expand((full_window_size, full_window_size)) + y_sq_offsets = x_sq_offsets.T + gaussian_filter = torch.exp(-(x_sq_offsets + y_sq_offsets) / sigma) # "normalized" so that the peak is 1 + return gaussian_filter + + +def get_bounded_slices(size_a: torch.Size, size_b: torch.Size, i: int, j: int) \ + -> tuple[tuple[slice, slice], tuple[slice, slice]]: + """ + Args: + size_a: size of the first 2D tensor (`a_tensor`) + size_b: size of the second 2D tensor (`tensor_b`), + should be at most `size_a`, and have odd height and width (so that center is clearly defined) + i: the row of `tensor_a` to center `tensor_b` on + j: the column of `tensor_a` to center `tensor_b` on + + Returns: + ((row_slice_a, col_slice_a), (row_slice_b, col_slice_b)) + which are as large as possible (at most the size of `size_b`) while remaining in bounds. + """ + assert len(size_a) == 2 and len(size_b) == 2 \ + and size_b[0] % 2 == size_b[1] % 2 == 1 + half_mask_width, half_mask_height = size_b[1] // 2, size_b[0] // 2 + left_offset = max(0, j - half_mask_width) + right_offset = min(size_a[1], j + half_mask_width + 1) # exclusive + top_offset = max(0, i - half_mask_height) + bottom_offset = min(size_a[0], i + half_mask_height + 1) + + return (slice(top_offset, bottom_offset), slice(left_offset, right_offset)), \ + (slice(half_mask_height + top_offset - i, half_mask_height + bottom_offset - i), + slice(half_mask_width + left_offset - j, half_mask_width + right_offset - j)) + + +def get_gtruth_wrapper(sigma: int, display_pts_inds: list[int] = None): + """ + Args: + sigma: variance for the Gaussian smoothing. + display_pts_inds: If specified, will only display the points corresponding to these indices in `transformed_pts` + Otherwise, will display all points by default. + + One main reason for using a closure is to reuse the Gaussian mask for efficiency. + """ + + half_window_size = sigma + # half_window_size = 2 * sigma + # half_window_size = 3 * sigma + # half_window_size = int(3 * sigma / math.sqrt(2)) + + gaussian_filter = get_gaussian_filter(sigma, half_window_size) + + # def get_gtruth(transformed_pts: torchvision.tv_tensors.BoundingBoxes, img_size: tuple) -> torch.Tensor: + def get_gtruth(transformed_pts: torch.Tensor, img_size: tuple) -> torch.Tensor: + """ + Converts coordinates of points to a heatmap with Gaussians centered at those coordinates. + + Args: # TODO doc + transformed_pts: + img_size: + """ + smoothed_gtruth = torch.zeros(img_size) + + for ind in range(len(transformed_pts)) if display_pts_inds is None else display_pts_inds: + # x, y, _, _ = transformed_pts[ind] + x, y = transformed_pts[ind] + + gtruth_slice, gaussian_slice = get_bounded_slices(smoothed_gtruth.size(), gaussian_filter.size(), y, x) + smoothed_gtruth[gtruth_slice] = torch.maximum( # should modify smoothed_gtruth in-place + smoothed_gtruth[gtruth_slice], + gaussian_filter[gaussian_slice] + ) + + return smoothed_gtruth + + return get_gtruth + + +class FlyingFramesDataset(torch.utils.data.Dataset): + """ + Different from the `FlyingFramesDataset` in Simulate_Data, which + randomly generates each frame. Here we read the saved examples. + """ + def __init__(self, data_root_path: str, get_gtruth: Callable, shard_paths: list[str] = None, eval: bool = False): + self.data_root_path, self.get_gtruth = data_root_path, get_gtruth + with open(os.path.join(self.data_root_path, "configs.json"), "r") as f: + self.configs = json.load(f) + + assert self.configs["dataset_size"] % self.configs["shard_size"] == 0 # simplifying assumption + self.shard_size = self.configs["shard_size"] + + if shard_paths is None: # use all shards + self.shard_paths = sorted([path for path in os.listdir(data_root_path) + if re.fullmatch(r"shard_\d+", path) is not None]) + else: + self.shard_paths = shard_paths + + assert len(self.shard_paths) * self.shard_size <= self.configs["dataset_size"], (len(self.shard_paths), self.shard_size, self.configs["dataset_size"]) + + identity_transform = transforms.Lambda(lambda x: x) # for eval=True + # Some last-mile stuff + # self.img_transforms = transforms.Compose([ + # transforms.PILToTensor(), + # # transforms.ColorJitter((0.875, 1.125), (0.5, 1.5), (0.5, 1.5), (-0.05, 0.05)) + # # if not eval else identity_transform, + # transforms.ColorJitter((0.875, 1.125), (0.5, 1.5), (0.5, 1.5), (-0.05, 0.05)), # color jitter even in eval + # # transforms.Grayscale(), + # # normalize? + # transforms.ToDtype(torch.float32, scale=True), + # # transforms.ToPILImage(), # for visualization + # ]) + + # self.unified_transforms = transforms.Compose([ + # # transforms.Resize(self.img_size), + # transforms.RandomHorizontalFlip(), + # transforms.RandomVerticalFlip(), + # transforms.RandomApply([ + # transforms.GaussianBlur(kernel_size=(3, 3)) + # ], p=0.33), # TODO: tune + # ]) if not eval else identity_transform + + self.img_transforms = transforms.Compose([ + transforms.PILToTensor(), + # transforms.ColorJitter((0.875, 1.125), (0.5, 1.5), (0.5, 1.5), (-0.05, 0.05)) + # if not eval else identity_transform, + transforms.ColorJitter((0.875, 1.125), (0.5, 1.5), (0.5, 1.5), (-0.05, 0.05)), # color jitter even in eval + # transforms.Grayscale(), + # normalize? + transforms.ConvertImageDtype(torch.float32), + # transforms.ToPILImage(), # for visualization + ]) + + self.unified_transforms = transforms.Compose([ + # transforms.Resize(self.img_size), + unified_hflip, + unified_vflip, + lambda img_and_pts: (transforms.RandomApply([ + transforms.GaussianBlur(kernel_size=(3, 3)) + ], p=0.33)(img_and_pts[0]), img_and_pts[1]), # TODO: tune + ]) if not eval else identity_transform + + def __getitem__(self, idx): + shard_path = self.shard_paths[idx // self.shard_size] + shard_ind = int(re.fullmatch(r"shard_(?P<shard_ind>\d+)", shard_path).group("shard_ind")) + img_ind = shard_ind * self.shard_size + idx % self.shard_size + img_path = os.path.join(self.data_root_path, shard_path, f"{img_ind}.jpg") + bbox_path = os.path.join(self.data_root_path, shard_path, f"{img_ind}_tensor.pt") + + assert os.path.isfile(img_path) and os.path.isfile(bbox_path), (img_path, bbox_path) + img = self.img_transforms(PIL.Image.open(img_path)) + bbox = torch.load(bbox_path) + + img, bbox = self.unified_transforms((img, bbox)) + + return { + "imgs": img, + "labels": self.get_gtruth(bbox, self.configs["img_size"]), + } + + def __len__(self): + return len(self.shard_paths) * self.shard_size + + +class NoamLR(torch.optim.lr_scheduler._LRScheduler): + """ + Taken from https://github.com/tugstugi/pytorch-saltnet/blob/master/utils/lr_scheduler.py + """ + + def __init__(self, optimizer, warmup_steps): # steps, not ratio + self.warmup_steps = warmup_steps + super().__init__(optimizer) + + def get_lr(self): + last_epoch = max(1, self.last_epoch) + scale = self.warmup_steps ** 0.5 * min(last_epoch ** (-0.5), last_epoch * self.warmup_steps ** (-1.5)) + return [base_lr * scale for base_lr in self.base_lrs] + + +def unified_hflip(img_and_pts): + img, corner_pts = img_and_pts + if random.random() < 0.5: + return img, corner_pts + + _, h, w = img.size() + img = transforms_f.hflip(img) + + # (x, y) -> (w - x, y) + corner_pts = torch.stack([w - corner_pts[:, 0], corner_pts[:, 1]], dim=1) + return img, corner_pts + + +def unified_vflip(img_and_pts): + img, corner_pts = img_and_pts + if random.random() < 0.5: + return img, corner_pts + + _, h, w = img.size() + img = transforms_f.vflip(img) + + # (x, y) -> (x, h - y) + corner_pts = torch.stack([corner_pts[:, 0], h - corner_pts[:, 1]], dim=1) + return img, corner_pts |