aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKevin Zhao2024-05-07 14:02:26 -0400
committerKevin Zhao2024-05-07 14:02:26 -0400
commit9205312ad56d5d04987797b1c796cf0b57cd1835 (patch)
tree8134c7b186570a1e74122a10c65bf2454f62060c
parent3d4862e725d29ab910f4a958717749abe12aad7a (diff)
Preliminary CNN decoding
-rw-r--r--corner_training/models.py28
-rw-r--r--decoder_cnn.py151
-rw-r--r--encoder.py42
3 files changed, 211 insertions, 10 deletions
diff --git a/corner_training/models.py b/corner_training/models.py
index dff0ac2..26d67cc 100644
--- a/corner_training/models.py
+++ b/corner_training/models.py
@@ -151,3 +151,31 @@ class QuantizedV3(nn.Module):
["block1.3", "block1.4", "block1.5"],
["block2.0", "block2.1", "block2.2"],
], inplace=True)
+
+
+class QuantizedV5(nn.Module):
+ """normal convs, biasless"""
+ def __init__(self, in_channels=3):
+ super().__init__()
+
+ self.layers = nn.Sequential(
+ torch.quantization.QuantStub(),
+ # nn.BatchNorm2d(in_channels),
+ nn.Conv2d(in_channels, 32, kernel_size=3, padding=1, bias=False),
+ nn.BatchNorm2d(32),
+ nn.ReLU(),
+ nn.Conv2d(32, 32, kernel_size=3, dilation=2, padding=2, bias=False),
+ nn.BatchNorm2d(32),
+ nn.ReLU(),
+ nn.Conv2d(32, 8, kernel_size=3, padding=1),
+ torch.quantization.DeQuantStub(),
+ )
+
+ def forward(self, imgs):
+ x = self.layers(imgs)
+ x = torch.max(x, dim=-3)[0] # max over channel dim
+ return torch.sigmoid(x)
+
+ def fuse_modules(self, is_qat=False):
+ fuse_modules = torch.ao.quantization.fuse_modules_qat if is_qat else torch.ao.quantization.fuse_modules
+ fuse_modules(self, [[f"layers.{i}", f"layers.{i+1}", f"layers.{i+2}"] for i in range(1, 6, 3)], inplace=True)
diff --git a/decoder_cnn.py b/decoder_cnn.py
new file mode 100644
index 0000000..ae4df07
--- /dev/null
+++ b/decoder_cnn.py
@@ -0,0 +1,151 @@
+import argparse
+import traceback
+import cv2
+import numpy as np
+import torch
+from creedsolo import RSCodec
+from matplotlib import pyplot as plt
+from raptorq import Decoder
+
+from corner_training.models import QuantizedV2, QuantizedV5
+from decoding_utils import localize_corners_wrapper
+
+parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+parser.add_argument("-i", "--input", help="camera device index or input video file", default=0)
+parser.add_argument("-o", "--output", help="output file for decoded data", default="out")
+parser.add_argument("-x", "--height", help="grid height", default=100, type=int)
+parser.add_argument("-y", "--width", help="grid width", default=100, type=int)
+parser.add_argument("-l", "--level", help="error correction level", default=0.1, type=float)
+parser.add_argument("-s", "--size", help="number of bytes to decode", type=int)
+parser.add_argument("-p", "--psize", help="packet size", type=int)
+parser.add_argument("-v", "--version",
+ help="0: 10% corners w/ two-sided one-cell padding; 1: 15% corners w/ four-sided 25% padding.",
+ default=0, choices=[0, 1], type=int)
+args = parser.parse_args()
+
+assert args.version == 1
+# cell borders are 0.0375% of width/height
+assert args.height * 3 % 80 == args.width * 3 % 80 == 0
+cheight = int(args.height * 0.15)
+cwidth = int(args.width * 0.15)
+
+frame_size = args.height * args.width - 4 * cheight * cwidth
+frame_bytes = frame_size * 3 // 8
+frame_xor = np.arange(frame_bytes, dtype=np.uint8)
+rs_bytes = frame_bytes - (frame_bytes + 254) // 255 * int(args.level * 255) - 4
+
+rsc = RSCodec(int(args.level * 255))
+decoder = Decoder.with_defaults(args.size, rs_bytes)
+
+
+stage1_model_checkpt_path = "/Users/kevinzhao/Downloads/QuantizedV0_Stage1_128_9.pt"
+
+stage1_model = QuantizedV2()
+stage1_model.eval()
+stage1_model.fuse_modules(is_qat=False)
+stage1_model.qconfig = torch.ao.quantization.default_qconfig
+torch.ao.quantization.prepare(stage1_model, inplace=True)
+torch.ao.quantization.convert(stage1_model, inplace=True)
+stage1_model.load_state_dict(torch.load(stage1_model_checkpt_path, map_location=torch.device('cpu')))
+
+stage2_model = QuantizedV5()
+stage2_model.eval()
+stage2_model.fuse_modules(is_qat=False)
+stage2_model.qconfig = torch.ao.quantization.default_qconfig
+torch.ao.quantization.prepare(stage2_model, inplace=True)
+torch.ao.quantization.convert(stage2_model, inplace=True)
+stage2_model.load_state_dict(torch.load("/Users/kevinzhao/Downloads/QuantizedV5_Stage2_128_9.pt", map_location=torch.device('cpu')))
+
+# stage1_size = 128
+# stage2_size = 128
+
+stage1_size = 128
+stage2_size = 64
+
+input_crop_size = 1024
+
+localize_corners = localize_corners_wrapper(stage1_model, stage2_model, stage1_size, stage2_size)
+
+if args.input.isdecimal():
+ args.input = int(args.input)
+cap = cv2.VideoCapture(args.input)
+data = None
+while data is None:
+ try:
+ ret, raw_frame = cap.read()
+ if not ret:
+ print("End of stream")
+ break
+ # # raw_frame is a uint8 BE CAREFUL
+ # if type(args.input) == int:
+ # # Crop image to reduce camera distortion
+ # X, Y = raw_frame.shape[:2]
+ # raw_frame = raw_frame[X // 4 : 3 * X // 4, Y // 4 : 3 * Y // 4]
+ # cv2.imshow("", raw_frame)
+ # cv2.waitKey(1)
+ # raw_frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2RGB)
+
+ h, w, _ = raw_frame.shape
+ cropped_frame = raw_frame[(h - input_crop_size) // 2:-(h - input_crop_size) // 2,
+ (w - input_crop_size) // 2:-(w - input_crop_size) // 2]
+ cropped_frame = cv2.cvtColor(cropped_frame, cv2.COLOR_BGR2RGB)
+ (widx, ridx, gidx, bidx), (wcol, rcol, gcol, bcol) = localize_corners(cropped_frame)
+
+ widx = widx[::-1]
+ ridx = ridx[::-1]
+ gidx = gidx[::-1]
+ bidx = bidx[::-1]
+ # plt.imshow(cropped_frame)
+ # plt.scatter([widx[1]], [widx[0]], color="r")
+ # plt.scatter([ridx[1]], [ridx[0]], color="g")
+ # plt.scatter([gidx[1]], [gidx[0]], color="b")
+ # plt.scatter([bidx[1]], [bidx[0]], color="w")
+ # plt.show()
+
+ # Find basis of color space
+ origin = (rcol + gcol + bcol - wcol) / 2
+ rcol -= origin
+ gcol -= origin
+ bcol -= origin
+ F = 255 * np.linalg.inv(np.stack((rcol, gcol, bcol)).T)
+
+ # cch = cheight / 2 - 1
+ # ccw = cwidth / 2 - 1
+ cch = cheight / 4 - 1
+ ccw = cwidth / 4 - 1
+
+ M = cv2.getPerspectiveTransform(
+ np.float32([np.flip(widx), np.flip(ridx), np.flip(gidx), np.flip(bidx)]),
+ np.float32(
+ [
+ [ccw, cch],
+ [args.width - ccw - 1, cch],
+ [ccw, args.height - cch - 1],
+ [args.width - ccw - 1, args.height - cch - 1],
+ ]
+ ),
+ )
+ frame = cv2.warpPerspective(cropped_frame, M, (args.width, args.height))
+ # # Convert to new color space
+ # frame = (np.squeeze(F @ (frame - origin)[..., np.newaxis]) >= 128).astype(np.uint8)
+ frame = (np.squeeze(F @ (frame - origin)[..., np.newaxis]) >= 192).astype(np.uint8)
+ # import matplotlib.pyplot as plt
+ # # plt.imshow(frame * 255)
+ # plt.imshow((1 - frame) * 255)
+ # plt.show()
+ frame = np.concatenate(
+ (
+ frame[:cheight, cwidth : args.width - cwidth].flatten(),
+ frame[cheight : args.height - cheight].flatten(),
+ frame[args.height - cheight :, cwidth : args.width - cwidth].flatten(),
+ )
+ )
+ data = decoder.decode(bytes(rsc.decode(bytearray(np.packbits(frame) ^ frame_xor))[0][: args.psize]))
+ print("Decoded frame")
+ except KeyboardInterrupt:
+ break
+ except:
+ traceback.print_exc()
+with open(args.output, "wb") as f:
+ f.write(data)
+cap.release()
diff --git a/encoder.py b/encoder.py
index 634a166..fd7958b 100644
--- a/encoder.py
+++ b/encoder.py
@@ -12,10 +12,21 @@ parser.add_argument("-y", "--width", help="grid width", default=100, type=int)
parser.add_argument("-l", "--level", help="error correction level", default=0.1, type=float)
parser.add_argument("-f", "--fps", help="frame rate", default=30, type=int)
parser.add_argument("-m", "--mix", help="mix frames with original video", action="store_true")
+parser.add_argument("-v", "--version",
+ help="0: 10% corners w/ two-sided one-cell padding; 1: 15% corners w/ four-sided 25% padding.",
+ default=0, choices=[0, 1], type=int)
args = parser.parse_args()
+if args.version == 0:
+ cheight = cwidth = max(args.height // 10, args.width // 10)
+elif args.version == 1:
+ # cell borders are 0.0375% of width/height
+ assert args.height * 3 % 80 == args.width * 3 % 80 == 0 # TODO: less strict better ratio
+ cheight = int(args.height * 0.15)
+ cwidth = int(args.width * 0.15)
+else:
+ raise NotImplementedError
-cheight = cwidth = max(args.height // 10, args.width // 10)
midwidth = args.width - 2 * cwidth
frame_size = args.height * args.width - 4 * cheight * cwidth
# Divide by 8 / 3 for 3-bit color
@@ -32,15 +43,26 @@ encoder = Encoder.with_defaults(data, rs_bytes)
packets = encoder.get_encoded_packets(int(len(data) / rs_bytes * (1 / (1 - args.level) - 1)))
# Make corners
-ones = np.ones((cheight - 1, cwidth - 1))
-zeros = np.zeros((cheight - 1, cwidth - 1))
-wcorner = np.pad(np.dstack((ones, ones, ones)), ((0, 1), (0, 1), (0, 0)))
-rcorner = np.pad(np.dstack((ones, zeros, zeros)), ((0, 1), (1, 0), (0, 0)))
-gcorner = np.pad(np.dstack((zeros, ones, zeros)), ((1, 0), (0, 1), (0, 0)))
-bcorner = np.pad(np.dstack((zeros, zeros, ones)), ((1, 0), (1, 0), (0, 0)))
+if args.version == 0:
+ ones = np.ones((cheight - 1, cwidth - 1))
+ zeros = np.zeros((cheight - 1, cwidth - 1))
+ wcorner = np.pad(np.dstack((ones, ones, ones)), ((0, 1), (0, 1), (0, 0)))
+ rcorner = np.pad(np.dstack((ones, zeros, zeros)), ((0, 1), (1, 0), (0, 0)))
+ gcorner = np.pad(np.dstack((zeros, ones, zeros)), ((1, 0), (0, 1), (0, 0)))
+ bcorner = np.pad(np.dstack((zeros, zeros, ones)), ((1, 0), (1, 0), (0, 0)))
+elif args.version == 1:
+ zeros = np.zeros((cheight, cwidth, 3))
+ wcorner = zeros.copy()
+ rcorner = zeros.copy()
+ gcorner = zeros.copy()
+ bcorner = zeros.copy()
+ black_border_h, black_border_w = cheight // 4, cwidth // 4
+ for corner_arr, ones_channel_ind in [(wcorner, 0), (wcorner, 1), (wcorner, 2),
+ (rcorner, 0), (gcorner, 1), (bcorner, 2)]:
+ corner_arr[black_border_h:-black_border_h, black_border_w:-black_border_w, ones_channel_ind] = np.ones((cheight // 2, cwidth // 2))
# Output flags for decoder
-print(f"-x {args.height} -y {args.width} -l {args.level} -s {len(data)} -p {len(packets[0])}", end="")
+print(f"-x {args.height} -y {args.width} -l {args.level} -s {len(data)} -p {len(packets[0])} -v {args.version}", end="")
def mkframe(packet):
@@ -55,11 +77,11 @@ def mkframe(packet):
(wcorner, frame[: cheight * midwidth].reshape((cheight, midwidth, 3)), rcorner),
axis=1,
),
- frame[cheight * midwidth : frame_size - cheight * midwidth].reshape(
+ frame[cheight * midwidth: frame_size - cheight * midwidth].reshape(
(args.height - 2 * cheight, args.width, 3)
),
np.concatenate(
- (gcorner, frame[frame_size - cheight * midwidth :].reshape((cheight, midwidth, 3)), bcorner),
+ (gcorner, frame[frame_size - cheight * midwidth:].reshape((cheight, midwidth, 3)), bcorner),
axis=1,
),
)