aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKevin Zhao2024-05-07 17:28:56 -0400
committerKevin Zhao2024-05-07 17:28:56 -0400
commit3a40a8fb033b41362d6e84dfeff976f85db2df23 (patch)
treeb2da763a2958e0afe51e17aa80e6a5176815183f
parent9d9c891d74737f363029164b4e431bb620406a2a (diff)
Separate localize_corners but still different cropping
-rw-r--r--.gitignore1
-rw-r--r--checkpts/QuantizedV2_Stage1_128_9.ptbin0 -> 18078 bytes
-rw-r--r--checkpts/QuantizedV5_Stage2_128_9.ptbin0 -> 18078 bytes
-rw-r--r--decoder.py139
-rw-r--r--decoder_cnn.py151
-rw-r--r--decoding_utils.py341
6 files changed, 428 insertions, 204 deletions
diff --git a/.gitignore b/.gitignore
index c883aa9..e3cbb6e 100644
--- a/.gitignore
+++ b/.gitignore
@@ -158,3 +158,4 @@ cython_debug/
Report_Stuff/
data/
*.mkv
+*.slurm
diff --git a/checkpts/QuantizedV2_Stage1_128_9.pt b/checkpts/QuantizedV2_Stage1_128_9.pt
new file mode 100644
index 0000000..0c495ea
--- /dev/null
+++ b/checkpts/QuantizedV2_Stage1_128_9.pt
Binary files differ
diff --git a/checkpts/QuantizedV5_Stage2_128_9.pt b/checkpts/QuantizedV5_Stage2_128_9.pt
new file mode 100644
index 0000000..f17556e
--- /dev/null
+++ b/checkpts/QuantizedV5_Stage2_128_9.pt
Binary files differ
diff --git a/decoder.py b/decoder.py
index 43d06a5..280cd36 100644
--- a/decoder.py
+++ b/decoder.py
@@ -2,9 +2,13 @@ import argparse
import time
import cv2
import numpy as np
+import torch
from creedsolo import RSCodec
from raptorq import Decoder
+from corner_training.models import QuantizedV2, QuantizedV5
+from decoding_utils import localize_corners_wrapper
+
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("-i", "--input", help="camera device index or input video file", default=0)
parser.add_argument("-o", "--output", help="output file for decoded data", default="out")
@@ -13,9 +17,19 @@ parser.add_argument("-y", "--width", help="grid width", default=100, type=int)
parser.add_argument("-l", "--level", help="error correction level", default=0.1, type=float)
parser.add_argument("-s", "--size", help="number of bytes to decode", type=int)
parser.add_argument("-p", "--psize", help="packet size", type=int)
+parser.add_argument("-v", "--version", help="0 - original; 1 - CNN", default=0, choices=[0, 1], type=int)
+
args = parser.parse_args()
-cheight = cwidth = max(args.height // 10, args.width // 10)
+if args.version == 0:
+ cheight = cwidth = max(args.height // 10, args.width // 10)
+elif args.version == 1:
+ assert args.height * 3 % 80 == args.width * 3 % 80 == 0
+ cheight = int(args.height * 0.15)
+ cwidth = int(args.width * 0.15)
+else:
+ raise NotImplementedError
+
frame_size = args.height * args.width - 4 * cheight * cwidth
frame_bytes = frame_size * 3 // 8
frame_xor = np.arange(frame_bytes, dtype=np.uint8)
@@ -24,24 +38,60 @@ rs_bytes = frame_bytes - (frame_bytes + 254) // 255 * int(args.level * 255) - 4
rsc = RSCodec(int(args.level * 255))
decoder = Decoder.with_defaults(args.size, rs_bytes)
+input_crop_size = 1024
+
+if args.version == 0:
+ def find_corner(A, f):
+ cx, cy = A.shape[:2]
+ # Resize so smaller dim is 8
+ scale = min(cx // 8, cy // 8)
+ B = cv2.resize(A, (cy // scale, cx // scale), interpolation=cv2.INTER_AREA)
+ guess = np.array(np.unravel_index(np.argmax(f(B.astype(np.float64))), B.shape[:2])) * scale + scale // 2
+ mask = cv2.floodFill(
+ A,
+ np.empty(0),
+ tuple(np.flip(guess)),
+ 0,
+ (100, 100, 100),
+ (100, 100, 100),
+ cv2.FLOODFILL_MASK_ONLY + cv2.FLOODFILL_FIXED_RANGE,
+ )[2][1:-1, 1:-1].astype(bool)
+ return np.average(np.where(mask), axis=1), np.average(A[mask], axis=0).astype(np.float64)
+
+ def localize_corners(cropped_frame):
+ """
+ Returns (reconstructed grid, (wcol, rcol, gcol, bcol))
+ """
+ X, Y = cropped_frame.shape[:2]
+ cx, cy = X // 3, Y // 3
+ widx, wcol = find_corner(cropped_frame[:cx, :cy], lambda B: np.sum(B, axis=2) - 2 * np.std(B, axis=2))
+ ridx, rcol = find_corner(cropped_frame[:cx, Y - cy:], lambda B: B[:, :, 0] - B[:, :, 1] - B[:, :, 2])
+ ridx[1] += Y - cy
+ gidx, gcol = find_corner(cropped_frame[X - cx:, :cy], lambda B: B[:, :, 1] - B[:, :, 2] - B[:, :, 0])
+ gidx[0] += X - cx
+ bidx, bcol = find_corner(cropped_frame[X - cx:, Y - cy:], lambda B: B[:, :, 2] - B[:, :, 0] - B[:, :, 1])
+ bidx[0] += X - cx
+ bidx[1] += Y - cy
+
+ cch = cheight / 2 - 1
+ ccw = cwidth / 2 - 1
+ M = cv2.getPerspectiveTransform(
+ np.float32([np.flip(widx), np.flip(ridx), np.flip(gidx), np.flip(bidx)]),
+ np.float32(
+ [
+ [ccw, cch],
+ [args.width - ccw - 1, cch],
+ [ccw, args.height - cch - 1],
+ [args.width - ccw - 1, args.height - cch - 1],
+ ]
+ ),
+ )
-def find_corner(A, f):
- cx, cy = A.shape[:2]
- # Resize so smaller dim is 8
- scale = min(cx // 8, cy // 8)
- B = cv2.resize(A, (cy // scale, cx // scale), interpolation=cv2.INTER_AREA)
- guess = np.array(np.unravel_index(np.argmax(f(B.astype(np.float64))), B.shape[:2])) * scale + scale // 2
- mask = cv2.floodFill(
- A,
- np.empty(0),
- tuple(np.flip(guess)),
- 0,
- (100, 100, 100),
- (100, 100, 100),
- cv2.FLOODFILL_MASK_ONLY + cv2.FLOODFILL_FIXED_RANGE,
- )[2][1:-1, 1:-1].astype(bool)
- return np.average(np.where(mask), axis=1), np.average(A[mask], axis=0).astype(np.float64)
+ frame = cv2.warpPerspective(cropped_frame, M, (args.width, args.height))
+ return frame, (wcol, rcol, gcol, bcol)
+elif args.version == 1:
+ localize_corners = localize_corners_wrapper(args, input_crop_size)
if args.input.isdecimal():
args.input = int(args.input)
@@ -54,26 +104,23 @@ while data is None:
if not ret:
print("End of stream")
break
- # raw_frame is a uint8 BE CAREFUL
- if type(args.input) == int:
- # Crop image to reduce camera distortion
- X, Y = raw_frame.shape[:2]
- raw_frame = raw_frame[X // 4 : 3 * X // 4, Y // 4 : 3 * Y // 4]
+
+ if args.version == 0:
+ # raw_frame is a uint8 BE CAREFUL
+ if type(args.input) == int:
+ # Crop image to reduce camera distortion
+ X, Y = raw_frame.shape[:2]
+ raw_frame = raw_frame[X // 4: 3 * X // 4, Y // 4: 3 * Y // 4]
+ elif args.version == 1:
+ h, w, _ = raw_frame.shape
+ raw_frame = raw_frame[(h - input_crop_size) // 2:-(h - input_crop_size) // 2,
+ (w - input_crop_size) // 2:-(w - input_crop_size) // 2]
+
cv2.imshow("", raw_frame)
cv2.waitKey(1)
raw_frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2RGB)
- # Find positions and colors of corners
- X, Y = raw_frame.shape[:2]
- cx, cy = X // 3, Y // 3
- widx, wcol = find_corner(raw_frame[:cx, :cy], lambda B: np.sum(B, axis=2) - 2 * np.std(B, axis=2))
- ridx, rcol = find_corner(raw_frame[:cx, Y - cy :], lambda B: B[:, :, 0] - B[:, :, 1] - B[:, :, 2])
- ridx[1] += Y - cy
- gidx, gcol = find_corner(raw_frame[X - cx :, :cy], lambda B: B[:, :, 1] - B[:, :, 2] - B[:, :, 0])
- gidx[0] += X - cx
- bidx, bcol = find_corner(raw_frame[X - cx :, Y - cy :], lambda B: B[:, :, 2] - B[:, :, 0] - B[:, :, 1])
- bidx[0] += X - cx
- bidx[1] += Y - cy
+ frame, (wcol, rcol, gcol, bcol) = localize_corners(raw_frame)
# Find basis of color space
origin = (rcol + gcol + bcol - wcol) / 2
@@ -82,31 +129,17 @@ while data is None:
bcol -= origin
F = 255 * np.linalg.inv(np.stack((rcol, gcol, bcol)).T)
- cch = cheight / 2 - 1
- ccw = cwidth / 2 - 1
- M = cv2.getPerspectiveTransform(
- np.float32([np.flip(widx), np.flip(ridx), np.flip(gidx), np.flip(bidx)]),
- np.float32(
- [
- [ccw, cch],
- [args.width - ccw - 1, cch],
- [ccw, args.height - cch - 1],
- [args.width - ccw - 1, args.height - cch - 1],
- ]
- ),
- )
- frame = cv2.warpPerspective(raw_frame, M, (args.width, args.height))
# Convert to new color space
frame = (np.squeeze(F @ (frame - origin)[..., np.newaxis]) >= 192).astype(np.uint8)
- # import matplotlib.pyplot as pltc
- # plt.imshow(frame * 255)
- # plt.show()
+ import matplotlib.pyplot as plt
+ plt.imshow(frame * 255)
+ plt.show()
frame = np.packbits(
np.concatenate(
(
- frame[:cheight, cwidth : args.width - cwidth].flatten(),
- frame[cheight : args.height - cheight].flatten(),
- frame[args.height - cheight :, cwidth : args.width - cwidth].flatten(),
+ frame[:cheight, cwidth: args.width - cwidth].flatten(),
+ frame[cheight: args.height - cheight].flatten(),
+ frame[args.height - cheight:, cwidth: args.width - cwidth].flatten(),
)
)
)
diff --git a/decoder_cnn.py b/decoder_cnn.py
deleted file mode 100644
index ae4df07..0000000
--- a/decoder_cnn.py
+++ /dev/null
@@ -1,151 +0,0 @@
-import argparse
-import traceback
-import cv2
-import numpy as np
-import torch
-from creedsolo import RSCodec
-from matplotlib import pyplot as plt
-from raptorq import Decoder
-
-from corner_training.models import QuantizedV2, QuantizedV5
-from decoding_utils import localize_corners_wrapper
-
-parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
-parser.add_argument("-i", "--input", help="camera device index or input video file", default=0)
-parser.add_argument("-o", "--output", help="output file for decoded data", default="out")
-parser.add_argument("-x", "--height", help="grid height", default=100, type=int)
-parser.add_argument("-y", "--width", help="grid width", default=100, type=int)
-parser.add_argument("-l", "--level", help="error correction level", default=0.1, type=float)
-parser.add_argument("-s", "--size", help="number of bytes to decode", type=int)
-parser.add_argument("-p", "--psize", help="packet size", type=int)
-parser.add_argument("-v", "--version",
- help="0: 10% corners w/ two-sided one-cell padding; 1: 15% corners w/ four-sided 25% padding.",
- default=0, choices=[0, 1], type=int)
-args = parser.parse_args()
-
-assert args.version == 1
-# cell borders are 0.0375% of width/height
-assert args.height * 3 % 80 == args.width * 3 % 80 == 0
-cheight = int(args.height * 0.15)
-cwidth = int(args.width * 0.15)
-
-frame_size = args.height * args.width - 4 * cheight * cwidth
-frame_bytes = frame_size * 3 // 8
-frame_xor = np.arange(frame_bytes, dtype=np.uint8)
-rs_bytes = frame_bytes - (frame_bytes + 254) // 255 * int(args.level * 255) - 4
-
-rsc = RSCodec(int(args.level * 255))
-decoder = Decoder.with_defaults(args.size, rs_bytes)
-
-
-stage1_model_checkpt_path = "/Users/kevinzhao/Downloads/QuantizedV0_Stage1_128_9.pt"
-
-stage1_model = QuantizedV2()
-stage1_model.eval()
-stage1_model.fuse_modules(is_qat=False)
-stage1_model.qconfig = torch.ao.quantization.default_qconfig
-torch.ao.quantization.prepare(stage1_model, inplace=True)
-torch.ao.quantization.convert(stage1_model, inplace=True)
-stage1_model.load_state_dict(torch.load(stage1_model_checkpt_path, map_location=torch.device('cpu')))
-
-stage2_model = QuantizedV5()
-stage2_model.eval()
-stage2_model.fuse_modules(is_qat=False)
-stage2_model.qconfig = torch.ao.quantization.default_qconfig
-torch.ao.quantization.prepare(stage2_model, inplace=True)
-torch.ao.quantization.convert(stage2_model, inplace=True)
-stage2_model.load_state_dict(torch.load("/Users/kevinzhao/Downloads/QuantizedV5_Stage2_128_9.pt", map_location=torch.device('cpu')))
-
-# stage1_size = 128
-# stage2_size = 128
-
-stage1_size = 128
-stage2_size = 64
-
-input_crop_size = 1024
-
-localize_corners = localize_corners_wrapper(stage1_model, stage2_model, stage1_size, stage2_size)
-
-if args.input.isdecimal():
- args.input = int(args.input)
-cap = cv2.VideoCapture(args.input)
-data = None
-while data is None:
- try:
- ret, raw_frame = cap.read()
- if not ret:
- print("End of stream")
- break
- # # raw_frame is a uint8 BE CAREFUL
- # if type(args.input) == int:
- # # Crop image to reduce camera distortion
- # X, Y = raw_frame.shape[:2]
- # raw_frame = raw_frame[X // 4 : 3 * X // 4, Y // 4 : 3 * Y // 4]
- # cv2.imshow("", raw_frame)
- # cv2.waitKey(1)
- # raw_frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2RGB)
-
- h, w, _ = raw_frame.shape
- cropped_frame = raw_frame[(h - input_crop_size) // 2:-(h - input_crop_size) // 2,
- (w - input_crop_size) // 2:-(w - input_crop_size) // 2]
- cropped_frame = cv2.cvtColor(cropped_frame, cv2.COLOR_BGR2RGB)
- (widx, ridx, gidx, bidx), (wcol, rcol, gcol, bcol) = localize_corners(cropped_frame)
-
- widx = widx[::-1]
- ridx = ridx[::-1]
- gidx = gidx[::-1]
- bidx = bidx[::-1]
- # plt.imshow(cropped_frame)
- # plt.scatter([widx[1]], [widx[0]], color="r")
- # plt.scatter([ridx[1]], [ridx[0]], color="g")
- # plt.scatter([gidx[1]], [gidx[0]], color="b")
- # plt.scatter([bidx[1]], [bidx[0]], color="w")
- # plt.show()
-
- # Find basis of color space
- origin = (rcol + gcol + bcol - wcol) / 2
- rcol -= origin
- gcol -= origin
- bcol -= origin
- F = 255 * np.linalg.inv(np.stack((rcol, gcol, bcol)).T)
-
- # cch = cheight / 2 - 1
- # ccw = cwidth / 2 - 1
- cch = cheight / 4 - 1
- ccw = cwidth / 4 - 1
-
- M = cv2.getPerspectiveTransform(
- np.float32([np.flip(widx), np.flip(ridx), np.flip(gidx), np.flip(bidx)]),
- np.float32(
- [
- [ccw, cch],
- [args.width - ccw - 1, cch],
- [ccw, args.height - cch - 1],
- [args.width - ccw - 1, args.height - cch - 1],
- ]
- ),
- )
- frame = cv2.warpPerspective(cropped_frame, M, (args.width, args.height))
- # # Convert to new color space
- # frame = (np.squeeze(F @ (frame - origin)[..., np.newaxis]) >= 128).astype(np.uint8)
- frame = (np.squeeze(F @ (frame - origin)[..., np.newaxis]) >= 192).astype(np.uint8)
- # import matplotlib.pyplot as plt
- # # plt.imshow(frame * 255)
- # plt.imshow((1 - frame) * 255)
- # plt.show()
- frame = np.concatenate(
- (
- frame[:cheight, cwidth : args.width - cwidth].flatten(),
- frame[cheight : args.height - cheight].flatten(),
- frame[args.height - cheight :, cwidth : args.width - cwidth].flatten(),
- )
- )
- data = decoder.decode(bytes(rsc.decode(bytearray(np.packbits(frame) ^ frame_xor))[0][: args.psize]))
- print("Decoded frame")
- except KeyboardInterrupt:
- break
- except:
- traceback.print_exc()
-with open(args.output, "wb") as f:
- f.write(data)
-cap.release()
diff --git a/decoding_utils.py b/decoding_utils.py
new file mode 100644
index 0000000..c8f292a
--- /dev/null
+++ b/decoding_utils.py
@@ -0,0 +1,341 @@
+import itertools
+import time
+
+import cv2
+import numpy as np
+
+import torch
+import torchvision
+import torchvision.transforms.v2 as transforms
+import torchvision.transforms.v2.functional as transforms_f
+
+from corner_training.models import QuantizedV2, QuantizedV5
+from corner_training.utils import get_gaussian_filter, get_bounded_slices
+
+
+torch.backends.quantized.engine = 'qnnpack'
+
+
+def localize_corners_wrapper(args, input_crop_size, debug=False):
+ stage1_model_checkpt_path = "checkpts/QuantizedV2_Stage1_128_9.pt"
+ stage2_model_checkpt_path = "checkpts/QuantizedV5_Stage2_128_9.pt"
+
+ stage1_model = QuantizedV2()
+ stage1_model.eval()
+ stage1_model.fuse_modules(is_qat=False)
+ stage1_model.qconfig = torch.ao.quantization.default_qconfig
+ torch.ao.quantization.prepare(stage1_model, inplace=True)
+ torch.ao.quantization.convert(stage1_model, inplace=True)
+ stage1_model.load_state_dict(torch.load(stage1_model_checkpt_path, map_location=torch.device('cpu')))
+
+ stage2_model = QuantizedV5()
+ stage2_model.eval()
+ stage2_model.fuse_modules(is_qat=False)
+ stage2_model.qconfig = torch.ao.quantization.default_qconfig
+ torch.ao.quantization.prepare(stage2_model, inplace=True)
+ torch.ao.quantization.convert(stage2_model, inplace=True)
+ stage2_model.load_state_dict(torch.load(stage2_model_checkpt_path, map_location=torch.device('cpu')))
+
+ stage1_size = 128
+ stage2_size = input_crop_size // 16
+
+ assert stage1_size & 1 == 0, "Assuming even size when dividing into quadrants"
+ assert stage2_size & 1 == 0, "Assuming even size when center cropping"
+ stage1_model.eval()
+ stage2_model.eval()
+
+ preprocess_img_stage1 = transforms.Compose([
+ transforms.Lambda(lambda img: cv2.resize(img, (stage1_size, stage1_size), interpolation=cv2.INTER_NEAREST)),
+ transforms.ToImage(),
+ transforms.ToDtype(torch.float32, scale=True),
+ ])
+
+ gaussian_filter = get_gaussian_filter(4, 4) # for stage1 NMS heuristic
+
+ preprocess_img_stage2 = transforms.Compose([
+ transforms.ToImage(),
+ transforms.ToDtype(torch.float32, scale=True),
+ ])
+
+ # Transform cropped corners until they all look like top left corners, as that's what the model is trained on
+ transforms_by_corner = [
+ lambda img: img, # identity
+ transforms_f.hflip,
+ transforms_f.vflip,
+ lambda img: transforms_f.vflip(transforms_f.hflip(img))
+ ]
+
+ inv_transforms_by_corner = transforms_by_corner # flipping is a self-inverse
+
+ def localize_corners(cropped_frame: np.ndarray):
+ """
+ Args:
+ cropped_frame: Square numpy array
+ """
+ orig_h, orig_w, _ = cropped_frame.shape
+ assert orig_w == orig_h, "Assuming square img"
+ assert orig_w % stage1_size == 0
+ upscale_factor = orig_w // stage1_size # for stage 2
+
+ start_time = time.time()
+ stage1_img = preprocess_img_stage1(cropped_frame)
+ if debug:
+ print(54, time.time() - start_time)
+
+ with torch.no_grad():
+ stage1_pred = stage1_model(stage1_img.unsqueeze(0)).squeeze(0)
+
+ if debug:
+ print(57, time.time() - start_time)
+
+ quad_size = stage1_size // 2
+
+ corners_by_quad = dict()
+
+ for top_half in (0, 1): # TODO: bot/right to remove all 1 minuses
+ for left_half in (0, 1):
+ quad_i_start = quad_size * (1 - top_half)
+ quad_j_start = quad_size * (1 - left_half)
+ curr_quad_preds = stage1_pred[
+ quad_i_start: quad_i_start + quad_size,
+ quad_j_start: quad_j_start + quad_size,
+ ].clone()
+
+ max_locs = []
+ for i in range(6): # expect 4 points, but get top 6 to be safe
+ max_ind = torch.argmax(curr_quad_preds).item() # TODO: more efficient like segtree, maybe account for neighbors too
+ max_loc = (max_ind // quad_size, max_ind % quad_size)
+ max_locs.append(max_loc)
+
+ # TODO: improve, maybe scale Gaussian peak to val of max_loc, probably better to not subtract from a location multiple times
+ preds_slice, gaussian_slice = get_bounded_slices((quad_size, quad_size), gaussian_filter.size(),
+ *max_loc)
+ curr_quad_preds[preds_slice] -= gaussian_filter[gaussian_slice]
+
+ if debug:
+ print(f"{max_locs=}")
+
+ min_cost = 1e9
+ min_square = None
+ for potential_combo in itertools.combinations(max_locs, 4): # TODO: don't repeat symmetrical squares
+ curr_pts, curr_cost = score_combo(potential_combo)
+ if curr_cost < min_cost:
+ min_cost = curr_cost
+ min_square = curr_pts
+
+ if min_square is None:
+ print("all collinear")
+ return None
+ corners_by_quad[(1 - top_half) * 2 + (1 - left_half)] = [(i + quad_i_start, j + quad_j_start) for (i, j)
+ in min_square]
+ if debug:
+ print(92, time.time() - start_time)
+ print(corners_by_quad)
+
+ outer_corners = []
+ corner_colors = [] # by center, currently rounding to the pixel in the original image
+ origin = (quad_size, quad_size)
+ for quad in range(4): # TODO: consistent (x, y) or (i, j)
+ outer_corners.append(max((l2_dist(corner, origin), corner) for corner in corners_by_quad[quad])[1])
+ corner_colors.append(cropped_frame[int((sum(corner[0] for corner in corners_by_quad[quad]) / 4 * upscale_factor)),
+ int((sum(corner[1] for corner in corners_by_quad[quad]) / 4 * upscale_factor))]
+ .astype(np.float64))
+
+ stage2_imgs = []
+
+ for top_half in (0, 1): # TODO: bot/right to remove all 1 minuses
+ for left_half in (0, 1):
+ corner_ind = top_half * 2 + left_half
+ y, x = outer_corners[corner_ind]
+ upscaled_y, upscaled_x = y * upscale_factor, x * upscale_factor
+
+ top = max(0, upscaled_y - stage2_size // 2)
+ bottom = min(orig_h, upscaled_y + stage2_size // 2)
+ left = max(0, upscaled_x - stage2_size // 2)
+ right = min(orig_w, upscaled_x + stage2_size // 2)
+
+ # Need padding if detected corner is within `stage2_size // 2` of border
+ corner_padding = [0] * 4 # pad the side that does not affect extracted coordinates
+ corner_padding[(1 - top_half) * 2 + 1] = stage2_size - (bottom - top)
+ corner_padding[(1 - left_half) * 2] = stage2_size - (right - left)
+ cropped_corner_img = transforms_f.pad( # TODO: don't pad since that should speed up inference
+ preprocess_img_stage2(cropped_frame[top:bottom, left:right]),
+ corner_padding
+ )
+ stage2_imgs.append(cropped_corner_img)
+
+ transformed_corner_imgs = torch.stack([transforms_by_corner[corner_ind](stage2_img)
+ for corner_ind, stage2_img in enumerate(stage2_imgs)])
+
+ if debug:
+ print(121, time.time() - start_time)
+
+ with torch.no_grad():
+ transformed_preds = stage2_model(transformed_corner_imgs)
+
+ if debug:
+ print(125, time.time() - start_time)
+
+ transformed_pred_pts = [
+ torchvision.tv_tensors.BoundingBoxes(
+ [(max_ind := pred.argmax()) % stage2_size, max_ind // stage2_size, 0, 0],
+ format="XYWH", canvas_size=(stage2_size, stage2_size)
+ )
+ for pred in transformed_preds
+ ]
+
+ stage2_pred_pts = [inv_transforms_by_corner[corner_ind](transformed_pred_pt)[0, :2].tolist()
+ for corner_ind, transformed_pred_pt in enumerate(transformed_pred_pts)]
+
+ if debug:
+ print(137, time.time() - start_time)
+
+ orig_pred_pts = [(orig_x * upscale_factor + stage2_pred_x - stage2_size // 2,
+ orig_y * upscale_factor + stage2_pred_y - stage2_size // 2)
+ for (orig_y, orig_x), (stage2_pred_x, stage2_pred_y) in zip(outer_corners, stage2_pred_pts)]
+
+ if debug:
+ print(142, time.time() - start_time)
+
+ cch = int(args.height * 0.15) / 4 - 1
+ ccw = int(args.width * 0.15) / 4 - 1
+
+ M = cv2.getPerspectiveTransform(
+ np.float32(orig_pred_pts),
+ np.float32(
+ [
+ [ccw, cch],
+ [args.width - ccw - 1, cch],
+ [ccw, args.height - cch - 1],
+ [args.width - ccw - 1, args.height - cch - 1],
+ ]
+ ),
+ )
+
+ cropped_frame = cv2.warpPerspective(cropped_frame, M, (args.width, args.height))
+
+ return cropped_frame, corner_colors
+
+ return localize_corners
+
+
+def l2_dist(loc, origin):
+ """ No sqrt """
+ return (loc[0] - origin[0]) ** 2 + (loc[1] - origin[1]) ** 2
+
+
+def score_combo(combo):
+ """
+ Plan:
+ 1. Check if pts are convex. If no, very bad quadrilateral.
+ 2. Check if diagonal lengths are within a factor of 1.5. If no, somewhat bad since far from right angles.
+ 3. If the above are satisfied, then simply return how close the side lengths are to being equal.
+ """
+ hull = convex_hull([Point(x, y) for x, y in combo]) # TODO: check how collinear case is handled
+ hull = [(pt.x, pt.y) for pt in hull] # convert back to tuple
+ if len(hull) != 4:
+ return None, 1e9
+
+ squared_diag0 = l2_dist(hull[0], hull[2])
+ squared_diag1 = l2_dist(hull[1], hull[3])
+ if squared_diag0 < squared_diag1: # swap so that diag0 is larger
+ squared_diag0, squared_diag1 = squared_diag1, squared_diag0
+
+ if squared_diag0 / squared_diag1 > 1.5**2:
+ return hull, 1e8
+
+ cyclic_pts = hull + [hull[0]]
+ side_lens = [l2_dist(cyclic_pts[i], cyclic_pts[i + 1]) for i in range(4)]
+
+ return hull, (max(side_lens) - min(side_lens)) / min(side_lens)
+
+
+# Gift wrapping code, adapted from GeeksForGeeks.
+# "This code is contributed by Akarsh Somani, IIIT Kalyani"
+class Point:
+ def __init__(self, x, y):
+ self.x = x
+ self.y = y
+
+
+def left_index(points):
+ """
+ Finding the left most point
+ """
+ minn = 0
+ for i in range(1,len(points)):
+ if points[i].x < points[minn].x:
+ minn = i
+ elif points[i].x == points[minn].x:
+ if points[i].y > points[minn].y:
+ minn = i
+ return minn
+
+
+def orientation(p, q, r):
+ """
+ To find orientation of ordered triplet (p, q, r).
+ The function returns following values
+ 0 --> p, q and r are collinear
+ 1 --> Clockwise
+ 2 --> Counterclockwise
+ """
+ val = (q.y - p.y) * (r.x - q.x) - \
+ (q.x - p.x) * (r.y - q.y)
+
+ if val == 0:
+ return 0
+ elif val > 0:
+ return 1
+ else:
+ return 2
+
+
+def convex_hull(points):
+ n = len(points)
+ assert n >= 3, "There must be at least 3 points."
+
+ # Find the leftmost point
+ l = left_index(points)
+
+ hull = []
+
+ '''
+ Start from leftmost point, keep moving counterclockwise
+ until reach the start point again. This loop runs O(h)
+ times where h is number of points in result or output.
+ '''
+ p = l
+ q = 0
+ while True:
+ # Add current point to result
+ hull.append(points[p])
+
+ '''
+ Search for a point 'q' such that orientation(p, q,
+ x) is counterclockwise for all points 'x'. The idea
+ is to keep track of last visited most counterclock-
+ wise point in q. If any point 'i' is more counterclock-
+ wise than q, then update q.
+ '''
+ q = (p + 1) % n
+
+ for i in range(n):
+ # If i is more counterclockwise
+ # than current q, then update q
+ if(orientation(points[p],
+ points[i], points[q]) == 2):
+ q = i
+
+ '''
+ Now q is the most counterclockwise with respect to p
+ Set p as q for next iteration, so that q is added to
+ result 'hull'
+ '''
+ p = q
+
+ # While we don't come to first point
+ if p == l:
+ break
+
+ return hull