# (C) Copyright 2005-2021 Enthought, Inc., Austin, TX
# All rights reserved.
#
# This software is provided without warranty under the terms of the BSD
# license included in LICENSE.txt and may be redistributed only under
# the conditions described in the aforementioned license. The license
# is also available online at http://www.enthought.com/licenses/BSD.txt
#
# Thanks for using Enthought open source!

""" Helper functions for working with images

This module provides helper functions for converting between numpy arrays
and Qt QImages in a standardized way.
"""

# Imports of numpy are deferred so we can keep it an optional dependency.

from pyface.qt import qt_api
from pyface.qt.QtGui import QImage


def QImage_to_array(qimage):
    """ Convert a QImage to a numpy array.

    This copies the data returned from Qt.

    Parameters
    ----------
    qimage : QImage
        The QImage that we want to extract the values from.  The format must
        be either RGB32 or ARGB32.

    Return
    ------
    array : ndarray
        An N x M x 4 array of unsigned 8-bit ints as RGBA values.
    """
    import numpy as np

    width, height = qimage.width(), qimage.height()
    channels = qimage.pixelFormat().channelCount()
    data = qimage.bits()
    if qt_api in {'pyqt', 'pyqt5'}:
        data = data.asarray(width * height * channels)
    array = np.array(data, dtype='uint8')
    array.shape = (height, width, channels)
    if qimage.format() in {QImage.Format_RGB32, QImage.Format_ARGB32}:
        # comes in as BGRA, but want RGBA
        array = array[:, :, [2, 1, 0, 3]]
    else:
        raise ValueError(
            "Unsupported QImage format {}".format(qimage.format())
        )
    return array


def array_to_QImage(array):
    """ Convert a numpy array to a QImage.

    This copies the data before passing it to Qt.

    Parameters
    ----------
    array : ndarray
        An N x M x {3, 4} array of unsigned 8-bit ints.  The image
        format is assumed to be RGB or RGBA, based on the shape.

    Return
    ------
    qimage : QImage
        The QImage created from the data.  The pixel format is
        QImage.Format_RGB32.
    """
    import numpy as np

    if array.ndim != 3:
        raise ValueError("Array must be either RGB or RGBA values.")

    height, width, channels = array.shape
    data = np.empty((height, width, 4), dtype='uint8')
    if channels == 3:
        data[:, :, [2, 1, 0]] = array
        data[:, :, 3] = 0xff
    elif channels == 4:
        data[:, :, [2, 1, 0, 3]] = array
    else:
        raise ValueError("Array must be either RGB or RGBA values.")

    bytes_per_line = 4 * width

    if channels == 3:
        image = QImage(data.data, width, height, bytes_per_line,
                       QImage.Format_RGB32)

    elif channels == 4:
        image = QImage(data.data, width, height, bytes_per_line,
                       QImage.Format_ARGB32)
    image._numpy_data = data
    return image

