#!/usr/bin/python3

# Read pixel data files created by a method call to
# de.rainerhock.eightbitwonders.C64EmulationTestBase.takeScreenShot
# Command line `view-pixeldata.py <size> <filename>`
# size can be c64 or <width>x<height>
#
import sys
import os
from PIL import Image, ImageTk


def get_dimensions(param, filename):
    sizes = {
        "c64": (768, 544, 835584),
        "c64-ntsc": (768, 524, 758784),
        "vic": (896, 586, 1017856),
        "vic-ntsc": (800, 468, 748800),
        "40col-borderless": (640, 400, 512000),
        "vic-borderless": (704, 368, 518144),
        "pet40" : (384, 266, 204288),
        "pet80" : (704, 532, 749056),
        "c128"  : (1712,1152, 3944448),
        "auto": (-1, -1, -1)
    }
    if param in sizes.keys():
        w = sizes.get(param)[0]
        h = sizes.get(param)[1]
        if w < 0 and h < 0:
            size = os.stat(filename).st_size
            for val in sizes.values():
                if size == val[2]:
                    w = val[0]
                    h = val[1]

    else:
        w = int(param.split("x")[0])
        h = int(param.split("x")[1])
    return w, h


def showfile(param, filename):
    try:
        with open(filename, "rb") as f:
            w, h = get_dimensions(param, filename)
            img = Image.new('RGB', (w, h), 0)
            pixels = img.load()
            for y in range(h):
                for x in range(w):
                    pixel = int.from_bytes(f.read(2), "little")
                    pixels[x, y] = ((pixel >> 8) & 0b11111000) \
                        | (((pixel >> 3) & 0b11111000) << 8) \
                        | ((pixel << 3) & 0b11111000) << 16
            img.show()
            return True

    except BaseException as e:
        sys.stderr.write(str(e) + "\n")
        return False


def main():
    if len(sys.argv) == 2:
        if showfile("auto", sys.argv[1]):
            return

    if len(sys.argv) == 3:
        if showfile(sys.argv[1], sys.argv[2]):
            return

    if len(sys.argv) == 5:
        print(sys.argv[2])
        print(sys.argv)
        if sys.argv[2] == "show-both":
            w0, h0 = get_dimensions(sys.argv[1], sys.argv[3])
            w1, h1 = get_dimensions(sys.argv[1], sys.argv[4])
            with open(sys.argv[3], "rb") as f1:
                with open(sys.argv[4], "rb") as f2:
                    img = Image.new('RGB', (2+w0+w1, h0 if h0 > h1 else h1), 0)
                    pixels = img.load()
                    for y in range(h0):
                        for x in range(w0):
                            pixel = int.from_bytes(f1.read(2), "little")
                            pixels[x, y] = ((pixel >> 8) & 0b11111000) \
                                | (((pixel >> 3) & 0b11111000) << 8) \
                                | ((pixel << 3) & 0b11111000) << 16
                    for y in range(h1):
                        for x in range(w1):
                            pixel = int.from_bytes(f2.read(2), "little")
                            pixels[x+w0+2, y] = ((pixel >> 8) & 0b11111000) \
                                | (((pixel >> 3) & 0b11111000) << 8) \
                                | ((pixel << 3) & 0b11111000) << 16
                    img.show()
                    return

    if len(sys.argv) == 4:
        try:
            w, h = get_dimensions(sys.argv[1], sys.argv[2])
            with open(sys.argv[2], "rb") as f1:
                with open(sys.argv[3], "rb") as f2:
                    img = Image.new('RGB', (w, h), 0)
                    pixels = img.load()
                    for y in range(h):
                        for x in range(w):
                            p1 = int.from_bytes(f1.read(2), "little")
                            p2 = int.from_bytes(f2.read(2), "little")
                            pixel = p1-p2
                            pixels[x, y] = ((pixel >> 8) & 0b11111000) \
                                | (((pixel >> 3) & 0b11111000) << 8) \
                                | ((pixel << 3) & 0b11111000) << 16

                    img.show()
                    return
        except BaseException as e:
            sys.stderr.write(str(e)+"\n")
            pass

    sys.stderr.write("usage:\n")
    sys.stderr.write(sys.argv[0]
                     + " <width>x<height> <filename> # (show image with w x h pixels in size)\n")
    sys.stderr.write(sys.argv[0]
                     + " <machinetype> <filename> # (show image with size from machinetype)\n")
    sys.stderr.write("    machinetypes are c64, c64-ntsc, vic, vic-ntsc, 40col-borderless\n")
    sys.stderr.write(sys.argv[0]
                     + " <<width>x<height> or <machinetype>> <filename1> <filename2> \
                     # (show different pixels)\n")
    sys.stderr.write(sys.argv[0]
                     + " <<width>x<height> or <machinetype>> show-both <filename1> <filename2> \
                     # (show both pictures)\n")


if __name__ == "__main__":
    main()
