#!/usr/bin/env python3
# -*- coding: utf-8 -*-

# Copyright 2019-2020, 2024 Daniel Estevez <daniel@destevez.net>
#
# This file is part of gr-satellites
#
# SPDX-License-Identifier: GPL-3.0-or-later
#

import argparse
from enum import Enum
import struct
import subprocess
import sys

import numpy as np

from satellites.kiss import *
from satellites.telemetry.by02 import TMPrimaryHeader as TMPrimaryHeaderShort
from satellites.telemetry.erminaz import TMPrimaryHeader


class Satellite(Enum):
    JY1SAT = 'JY1SAT'
    ERMINAZ = 'ERMINAZ'
    DSLWP = 'DSLWP'

    def __str__(self):
        return self.value


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--satellite', type=Satellite, choices=list(Satellite),
        help='Satellite')
    parser.add_argument('input', help='Input KISS file')
    parser.add_argument(
        'output',
        help='Output filenames (start of filenames)')
    return parser.parse_args()


def seqnum(packet, satellite):
    offsets = {
        Satellite.JY1SAT: 3,
        Satellite.ERMINAZ: 8,
        Satellite.DSLWP: 1,
    }
    off = offsets[satellite]
    return struct.unpack('>H', packet[off:off+2])[0]


def read_kiss_file(path, framesize=None):
    frames = list()
    frame = list()
    transpose = False
    with open(path, 'rb') as f:
        for c in f.read():
            if c == FEND:
                if ((framesize is None or len(frame) == framesize + 1)
                    and len(frame) > 1
                    and (frame[0] & 0x0f) == 0):
                    frames.append(frame[1:])
                frame = list()
            elif transpose:
                if c == TFEND:
                    frame.append(FEND)
                elif c == TFESC:
                    frame.append(FESC)
                transpose = False
            elif c == FESC:
                transpose = True
            else:
                frame.append(c)
    return np.array(frames, dtype='uint8')


def main():
    args = parse_args()

    # Read frames
    framesizes = {Satellite.JY1SAT: 256}
    framesize = framesizes.get(args.satellite)
    x = read_kiss_file(args.input, framesize)

    if args.satellite == Satellite.JY1SAT:
        # Filter out by frame id and trim to payload
        x = x[((x[:, 0] == 0xe0) | (x[:, 0] == 0xe1)) & (x[:, 1] == 0x10), 56:]

        # Filter SSDV packets
        x = x[(x[:, 0] == 0x55) & (x[:, 1] == 0x68), :]
    elif args.satellite == Satellite.ERMINAZ:
        vcid = 4
        headers = [TMPrimaryHeader.parse(y) for y in x]
        x = np.array([y for h, y in zip(headers, x)
                      if h.virtual_channel_id == vcid])
        x = x[:, TMPrimaryHeader.sizeof() + 2:]
    elif args.satellite == Satellite.DSLWP:
        vcid = 1
        headers = [TMPrimaryHeaderShort.parse(y) for y in x]
        x = np.array([y for h, y in zip(headers, x)
                      if h.virtual_channel_id == vcid])
        x = x[:, TMPrimaryHeaderShort.sizeof():]

    if x.size == 0:
        # there are no SSDV packets
        return

    id_idxs = {
        Satellite.JY1SAT: 2,
        Satellite.ERMINAZ: 6,
        Satellite.DSLWP: 0,
    }
    id_idx = id_idxs[args.satellite]
    ids = set(x[:, id_idx])

    decoder_args = {
        Satellite.JY1SAT: ['-J'],
        Satellite.ERMINAZ: ['-l', str(x.shape[1])],
        Satellite.DSLWP: ['-D'],
    }[args.satellite]
    for i in ids:
        L = list(x[x[:, id_idx] == i, :])
        L.sort(key=lambda x: seqnum(x, satellite=args.satellite))
        ssdv = '{}_{}.ssdv'.format(args.output, i)
        jpeg = '{}_{}.jpg'.format(args.output, i)
        np.array(L).tofile(ssdv)
        print('Calling SSDV decoder for image {}'.format(hex(i)))
        subprocess.call(['ssdv', '-d'] + decoder_args + [ssdv, jpeg])
        print()


if __name__ == '__main__':
    main()
