from krita import Extension, Krita
from PyQt5.QtWidgets import (QDialog, QVBoxLayout, QHBoxLayout, QTableWidget,
                             QTableWidgetItem, QPushButton, QCheckBox, QLabel,
                             QComboBox, QWidget, QSpinBox)
from PyQt5.QtGui import QImage, QPixmap
from PyQt5.QtCore import Qt

EXTENSION_ID = "pykrita_convolution_lab"
MENU_PATH = "tools/scripts"


def active_doc_node_roi():
    app = Krita.instance()
    doc = app.activeDocument()
    if doc is None:
        return None, None, (0, 0, 0, 0)
    node = doc.activeNode()
    # selection bounding box
    sel = doc.selection()
    if sel is not None and not sel.isEmpty():
        x, y, w, h = sel.bounds()
    else:
        x, y, w, h = 0, 0, node.width(), node.height()
    return doc, node, (x, y, w, h)


def clamp(i, lo, hi):
    return lo if i < lo else (hi if i > hi else i)


def mirror(i, lo, hi):
    if i < lo:
        return lo + (lo - i - 1)
    if i >= hi:
        return hi - (i - hi + 1)
    return i


def convolve_bgra_block(src_bytes, w, h, kernel, ksize=3, border='clamp'):
    # src_bytes: bytearray or bytes in BGRA order
    out = bytearray(len(src_bytes))
    half = ksize // 2
    get_idx = lambda xx, yy: (yy * w + xx) * 4
    for y in range(h):
        for x in range(w):
            sum_rgb = [0, 0, 0]
            for ky in range(ksize):
                for kx in range(ksize):
                    sx = x + (kx - half)
                    sy = y + (ky - half)
                    if border == 'clamp':
                        sx = clamp(sx, 0, w - 1)
                        sy = clamp(sy, 0, h - 1)
                    else:
                        sx = mirror(sx, 0, w)
                        sy = mirror(sy, 0, h)
                    idx = get_idx(sx, sy)
                    k = kernel[ky * ksize + kx]
                    # B,G,R only
                    sum_rgb[0] += src_bytes[idx + 0] * k
                    sum_rgb[1] += src_bytes[idx + 1] * k
                    sum_rgb[2] += src_bytes[idx + 2] * k
            dst_idx = get_idx(x, y)
            out[dst_idx + 0] = int(clamp(round(sum_rgb[0]), 0, 255))
            out[dst_idx + 1] = int(clamp(round(sum_rgb[1]), 0, 255))
            out[dst_idx + 2] = int(clamp(round(sum_rgb[2]), 0, 255))
            # copy alpha
            src_idx = get_idx(x, y)
            out[dst_idx + 3] = src_bytes[src_idx + 3]
    return out


class ConvolutionDialog(QDialog):
    def __init__(self, parent=None):
        super().__init__(parent)
        self.setWindowTitle('ConvolutionLab')
        self.kSize = 3
        self.kernel_table = QTableWidget(3, 3)
        for y in range(3):
            for x in range(3):
                item = QTableWidgetItem('0')
                self.kernel_table.setItem(y, x, item)
        # default: identity
        self.kernel_table.item(1, 1).setText('1')

        self.normalize_cb = QCheckBox('Normalize sum')
        self.preserve_cb = QCheckBox('Preserve brightness')
        self.border_combo = QComboBox()
        self.border_combo.addItems(['clamp', 'mirror'])

        self.preview_label = QLabel('Preview')
        self.preview_label.setFixedSize(256, 256)

        self.preview_btn = QPushButton('Preview')
        self.apply_btn = QPushButton('Apply')

        self.preview_btn.clicked.connect(self.on_preview)
        self.apply_btn.clicked.connect(self.on_apply)

        top_layout = QHBoxLayout()
        left = QVBoxLayout()
        left.addWidget(self.kernel_table)
        left.addWidget(self.normalize_cb)
        left.addWidget(self.preserve_cb)
        left.addWidget(QLabel('Border'))
        left.addWidget(self.border_combo)
        left.addStretch()

        right = QVBoxLayout()
        right.addWidget(self.preview_label)
        right.addWidget(self.preview_btn)
        right.addWidget(self.apply_btn)
        right.addStretch()

        top_layout.addLayout(left)
        top_layout.addLayout(right)

        self.setLayout(top_layout)

    def read_kernel(self):
        k = []
        for y in range(self.kSize):
            for x in range(self.kSize):
                it = self.kernel_table.item(y, x)
                try:
                    v = float(it.text())
                except Exception:
                    v = 0.0
                k.append(v)
        return k

    def on_preview(self):
        doc, node, (x, y, w, h) = active_doc_node_roi()
        if doc is None:
            return
        # downsample to max 256
        max_side = 256
        scale = 1
        if max(w, h) > max_side:
            scale = max_side / max(w, h)
        sw = max(1, int(w * scale))
        sh = max(1, int(h * scale))
        pixel_bytes = node.pixelData(x, y, w, h)
        # convert to bytearray and downsample simple nearest
        src = bytearray(pixel_bytes)
        # create small buffer
        small = bytearray(sw * sh * 4)
        for yy in range(sh):
            for xx in range(sw):
                sx = int(xx / scale)
                sy = int(yy / scale)
                sidx = (sy * w + sx) * 4
                didx = (yy * sw + xx) * 4
                small[didx:didx+4] = src[sidx:sidx+4]

        k = self.read_kernel()
        if self.normalize_cb.isChecked():
            s = sum(k)
            if s != 0:
                k = [vi / s for vi in k]
        out = convolve_bgra_block(small, sw, sh, k, ksize=self.kSize, border=self.border_combo.currentText())
        # build QImage and show
        img = QImage(bytes(out), sw, sh, QImage.Format_RGBA8888)
        pix = QPixmap.fromImage(img).scaled(self.preview_label.size(), Qt.KeepAspectRatio)
        self.preview_label.setPixmap(pix)

    def on_apply(self):
        doc, node, (x, y, w, h) = active_doc_node_roi()
        if doc is None:
            return
        pixel_bytes = node.pixelData(x, y, w, h)
        src = bytearray(pixel_bytes)
        k = self.read_kernel()
        if self.normalize_cb.isChecked():
            s = sum(k)
            if s != 0:
                k = [vi / s for vi in k]
        out = convolve_bgra_block(src, w, h, k, ksize=self.kSize, border=self.border_combo.currentText())
        node.setPixelData(bytes(out), x, y, w, h)
        doc.refreshProjection()


class ConvolutionLab(Extension):
    def __init__(self, parent=None):
        super().__init__(parent)

    def setup(self):
        pass

    def createActions(self, window):
        action = window.createAction('convolution_lab_action', 'ConvolutionLab', 'tools/scripts')
        action.triggered.connect(self.run)

    def run(self):
        dlg = ConvolutionDialog()
        dlg.exec_()
