diff options
author | Kevin Zhao | 2024-05-07 14:02:26 -0400 |
---|---|---|
committer | Kevin Zhao | 2024-05-07 14:02:26 -0400 |
commit | 9205312ad56d5d04987797b1c796cf0b57cd1835 (patch) | |
tree | 8134c7b186570a1e74122a10c65bf2454f62060c | |
parent | 3d4862e725d29ab910f4a958717749abe12aad7a (diff) |
Preliminary CNN decoding
-rw-r--r-- | corner_training/models.py | 28 | ||||
-rw-r--r-- | decoder_cnn.py | 151 | ||||
-rw-r--r-- | encoder.py | 42 |
3 files changed, 211 insertions, 10 deletions
diff --git a/corner_training/models.py b/corner_training/models.py index dff0ac2..26d67cc 100644 --- a/corner_training/models.py +++ b/corner_training/models.py @@ -151,3 +151,31 @@ class QuantizedV3(nn.Module): ["block1.3", "block1.4", "block1.5"], ["block2.0", "block2.1", "block2.2"], ], inplace=True) + + +class QuantizedV5(nn.Module): + """normal convs, biasless""" + def __init__(self, in_channels=3): + super().__init__() + + self.layers = nn.Sequential( + torch.quantization.QuantStub(), + # nn.BatchNorm2d(in_channels), + nn.Conv2d(in_channels, 32, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(32), + nn.ReLU(), + nn.Conv2d(32, 32, kernel_size=3, dilation=2, padding=2, bias=False), + nn.BatchNorm2d(32), + nn.ReLU(), + nn.Conv2d(32, 8, kernel_size=3, padding=1), + torch.quantization.DeQuantStub(), + ) + + def forward(self, imgs): + x = self.layers(imgs) + x = torch.max(x, dim=-3)[0] # max over channel dim + return torch.sigmoid(x) + + def fuse_modules(self, is_qat=False): + fuse_modules = torch.ao.quantization.fuse_modules_qat if is_qat else torch.ao.quantization.fuse_modules + fuse_modules(self, [[f"layers.{i}", f"layers.{i+1}", f"layers.{i+2}"] for i in range(1, 6, 3)], inplace=True) diff --git a/decoder_cnn.py b/decoder_cnn.py new file mode 100644 index 0000000..ae4df07 --- /dev/null +++ b/decoder_cnn.py @@ -0,0 +1,151 @@ +import argparse +import traceback +import cv2 +import numpy as np +import torch +from creedsolo import RSCodec +from matplotlib import pyplot as plt +from raptorq import Decoder + +from corner_training.models import QuantizedV2, QuantizedV5 +from decoding_utils import localize_corners_wrapper + +parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) +parser.add_argument("-i", "--input", help="camera device index or input video file", default=0) +parser.add_argument("-o", "--output", help="output file for decoded data", default="out") +parser.add_argument("-x", "--height", help="grid height", default=100, type=int) +parser.add_argument("-y", "--width", help="grid width", default=100, type=int) +parser.add_argument("-l", "--level", help="error correction level", default=0.1, type=float) +parser.add_argument("-s", "--size", help="number of bytes to decode", type=int) +parser.add_argument("-p", "--psize", help="packet size", type=int) +parser.add_argument("-v", "--version", + help="0: 10% corners w/ two-sided one-cell padding; 1: 15% corners w/ four-sided 25% padding.", + default=0, choices=[0, 1], type=int) +args = parser.parse_args() + +assert args.version == 1 +# cell borders are 0.0375% of width/height +assert args.height * 3 % 80 == args.width * 3 % 80 == 0 +cheight = int(args.height * 0.15) +cwidth = int(args.width * 0.15) + +frame_size = args.height * args.width - 4 * cheight * cwidth +frame_bytes = frame_size * 3 // 8 +frame_xor = np.arange(frame_bytes, dtype=np.uint8) +rs_bytes = frame_bytes - (frame_bytes + 254) // 255 * int(args.level * 255) - 4 + +rsc = RSCodec(int(args.level * 255)) +decoder = Decoder.with_defaults(args.size, rs_bytes) + + +stage1_model_checkpt_path = "/Users/kevinzhao/Downloads/QuantizedV0_Stage1_128_9.pt" + +stage1_model = QuantizedV2() +stage1_model.eval() +stage1_model.fuse_modules(is_qat=False) +stage1_model.qconfig = torch.ao.quantization.default_qconfig +torch.ao.quantization.prepare(stage1_model, inplace=True) +torch.ao.quantization.convert(stage1_model, inplace=True) +stage1_model.load_state_dict(torch.load(stage1_model_checkpt_path, map_location=torch.device('cpu'))) + +stage2_model = QuantizedV5() +stage2_model.eval() +stage2_model.fuse_modules(is_qat=False) +stage2_model.qconfig = torch.ao.quantization.default_qconfig +torch.ao.quantization.prepare(stage2_model, inplace=True) +torch.ao.quantization.convert(stage2_model, inplace=True) +stage2_model.load_state_dict(torch.load("/Users/kevinzhao/Downloads/QuantizedV5_Stage2_128_9.pt", map_location=torch.device('cpu'))) + +# stage1_size = 128 +# stage2_size = 128 + +stage1_size = 128 +stage2_size = 64 + +input_crop_size = 1024 + +localize_corners = localize_corners_wrapper(stage1_model, stage2_model, stage1_size, stage2_size) + +if args.input.isdecimal(): + args.input = int(args.input) +cap = cv2.VideoCapture(args.input) +data = None +while data is None: + try: + ret, raw_frame = cap.read() + if not ret: + print("End of stream") + break + # # raw_frame is a uint8 BE CAREFUL + # if type(args.input) == int: + # # Crop image to reduce camera distortion + # X, Y = raw_frame.shape[:2] + # raw_frame = raw_frame[X // 4 : 3 * X // 4, Y // 4 : 3 * Y // 4] + # cv2.imshow("", raw_frame) + # cv2.waitKey(1) + # raw_frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2RGB) + + h, w, _ = raw_frame.shape + cropped_frame = raw_frame[(h - input_crop_size) // 2:-(h - input_crop_size) // 2, + (w - input_crop_size) // 2:-(w - input_crop_size) // 2] + cropped_frame = cv2.cvtColor(cropped_frame, cv2.COLOR_BGR2RGB) + (widx, ridx, gidx, bidx), (wcol, rcol, gcol, bcol) = localize_corners(cropped_frame) + + widx = widx[::-1] + ridx = ridx[::-1] + gidx = gidx[::-1] + bidx = bidx[::-1] + # plt.imshow(cropped_frame) + # plt.scatter([widx[1]], [widx[0]], color="r") + # plt.scatter([ridx[1]], [ridx[0]], color="g") + # plt.scatter([gidx[1]], [gidx[0]], color="b") + # plt.scatter([bidx[1]], [bidx[0]], color="w") + # plt.show() + + # Find basis of color space + origin = (rcol + gcol + bcol - wcol) / 2 + rcol -= origin + gcol -= origin + bcol -= origin + F = 255 * np.linalg.inv(np.stack((rcol, gcol, bcol)).T) + + # cch = cheight / 2 - 1 + # ccw = cwidth / 2 - 1 + cch = cheight / 4 - 1 + ccw = cwidth / 4 - 1 + + M = cv2.getPerspectiveTransform( + np.float32([np.flip(widx), np.flip(ridx), np.flip(gidx), np.flip(bidx)]), + np.float32( + [ + [ccw, cch], + [args.width - ccw - 1, cch], + [ccw, args.height - cch - 1], + [args.width - ccw - 1, args.height - cch - 1], + ] + ), + ) + frame = cv2.warpPerspective(cropped_frame, M, (args.width, args.height)) + # # Convert to new color space + # frame = (np.squeeze(F @ (frame - origin)[..., np.newaxis]) >= 128).astype(np.uint8) + frame = (np.squeeze(F @ (frame - origin)[..., np.newaxis]) >= 192).astype(np.uint8) + # import matplotlib.pyplot as plt + # # plt.imshow(frame * 255) + # plt.imshow((1 - frame) * 255) + # plt.show() + frame = np.concatenate( + ( + frame[:cheight, cwidth : args.width - cwidth].flatten(), + frame[cheight : args.height - cheight].flatten(), + frame[args.height - cheight :, cwidth : args.width - cwidth].flatten(), + ) + ) + data = decoder.decode(bytes(rsc.decode(bytearray(np.packbits(frame) ^ frame_xor))[0][: args.psize])) + print("Decoded frame") + except KeyboardInterrupt: + break + except: + traceback.print_exc() +with open(args.output, "wb") as f: + f.write(data) +cap.release() @@ -12,10 +12,21 @@ parser.add_argument("-y", "--width", help="grid width", default=100, type=int) parser.add_argument("-l", "--level", help="error correction level", default=0.1, type=float) parser.add_argument("-f", "--fps", help="frame rate", default=30, type=int) parser.add_argument("-m", "--mix", help="mix frames with original video", action="store_true") +parser.add_argument("-v", "--version", + help="0: 10% corners w/ two-sided one-cell padding; 1: 15% corners w/ four-sided 25% padding.", + default=0, choices=[0, 1], type=int) args = parser.parse_args() +if args.version == 0: + cheight = cwidth = max(args.height // 10, args.width // 10) +elif args.version == 1: + # cell borders are 0.0375% of width/height + assert args.height * 3 % 80 == args.width * 3 % 80 == 0 # TODO: less strict better ratio + cheight = int(args.height * 0.15) + cwidth = int(args.width * 0.15) +else: + raise NotImplementedError -cheight = cwidth = max(args.height // 10, args.width // 10) midwidth = args.width - 2 * cwidth frame_size = args.height * args.width - 4 * cheight * cwidth # Divide by 8 / 3 for 3-bit color @@ -32,15 +43,26 @@ encoder = Encoder.with_defaults(data, rs_bytes) packets = encoder.get_encoded_packets(int(len(data) / rs_bytes * (1 / (1 - args.level) - 1))) # Make corners -ones = np.ones((cheight - 1, cwidth - 1)) -zeros = np.zeros((cheight - 1, cwidth - 1)) -wcorner = np.pad(np.dstack((ones, ones, ones)), ((0, 1), (0, 1), (0, 0))) -rcorner = np.pad(np.dstack((ones, zeros, zeros)), ((0, 1), (1, 0), (0, 0))) -gcorner = np.pad(np.dstack((zeros, ones, zeros)), ((1, 0), (0, 1), (0, 0))) -bcorner = np.pad(np.dstack((zeros, zeros, ones)), ((1, 0), (1, 0), (0, 0))) +if args.version == 0: + ones = np.ones((cheight - 1, cwidth - 1)) + zeros = np.zeros((cheight - 1, cwidth - 1)) + wcorner = np.pad(np.dstack((ones, ones, ones)), ((0, 1), (0, 1), (0, 0))) + rcorner = np.pad(np.dstack((ones, zeros, zeros)), ((0, 1), (1, 0), (0, 0))) + gcorner = np.pad(np.dstack((zeros, ones, zeros)), ((1, 0), (0, 1), (0, 0))) + bcorner = np.pad(np.dstack((zeros, zeros, ones)), ((1, 0), (1, 0), (0, 0))) +elif args.version == 1: + zeros = np.zeros((cheight, cwidth, 3)) + wcorner = zeros.copy() + rcorner = zeros.copy() + gcorner = zeros.copy() + bcorner = zeros.copy() + black_border_h, black_border_w = cheight // 4, cwidth // 4 + for corner_arr, ones_channel_ind in [(wcorner, 0), (wcorner, 1), (wcorner, 2), + (rcorner, 0), (gcorner, 1), (bcorner, 2)]: + corner_arr[black_border_h:-black_border_h, black_border_w:-black_border_w, ones_channel_ind] = np.ones((cheight // 2, cwidth // 2)) # Output flags for decoder -print(f"-x {args.height} -y {args.width} -l {args.level} -s {len(data)} -p {len(packets[0])}", end="") +print(f"-x {args.height} -y {args.width} -l {args.level} -s {len(data)} -p {len(packets[0])} -v {args.version}", end="") def mkframe(packet): @@ -55,11 +77,11 @@ def mkframe(packet): (wcorner, frame[: cheight * midwidth].reshape((cheight, midwidth, 3)), rcorner), axis=1, ), - frame[cheight * midwidth : frame_size - cheight * midwidth].reshape( + frame[cheight * midwidth: frame_size - cheight * midwidth].reshape( (args.height - 2 * cheight, args.width, 3) ), np.concatenate( - (gcorner, frame[frame_size - cheight * midwidth :].reshape((cheight, midwidth, 3)), bcorner), + (gcorner, frame[frame_size - cheight * midwidth:].reshape((cheight, midwidth, 3)), bcorner), axis=1, ), ) |