#!/usr/bin/python

# pripamtopng
#
# Python Raster Image PAM to PNG

from __future__ import print_function

import struct
import sys

from array import array

import png


def read_pam_header(infile):
    """
    Read (the rest of a) PAM header.
    `infile` should be positioned immediately after the initial 'P7' line
    (at the beginning of the second line).
    Returns are as for `read_pnm_header`.
    """

    # Unlike PBM, PGM, and PPM, we can read the header a line at a time.
    header = dict()
    while True:
        line = infile.readline().strip()
        if line == b"ENDHDR":
            break
        if not line:
            raise EOFError("PAM ended prematurely")
        if line[0] == b"#":
            continue
        line = line.split(None, 1)
        key = line[0]
        if key not in header:
            header[key] = line[1]
        else:
            header[key] += b" " + line[1]

    required = [b"WIDTH", b"HEIGHT", b"DEPTH", b"MAXVAL"]
    required_str = b", ".join(required).decode("ascii")
    result = []
    for token in required:
        if token not in header:
            raise png.Error("PAM file must specify " + required_str)
        try:
            x = int(header[token])
        except ValueError:
            raise png.Error(required_str + " must all be valid integers")
        if x <= 0:
            raise png.Error(required_str + " must all be positive integers")
        result.append(x)

    return ("P7",) + tuple(result)


def read_pnm_header(infile):
    """
    Read a PNM header, returning (format,width,height,depth,maxval).
    Also reads a PAM header (by using a helper function).
    `width` and `height` are in pixels.
    `depth` is the number of channels in the image;
    for PBM and PGM it is synthesized as 1, for PPM as 3;
    for PAM images it is read from the header.
    `maxval` is synthesized (as 1) for PBM images.
    """

    # Generally, see http://netpbm.sourceforge.net/doc/ppm.html
    # and http://netpbm.sourceforge.net/doc/pam.html

    supported = (b"P5", b"P6", b"P7")

    # Technically 'P7' must be followed by a newline,
    # so by using rstrip() we are being liberal in what we accept.
    # I think this is acceptable.
    type = infile.read(3).rstrip()
    if type not in supported:
        raise NotImplementedError("file format %s not supported" % type)
    if type == b"P7":
        # PAM header parsing is completely different.
        return read_pam_header(infile)

    # Expected number of tokens in header (3 for P4, 4 for P6)
    expected = 4
    pbm = (b"P1", b"P4")
    if type in pbm:
        expected = 3
    header = [type]

    # We must read the rest of the header byte by byte because
    # the final whitespace character may not be a newline.
    # Of course all PNM files in the wild use a newline at this point,
    # but we are strong and so we avoid
    # the temptation to use readline.
    bs = bytearray()
    backs = bytearray()

    def next():
        if backs:
            c = bytes(backs[0:1])
            del backs[0]
        else:
            c = infile.read(1)
            if not c:
                raise png.Error("premature EOF reading PNM header")
        bs.extend(c)
        return c

    def backup():
        """Push last byte of token onto front of backs."""
        backs.insert(0, bs[-1])
        del bs[-1]

    def ignore():
        del bs[:]

    def tokens():
        ls = lexInit
        while True:
            token, ls = ls()
            if token:
                yield token

    def lexInit():
        c = next()
        # Skip comments
        if b"#" <= c <= b"#":
            while c not in b"\n\r":
                c = next()
            ignore()
            return None, lexInit
        # Skip whitespace (that precedes a token)
        if c.isspace():
            ignore()
            return None, lexInit
        if not c.isdigit():
            raise png.Error("unexpected byte %r found in header" % c)
        return None, lexNumber

    def lexNumber():
        # According to the specification it is legal to have comments
        # that appear in the middle of a token.
        # I've never seen it; and,
        # it's a bit awkward to code good lexers in Python (no goto).
        # So we break on such cases.
        c = next()
        while c.isdigit():
            c = next()
        backup()
        token = bs[:]
        ignore()
        return token, lexInit

    for token in tokens():
        # All "tokens" are decimal integers, so convert them here.
        header.append(int(token))
        if len(header) == expected:
            break

    final = next()
    if not final.isspace():
        raise png.Error("expected header to end with whitespace, not %r" % final)

    if type in pbm:
        # synthesize a MAXVAL
        header.append(1)
    depth = (1, 3)[type == b"P6"]
    return header[0], header[1], header[2], depth, header[3]


def convert_pnm(w, infile, outfile):
    """
    Convert a PNM file containing raw pixel data into
    a PNG file with the parameters set in the writer object.
    Works for (binary) PGM, PPM, and PAM formats.
    """

    rows = scan_rows_from_file(infile, w.width, w.height, w.planes, w.bitdepth)
    w.write(outfile, rows)


def scan_rows_from_file(infile, width, height, planes, bitdepth):
    """
    Generate a sequence of rows from the input file `infile`.
    The input file should be in a "Netpbm-like" binary format.
    The input file should be positioned at the beginning of the first pixel.
    The number of pixels to read is taken from
    the image dimensions (`width`, `height`, `planes`);
    the number of bytes per value is implied by `bitdepth`.
    Each row is yielded as a single sequence of values.
    """

    # Values per row
    vpr = width * planes
    # Bytes per row
    bpr = vpr
    if bitdepth > 8:
        assert bitdepth == 16
        bpr *= 2
        fmt = ">%dH" % vpr

        def line():
            return array("H", struct.unpack(fmt, infile.read(bpr)))

    else:

        def line():
            return array("B", infile.read(bpr))

    for y in range(height):
        yield line()


def parse_args(args):
    """
    Create a parser and parse the command line arguments.
    """
    from argparse import ArgumentParser

    parser = ArgumentParser()
    version = "%(prog)s " + png.__version__
    parser.add_argument("--version", action="version", version=version)
    parser.add_argument(
        "-c",
        "--compression",
        type=int,
        metavar="level",
        help="zlib compression level (0-9)",
    )
    parser.add_argument("file", help="input PAM/PNM file to convert")
    args = parser.parse_args(args)
    return args


def main(argv=None):
    if argv is None:
        argv = sys.argv

    args = parse_args(argv[1:])

    # Prepare input and output files
    infile = png.cli_open(args.file)

    # Call after parsing, so that --version and --help work.
    outfile = png.binary_stdout()

    # Encode PNM to PNG
    format, width, height, depth, maxval = read_pnm_header(infile)

    # The NetPBM depth (number of channels) completely
    # determines the PNG format.
    # Observe:
    # - L, LA, RGB, RGBA are the 4 modes supported by PNG;
    # - they correspond to 1, 2, 3, 4 channels respectively.
    # We use the number of channels in the source image to
    # determine which one we have.
    # We ignore the NetPBM image type and the PAM TUPLTYPE.
    greyscale = depth <= 2
    pamalpha = depth in (2, 4)
    supported = [2 ** x - 1 for x in range(1, 17)]
    try:
        mi = supported.index(maxval)
    except ValueError:
        raise NotImplementedError(
            "input maxval (%s) not in supported list %s" % (maxval, str(supported))
        )
    bitdepth = mi + 1
    writer = png.Writer(
        width,
        height,
        greyscale=greyscale,
        bitdepth=bitdepth,
        alpha=pamalpha,
        compression=args.compression,
    )
    convert_pnm(writer, infile, outfile)


if __name__ == "__main__":
    try:
        sys.exit(main())
    except png.Error as e:
        print(e, file=sys.stderr)
        sys.exit(99)
