from krita import Extension, Krita
from PyQt5.QtWidgets import (QDialog, QVBoxLayout, QHBoxLayout, QLabel, QPushButton,
                             QComboBox, QSpinBox, QCheckBox)
from PyQt5.QtGui import QImage, QPixmap
from PyQt5.QtCore import Qt
import math

# Safe numpy import — Krita may or may not have NumPy installed
try:
    import numpy as np
except Exception:
    np = None

EXTENSION_ID = "pykrita_freq_lab"


def active_doc_node_roi():
    app = Krita.instance()
    doc = app.activeDocument()
    if doc is None:
        return None, None, (0, 0, 0, 0)
    node = doc.activeNode()
    sel = doc.selection()
    if sel is not None and not sel.isEmpty():
        x, y, w, h = sel.bounds()
    else:
        x, y, w, h = 0, 0, node.width(), node.height()
    return doc, node, (x, y, w, h)


def bgra_to_luma_array(bgra, w, h):
    # bgra: bytes or bytearray length w*h*4
    # returns 2D numpy array uint8 if numpy available, else 2D list
    if np is not None:
        arr = np.frombuffer(bgra, dtype=np.uint8).reshape((h, w, 4))
        # use rec601 luma
        l = (0.299 * arr[:, :, 2] + 0.587 * arr[:, :, 1] + 0.114 * arr[:, :, 0]).astype(np.uint8)
        return l
    else:
        out = [[0] * w for _ in range(h)]
        for y in range(h):
            for x in range(w):
                idx = (y * w + x) * 4
                b = bgra[idx]
                g = bgra[idx + 1]
                r = bgra[idx + 2]
                out[y][x] = int(0.299 * r + 0.587 * g + 0.114 * b)
        return out


def try_import_numpy():
    return np


def fft_preview_numpy(gray_u8, window='none'):
    # gray_u8: 2D numpy uint8 array
    f = np.fft.fft2(gray_u8.astype(np.float32))
    f = np.fft.fftshift(f)
    mag = np.log1p(np.abs(f))
    # normalize to 0..255
    mag = mag - mag.min()
    mag = mag / (mag.max() + 1e-9) * 255.0
    return mag.astype(np.uint8)


def make_circular_mask(h, w, mode, r_in_px=0, r_out_px=10):
    if np is None:
        return None
    cy = h // 2
    cx = w // 2
    Y, X = np.ogrid[:h, :w]
    dist = np.sqrt((X - cx) ** 2 + (Y - cy) ** 2)
    if mode == 'low':
        M = dist <= r_out_px
    elif mode == 'high':
        M = dist >= r_out_px
    else:
        M = (dist >= r_in_px) & (dist <= r_out_px)
    return M.astype(np.float32)


def apply_fft_filter_numpy(gray_u8, mode, r1, r2, window='none'):
    # gray_u8: 2D numpy uint8
    f = np.fft.fft2(gray_u8.astype(np.float32))
    fshift = np.fft.fftshift(f)
    h, w = gray_u8.shape
    if mode == 'low':
        mask = make_circular_mask(h, w, 'low', 0, r1)
    elif mode == 'high':
        mask = make_circular_mask(h, w, 'high', 0, r1)
    else:
        mask = make_circular_mask(h, w, 'band', r1, r2)
    fshift_filtered = fshift * mask
    f_ishift = np.fft.ifftshift(fshift_filtered)
    img_back = np.fft.ifft2(f_ishift)
    img_back = np.real(img_back)
    # normalize to 0..255
    img_back = img_back - img_back.min()
    img_back = img_back / (img_back.max() + 1e-9) * 255.0
    return img_back.astype(np.uint8)


def tiny_dft_2d(gray_u8, n):
    # slow O(n^4) DFT for small previews (gray_u8 is 2D list or 2D array)
    # returns magnitude log image n x n
    out = [[0] * n for _ in range(n)]
    for u in range(n):
        for v in range(n):
            re = 0.0
            im = 0.0
            for x in range(n):
                for y in range(n):
                    val = gray_u8[y][x] if not np else gray_u8[y, x]
                    angle = 2 * math.pi * ((u * x) / n + (v * y) / n)
                    re += val * math.cos(angle)
                    im -= val * math.sin(angle)
            mag = math.log1p(math.hypot(re, im))
            out[v][u] = int(mag)
    # normalize
    flat = [v for row in out for v in row]
    mn = min(flat)
    mx = max(flat)
    rng = mx - mn if mx != mn else 1
    for y in range(n):
        for x in range(n):
            out[y][x] = int((out[y][x] - mn) / rng * 255)
    return out


class FreqDialog(QDialog):
    def __init__(self, parent=None):
        super().__init__(parent)
        self.setWindowTitle('FreqLab')
        self.window_combo = QComboBox()
        self.window_combo.addItems(['none', 'hann'])
        self.filter_combo = QComboBox()
        self.filter_combo.addItems(['low-pass', 'high-pass', 'band-pass'])
        self.radius_spin = QSpinBox()
        self.radius_spin.setRange(1, 4096)
        self.radius_spin.setValue(32)
        self.radius_spin2 = QSpinBox()
        self.radius_spin2.setRange(1, 4096)
        self.radius_spin2.setValue(64)
        self.preview_label = QLabel()
        self.preview_label.setFixedSize(256, 256)
        self.preview_btn = QPushButton('Preview')
        self.apply_btn = QPushButton('Apply')
        self.preview_btn.clicked.connect(self.on_preview)
        self.apply_btn.clicked.connect(self.on_apply)

        lay = QVBoxLayout()
        hl = QHBoxLayout()
        hl.addWidget(QLabel('Window'))
        hl.addWidget(self.window_combo)
        hl.addWidget(QLabel('Filter'))
        hl.addWidget(self.filter_combo)
        hl.addWidget(QLabel('R1'))
        hl.addWidget(self.radius_spin)
        hl.addWidget(QLabel('R2'))
        hl.addWidget(self.radius_spin2)
        lay.addLayout(hl)
        lay.addWidget(self.preview_label)
        btn_h = QHBoxLayout()
        btn_h.addWidget(self.preview_btn)
        btn_h.addWidget(self.apply_btn)
        lay.addLayout(btn_h)
        self.setLayout(lay)

    def on_preview(self):
        doc, node, (x, y, w, h) = active_doc_node_roi()
        if doc is None:
            return
        max_side = 256 if np is not None else 64
        scale = 1
        if max(w, h) > max_side:
            scale = max_side / max(w, h)
        sw = max(1, int(w * scale))
        sh = max(1, int(h * scale))
        pixel_bytes = node.pixelData(x, y, w, h)
        src = bytearray(pixel_bytes)
        # downsample nearest
        small = bytearray(sw * sh * 4)
        for yy in range(sh):
            for xx in range(sw):
                sx = int(xx / scale)
                sy = int(yy / scale)
                sidx = (sy * w + sx) * 4
                didx = (yy * sw + xx) * 4
                small[didx:didx+4] = src[sidx:sidx+4]
        gray = bgra_to_luma_array(small, sw, sh)
        if np is not None:
            mag = fft_preview_numpy(gray, window=self.window_combo.currentText())
            img = QImage(mag.data, sw, sh, sw, QImage.Format_Indexed8)
            pix = QPixmap.fromImage(img).scaled(self.preview_label.size(), Qt.KeepAspectRatio)
            self.preview_label.setPixmap(pix)
        else:
            n = min(sw, sh, 64)
            arr = tiny_dft_2d(gray, n)
            # make QImage from arr
            data = bytes([arr[y][x] for y in range(n) for x in range(n)])
            img = QImage(data, n, n, QImage.Format_Indexed8)
            pix = QPixmap.fromImage(img).scaled(self.preview_label.size(), Qt.KeepAspectRatio)
            self.preview_label.setPixmap(pix)

    def on_apply(self):
        if np is None:
            # applying without numpy is disabled
            return
        doc, node, (x, y, w, h) = active_doc_node_roi()
        if doc is None:
            return
        pixel_bytes = node.pixelData(x, y, w, h)
        src = bytearray(pixel_bytes)
        gray = bgra_to_luma_array(src, w, h)
        mode = self.filter_combo.currentText()
        r1 = int(self.radius_spin.value())
        r2 = int(self.radius_spin2.value())
        if mode == 'low-pass':
            out = apply_fft_filter_numpy(gray, 'low', r1, r2, window=self.window_combo.currentText())
        elif mode == 'high-pass':
            out = apply_fft_filter_numpy(gray, 'high', r1, r2, window=self.window_combo.currentText())
        else:
            out = apply_fft_filter_numpy(gray, 'band', r1, r2, window=self.window_combo.currentText())
        # merge luma back into BGRA as grayscale RGB, copy alpha
        out_bytes = bytearray(w * h * 4)
        for yy in range(h):
            for xx in range(w):
                v = int(out[yy, xx])
                idx = (yy * w + xx) * 4
                out_bytes[idx + 0] = v
                out_bytes[idx + 1] = v
                out_bytes[idx + 2] = v
                out_bytes[idx + 3] = src[idx + 3]
        node.setPixelData(bytes(out_bytes), x, y, w, h)
        doc.refreshProjection()


class FreqLab(Extension):
    def __init__(self, parent=None):
        super().__init__(parent)

    def setup(self):
        pass

    def createActions(self, window):
        action = window.createAction('freq_lab_action', 'FreqLab', 'tools/scripts')
        action.triggered.connect(self.run)

    def run(self):
        dlg = FreqDialog()
        dlg.exec_()
