from krita import Extension, Krita
from PyQt5.QtWidgets import (QDialog, QVBoxLayout, QHBoxLayout, QLabel, QPushButton,
                             QSpinBox, QDoubleSpinBox, QComboBox, QTabWidget, QWidget,
                             QTextEdit)
from PyQt5.QtGui import QImage, QPixmap
from PyQt5.QtCore import Qt
import math, time

EXTENSION_ID = "pykrita_fast_filters"


def active_doc_node_roi():
    app = Krita.instance()
    doc = app.activeDocument()
    if doc is None:
        return None, None, (0, 0, 0, 0)
    node = doc.activeNode()
    sel = doc.selection()
    if sel is not None and not sel.isEmpty():
        x, y, w, h = sel.bounds()
    else:
        x, y, w, h = 0, 0, node.width(), node.height()
    return doc, node, (x, y, w, h)


def clamp(i, lo, hi):
    return lo if i < lo else (hi if i > hi else i)


def mirror(i, lo, hi):
    if i < lo:
        return lo + (lo - i - 1)
    if i >= hi:
        return hi - (i - hi + 1)
    return i


def gaussian_kernel_1d(sigma, radius=None):
    if radius is None:
        radius = int(math.ceil(3 * sigma))
    size = radius * 2 + 1
    k = [0.0] * size
    s2 = 2 * sigma * sigma
    sumv = 0.0
    for i in range(size):
        x = i - radius
        v = math.exp(-(x * x) / s2)
        k[i] = v
        sumv += v
    k = [vi / sumv for vi in k]
    return k, radius


def convolve_rows_bgr(src_bytes, w, h, kernel, radius, border='clamp'):
    out = bytearray(len(src_bytes))
    for y in range(h):
        for x in range(w):
            acc = [0.0, 0.0, 0.0]
            for i, kv in enumerate(kernel):
                sx = x + (i - radius)
                if border == 'clamp':
                    sx = clamp(sx, 0, w - 1)
                else:
                    sx = mirror(sx, 0, w)
                idx = (y * w + sx) * 4
                acc[0] += src_bytes[idx + 0] * kv
                acc[1] += src_bytes[idx + 1] * kv
                acc[2] += src_bytes[idx + 2] * kv
            dst = (y * w + x) * 4
            out[dst + 0] = int(clamp(round(acc[0]), 0, 255))
            out[dst + 1] = int(clamp(round(acc[1]), 0, 255))
            out[dst + 2] = int(clamp(round(acc[2]), 0, 255))
            out[dst + 3] = src_bytes[dst + 3]
    return out


def convolve_cols_bgr(src_bytes, w, h, kernel, radius, border='clamp'):
    out = bytearray(len(src_bytes))
    for y in range(h):
        for x in range(w):
            acc = [0.0, 0.0, 0.0]
            for i, kv in enumerate(kernel):
                sy = y + (i - radius)
                if border == 'clamp':
                    sy = clamp(sy, 0, h - 1)
                else:
                    sy = mirror(sy, 0, h)
                idx = (sy * w + x) * 4
                acc[0] += src_bytes[idx + 0] * kv
                acc[1] += src_bytes[idx + 1] * kv
                acc[2] += src_bytes[idx + 2] * kv
            dst = (y * w + x) * 4
            out[dst + 0] = int(clamp(round(acc[0]), 0, 255))
            out[dst + 1] = int(clamp(round(acc[1]), 0, 255))
            out[dst + 2] = int(clamp(round(acc[2]), 0, 255))
            out[dst + 3] = src_bytes[dst + 3]
    return out


def gaussian_blur_separable(src_bytes, w, h, sigma, radius=None, border='clamp'):
    kernel, r = gaussian_kernel_1d(sigma, radius)
    tmp = convolve_rows_bgr(src_bytes, w, h, kernel, r, border)
    out = convolve_cols_bgr(tmp, w, h, kernel, r, border)
    return out


SOBEL_X = [-1, 0, 1, -2, 0, 2, -1, 0, 1]
SOBEL_Y = [-1, -2, -1, 0, 0, 0, 1, 2, 1]


def rgb_to_luma(r, g, b):
    # Rec. 601 luma
    return 0.299 * r + 0.587 * g + 0.114 * b


def sobel_luma_bgra(src_bytes, w, h, border='clamp', want='magnitude'):
    out = bytearray(len(src_bytes))
    for y in range(h):
        for x in range(w):
            gx = 0.0
            gy = 0.0
            for ky in range(3):
                for kx in range(3):
                    sx = x + (kx - 1)
                    sy = y + (ky - 1)
                    if border == 'clamp':
                        sx = clamp(sx, 0, w - 1)
                        sy = clamp(sy, 0, h - 1)
                    else:
                        sx = mirror(sx, 0, w)
                        sy = mirror(sy, 0, h)
                    idx = (sy * w + sx) * 4
                    r = src_bytes[idx + 2]
                    g = src_bytes[idx + 1]
                    b = src_bytes[idx + 0]
                    l = rgb_to_luma(r, g, b)
                    kxv = SOBEL_X[ky * 3 + kx]
                    kyv = SOBEL_Y[ky * 3 + kx]
                    gx += l * kxv
                    gy += l * kyv
            if want == 'gx':
                val = gx
            elif want == 'gy':
                val = gy
            else:
                val = math.hypot(gx, gy)
            v = int(clamp(round(val), 0, 255))
            dst = (y * w + x) * 4
            out[dst + 0] = v
            out[dst + 1] = v
            out[dst + 2] = v
            out[dst + 3] = src_bytes[dst + 3]
    return out


class FastFiltersDialog(QDialog):
    def __init__(self, parent=None):
        super().__init__(parent)
        self.setWindowTitle('FastFilters')
        self.tabs = QTabWidget()

        self.gauss_tab = QWidget()
        self.sobel_tab = QWidget()
        self.bench_tab = QWidget()

        self.setup_gauss()
        self.setup_sobel()
        self.setup_bench()

        self.tabs.addTab(self.gauss_tab, 'Gaussian')
        self.tabs.addTab(self.sobel_tab, 'Sobel')
        self.tabs.addTab(self.bench_tab, 'Benchmark')

        layout = QVBoxLayout()
        layout.addWidget(self.tabs)
        self.setLayout(layout)

    def setup_gauss(self):
        l = QVBoxLayout()
        h = QHBoxLayout()
        h.addWidget(QLabel('Sigma'))
        self.sigma_spin = QDoubleSpinBox()
        self.sigma_spin.setRange(0.1, 100.0)
        self.sigma_spin.setValue(1.6)
        h.addWidget(self.sigma_spin)
        h.addWidget(QLabel('Radius (0=auto)'))
        self.radius_spin = QSpinBox()
        self.radius_spin.setRange(0, 64)
        self.radius_spin.setValue(0)
        h.addWidget(self.radius_spin)
        self.border_combo = QComboBox()
        self.border_combo.addItems(['clamp', 'mirror'])
        h.addWidget(QLabel('Border'))
        h.addWidget(self.border_combo)
        l.addLayout(h)
        btn_h = QHBoxLayout()
        self.gauss_preview_btn = QPushButton('Preview')
        self.gauss_apply_btn = QPushButton('Apply')
        self.gauss_preview_btn.clicked.connect(self.on_gauss_preview)
        self.gauss_apply_btn.clicked.connect(self.on_gauss_apply)
        btn_h.addWidget(self.gauss_preview_btn)
        btn_h.addWidget(self.gauss_apply_btn)
        l.addLayout(btn_h)
        self.gauss_preview_label = QLabel()
        self.gauss_preview_label.setFixedSize(256, 256)
        l.addWidget(self.gauss_preview_label)
        self.gauss_tab.setLayout(l)

    def setup_sobel(self):
        l = QVBoxLayout()
        h = QHBoxLayout()
        h.addWidget(QLabel('Output'))
        self.sobel_combo = QComboBox()
        self.sobel_combo.addItems(['magnitude', 'gx', 'gy'])
        h.addWidget(self.sobel_combo)
        l.addLayout(h)
        btn_h = QHBoxLayout()
        self.sobel_preview_btn = QPushButton('Preview')
        self.sobel_apply_btn = QPushButton('Apply')
        self.sobel_preview_btn.clicked.connect(self.on_sobel_preview)
        self.sobel_apply_btn.clicked.connect(self.on_sobel_apply)
        btn_h.addWidget(self.sobel_preview_btn)
        btn_h.addWidget(self.sobel_apply_btn)
        l.addLayout(btn_h)
        self.sobel_preview_label = QLabel()
        self.sobel_preview_label.setFixedSize(256, 256)
        l.addWidget(self.sobel_preview_label)
        self.sobel_tab.setLayout(l)

    def setup_bench(self):
        l = QVBoxLayout()
        self.bench_text = QTextEdit()
        self.bench_text.setReadOnly(True)
        self.bench_run_btn = QPushButton('Run benchmark')
        self.bench_run_btn.clicked.connect(self.on_bench_run)
        l.addWidget(self.bench_run_btn)
        l.addWidget(self.bench_text)
        self.bench_tab.setLayout(l)

    def on_gauss_preview(self):
        doc, node, (x, y, w, h) = active_doc_node_roi()
        if doc is None:
            return
        max_side = 256
        scale = 1
        if max(w, h) > max_side:
            scale = max_side / max(w, h)
        sw = max(1, int(w * scale))
        sh = max(1, int(h * scale))
        pixel_bytes = node.pixelData(x, y, w, h)
        src = bytearray(pixel_bytes)
        small = bytearray(sw * sh * 4)
        for yy in range(sh):
            for xx in range(sw):
                sx = int(xx / scale)
                sy = int(yy / scale)
                sidx = (sy * w + sx) * 4
                didx = (yy * sw + xx) * 4
                small[didx:didx+4] = src[sidx:sidx+4]
        sigma = float(self.sigma_spin.value())
        r = int(self.radius_spin.value()) or None
        out = gaussian_blur_separable(small, sw, sh, sigma, radius=r, border=self.border_combo.currentText())
        img = QImage(bytes(out), sw, sh, QImage.Format_RGBA8888)
        pix = QPixmap.fromImage(img).scaled(self.gauss_preview_label.size(), Qt.KeepAspectRatio)
        self.gauss_preview_label.setPixmap(pix)

    def on_gauss_apply(self):
        doc, node, (x, y, w, h) = active_doc_node_roi()
        if doc is None:
            return
        pixel_bytes = node.pixelData(x, y, w, h)
        src = bytearray(pixel_bytes)
        sigma = float(self.sigma_spin.value())
        r = int(self.radius_spin.value()) or None
        out = gaussian_blur_separable(src, w, h, sigma, radius=r, border=self.border_combo.currentText())
        node.setPixelData(bytes(out), x, y, w, h)
        doc.refreshProjection()

    def on_sobel_preview(self):
        doc, node, (x, y, w, h) = active_doc_node_roi()
        if doc is None:
            return
        max_side = 256
        scale = 1
        if max(w, h) > max_side:
            scale = max_side / max(w, h)
        sw = max(1, int(w * scale))
        sh = max(1, int(h * scale))
        pixel_bytes = node.pixelData(x, y, w, h)
        src = bytearray(pixel_bytes)
        small = bytearray(sw * sh * 4)
        for yy in range(sh):
            for xx in range(sw):
                sx = int(xx / scale)
                sy = int(yy / scale)
                sidx = (sy * w + sx) * 4
                didx = (yy * sw + xx) * 4
                small[didx:didx+4] = src[sidx:sidx+4]
        out = sobel_luma_bgra(small, sw, sh, border='clamp', want=self.sobel_combo.currentText())
        img = QImage(bytes(out), sw, sh, QImage.Format_RGBA8888)
        pix = QPixmap.fromImage(img).scaled(self.sobel_preview_label.size(), Qt.KeepAspectRatio)
        self.sobel_preview_label.setPixmap(pix)

    def on_sobel_apply(self):
        doc, node, (x, y, w, h) = active_doc_node_roi()
        if doc is None:
            return
        pixel_bytes = node.pixelData(x, y, w, h)
        src = bytearray(pixel_bytes)
        out = sobel_luma_bgra(src, w, h, border='clamp', want=self.sobel_combo.currentText())
        node.setPixelData(bytes(out), x, y, w, h)
        doc.refreshProjection()

    def on_bench_run(self):
        doc, node, (x, y, w, h) = active_doc_node_roi()
        if doc is None:
            self.bench_text.append('No document')
            return
        sizes = [(640, 360), (1280, 720), (1920, 1080)]
        results = []
        pixel_bytes = node.pixelData(x, y, w, h)
        src = bytearray(pixel_bytes)
        for (bw, bh) in sizes:
            # create a crop/resample of ROI into bw,bh
            small = bytearray(bw * bh * 4)
            for yy in range(bh):
                for xx in range(bw):
                    sx = int(xx * w / bw)
                    sy = int(yy * h / bh)
                    sidx = (sy * w + sx) * 4
                    didx = (yy * bw + xx) * 4
                    small[didx:didx+4] = src[sidx:sidx+4]
            # gauss
            t0 = time.time()
            _ = gaussian_blur_separable(small, bw, bh, 1.6, radius=None, border='clamp')
            t1 = time.time()
            # sobel
            _ = sobel_luma_bgra(small, bw, bh, border='clamp', want='magnitude')
            t2 = time.time()
            results.append((bw, bh, t1 - t0, t2 - t1))
        # print csv
        self.bench_text.append('w,h,gauss_s,sobel_s')
        for r in results:
            self.bench_text.append(f"{r[0]},{r[1]},{r[2]:.4f},{r[3]:.4f}")


class FastFilters(Extension):
    def __init__(self, parent=None):
        super().__init__(parent)

    def setup(self):
        pass

    def createActions(self, window):
        action = window.createAction('fast_filters_action', 'FastFilters', 'tools/scripts')
        action.triggered.connect(self.run)

    def run(self):
        dlg = FastFiltersDialog()
        dlg.exec_()
