from krita import Extension, Krita
from PyQt5.QtWidgets import (QDialog, QVBoxLayout, QHBoxLayout, QLabel, QPushButton,
                             QComboBox, QLineEdit, QSpinBox, QFileDialog, QCheckBox,
                             QTextEdit, QProgressDialog, QTabWidget, QWidget, QFormLayout, QApplication)
from PyQt5.QtCore import Qt, QSettings
from PyQt5.QtGui import QImage, QPixmap
import os, math, time

EXTENSION_ID = "pykrita_batch_lab"


def selected_layers(doc):
    # returns list of nodes: placeholder — Krita API provides ways to traverse selections
    root = doc.rootNode()
    return [root]


def top_level_layers(doc):
    return [child for child in doc.rootNode().childNodes()]


def ensure_dir(path):
    os.makedirs(path, exist_ok=True)


def compute_alpha_bbox(bytes_data, w, h):
    # bytes_data in BGRA order
    x0 = w
    y0 = h
    x1 = 0
    y1 = 0
    found = False
    for y in range(h):
        for x in range(w):
            idx = (y * w + x) * 4
            a = bytes_data[idx + 3]
            if a != 0:
                found = True
                if x < x0: x0 = x
                if y < y0: y0 = y
                if x > x1: x1 = x
                if y > y1: y1 = y
    if not found:
        return 0, 0, 0, 0
    return x0, y0, x1 - x0 + 1, y1 - y0 + 1


def bgra_to_rgba_bytes(bytes_data, w, h):
    # convert BGRA -> RGBA contiguous bytes for QImage
    out = bytearray(len(bytes_data))
    for i in range(0, len(bytes_data), 4):
        b = bytes_data[i]
        g = bytes_data[i+1]
        r = bytes_data[i+2]
        a = bytes_data[i+3]
        out[i]   = r
        out[i+1] = g
        out[i+2] = b
        out[i+3] = a
    return bytes(out)


def qimage_from_bgra(bytes_data, w, h, bit_depth=8):
    # returns QImage. For 8-bit per channel use Format_RGBA8888; for 16-bit try Format_RGBA64 if available
    rgba = bgra_to_rgba_bytes(bytes_data, w, h)
    if bit_depth == 8:
        try:
            img = QImage(rgba, w, h, QImage.Format_RGBA8888)
            return img
        except Exception:
            # fallback to ARGB32
            img = QImage(rgba, w, h, QImage.Format_ARGB32)
            return img
    else:
        # upconvert to 16-bit per channel if Qt supports Format_RGBA64
        if hasattr(QImage, 'Format_RGBA64'):
            b = bytearray(w * h * 8)
            # little-endian 16-bit per channel
            j = 0
            for i in range(0, len(bytes_data), 4):
                r = bytes_data[i+2]
                g = bytes_data[i+1]
                bch = bytes_data[i]
                a = bytes_data[i+3]
                r16 = (r << 8) | r
                g16 = (g << 8) | g
                b16 = (bch << 8) | bch
                a16 = (a << 8) | a
                b[j:j+2] = r16.to_bytes(2, 'little')
                b[j+2:j+4] = g16.to_bytes(2, 'little')
                b[j+4:j+6] = b16.to_bytes(2, 'little')
                b[j+6:j+8] = a16.to_bytes(2, 'little')
                j += 8
            try:
                img = QImage(bytes(b), w, h, QImage.Format_RGBA64)
                return img
            except Exception:
                # fallback to 8-bit
                img = QImage(rgba, w, h, QImage.Format_RGBA8888)
                return img
        else:
            # no 16-bit support, return 8-bit
            img = QImage(rgba, w, h, QImage.Format_RGBA8888)
            return img


def save_qimage(img, path, fmt, quality=90):
    # fmt: 'png','webp','jpeg','tiff'
    fmt_map = {'png':'PNG', 'webp':'WEBP', 'jpeg':'JPEG', 'tiff':'TIFF'}
    qfmt = fmt_map.get(fmt.lower(), fmt.upper())
    # QImage.save returns bool
    ok = img.save(path, qfmt, int(quality))
    return ok


class BatchLabDialog(QDialog):
    def __init__(self, parent=None):
        super().__init__(parent)
        self.setWindowTitle('BatchLab')
        self.settings = QSettings('PGA','BatchLab')

        self.tabs = QTabWidget()
        self.batch_tab = QWidget()
        self.sheet_tab = QWidget()
        self.setup_batch_tab()
        self.setup_sheet_tab()
        self.tabs.addTab(self.batch_tab, 'Batch Export')
        self.tabs.addTab(self.sheet_tab, 'Sprite-Sheet')

        layout = QVBoxLayout()
        layout.addWidget(self.tabs)
        self.log = QTextEdit()
        self.log.setReadOnly(True)
        layout.addWidget(self.log)
        self.setLayout(layout)

    def setup_batch_tab(self):
        l = QVBoxLayout()
        form = QFormLayout()
        self.scope_combo = QComboBox()
        self.scope_combo.addItems(['Selected Layers','Visible Layers','Top-level Groups','All'])
        form.addRow('Scope', self.scope_combo)

        self.format_combo = QComboBox()
        self.format_combo.addItems(['png','webp','jpeg','tiff'])
        form.addRow('Format', self.format_combo)

        self.quality_spin = QSpinBox()
        self.quality_spin.setRange(1,100)
        self.quality_spin.setValue(90)
        form.addRow('Quality', self.quality_spin)
        # TIFF bit depth
        self.tiff_depth = QComboBox()
        self.tiff_depth.addItems(['8','16'])
        self.tiff_depth.setCurrentText(self.settings.value('tiff_bit_depth','8'))
        form.addRow('TIFF bit depth', self.tiff_depth)
        self.tiff_depth.currentTextChanged.connect(lambda v: self.settings.setValue('tiff_bit_depth', v))

        self.scale_combo = QComboBox()
        self.scale_combo.addItems(['1x','2x','0.5x','Custom%'])
        self.scale_combo.currentIndexChanged.connect(self.on_scale_mode)
        form.addRow('Scale', self.scale_combo)
        self.custom_scale = QSpinBox()
        self.custom_scale.setRange(1,1000)
        self.custom_scale.setValue(100)
        form.addRow('Custom %', self.custom_scale)

        self.trim_cb = QCheckBox('Trim transparency')
        form.addRow(self.trim_cb)
        self.padding_spin = QSpinBox()
        self.padding_spin.setRange(0,500)
        self.padding_spin.setValue(0)
        form.addRow('Padding (px)', self.padding_spin)

        self.template_edit = QLineEdit('{doc}_{layer}_{v:03d}@{scale}x.{ext}')
        form.addRow('Filename template', self.template_edit)

        self.dest_btn = QPushButton('Choose target folder')
        self.dest_btn.clicked.connect(self.choose_folder)
        self.dest_label = QLabel(self.settings.value('last_dest',''))
        form.addRow(self.dest_btn, self.dest_label)

        self.subfolder_cb = QCheckBox('Create subfolder per doc')
        form.addRow(self.subfolder_cb)

        l.addLayout(form)
        btn_h = QHBoxLayout()
        self.preview_btn = QPushButton('Preview')
        self.run_btn = QPushButton('Run')
        self.preview_btn.clicked.connect(self.on_batch_preview)
        self.run_btn.clicked.connect(self.on_batch_run)
        btn_h.addWidget(self.preview_btn)
        btn_h.addWidget(self.run_btn)
        l.addLayout(btn_h)
        self.batch_tab.setLayout(l)

    def setup_sheet_tab(self):
        l = QVBoxLayout()
        form = QFormLayout()
        self.sheet_source_combo = QComboBox()
        self.sheet_source_combo.addItems(['Selected Layers','Directory (PNG)'])
        form.addRow('Input', self.sheet_source_combo)
        self.grid_rb = QComboBox()
        self.grid_rb.addItems(['Grid','Tight pack'])
        form.addRow('Layout', self.grid_rb)
        self.tile_size_edit = QLineEdit('auto')
        form.addRow('Tile size (W×H or auto)', self.tile_size_edit)
        self.padding_spin2 = QSpinBox()
        self.padding_spin2.setRange(0,256)
        self.padding_spin2.setValue(2)
        form.addRow('Padding', self.padding_spin2)
        self.extrude_cb = QCheckBox('Extrude 1px')
        form.addRow(self.extrude_cb)
        self.pow2_cb = QCheckBox('Power-of-two canvas')
        form.addRow(self.pow2_cb)
        self.sheet_dest_btn = QPushButton('Choose output folder')
        self.sheet_dest_btn.clicked.connect(self.choose_folder_sheet)
        self.sheet_dest_label = QLabel(self.settings.value('last_sheet_dest',''))
        form.addRow(self.sheet_dest_btn, self.sheet_dest_label)
        l.addLayout(form)
        btn_h = QHBoxLayout()
        self.sheet_preview_btn = QPushButton('Preview layout')
        self.sheet_run_btn = QPushButton('Run')
        self.sheet_preview_btn.clicked.connect(self.on_sheet_preview)
        self.sheet_run_btn.clicked.connect(self.on_sheet_run)
        btn_h.addWidget(self.sheet_preview_btn)
        btn_h.addWidget(self.sheet_run_btn)
        l.addLayout(btn_h)
        self.sheet_tab.setLayout(l)

    def on_scale_mode(self):
        idx = self.scale_combo.currentIndex()
        self.custom_scale.setEnabled(idx == 3)

    def choose_folder(self):
        d = QFileDialog.getExistingDirectory(self, 'Choose folder', os.path.expanduser('~'))
        if d:
            self.dest_label.setText(d)
            self.settings.setValue('last_dest', d)

    def choose_folder_sheet(self):
        d = QFileDialog.getExistingDirectory(self, 'Choose folder', os.path.expanduser('~'))
        if d:
            self.sheet_dest_label.setText(d)
            self.settings.setValue('last_sheet_dest', d)

    def build_target_filename(self, doc_name, layer_name, idx, scale, ext):
        tpl = self.template_edit.text()
        return tpl.format(doc=doc_name, layer=layer_name, v=idx, scale=scale, ext=ext)

    def collect_layers_for_scope(self, doc):
        s = self.scope_combo.currentText()
        if s == 'Selected Layers':
            return selected_layers(doc)
        elif s == 'Visible Layers':
            return [n for n in doc.rootNode().childNodes() if n.visible()]
        elif s == 'Top-level Groups':
            return [n for n in doc.rootNode().childNodes() if n.type() == 'grouplayer']
        else:
            return top_level_layers(doc)

    def on_batch_preview(self):
        doc = Krita.instance().activeDocument()
        if doc is None:
            self.log.append('No document')
            return
        layers = self.collect_layers_for_scope(doc)
        dest = self.dest_label.text() or self.settings.value('last_dest','')
        self.log.append('Previewing export:')
        for i, layer in enumerate(layers, start=1):
            name = layer.name()
            fname = self.build_target_filename(doc.name(), name, i, self.scale_combo.currentText(), self.format_combo.currentText())
            self.log.append(f'  -> {fname}')

    def on_batch_run(self):
        doc = Krita.instance().activeDocument()
        if doc is None:
            self.log.append('No document')
            return
        layers = self.collect_layers_for_scope(doc)
        dest = self.dest_label.text() or self.settings.value('last_dest','')
        if not dest:
            self.log.append('No destination chosen')
            return
        ensure_dir(dest)
        total = len(layers)
        progress = QProgressDialog('Exporting...', 'Cancel', 0, total, self)
        progress.setWindowModality(Qt.WindowModal)
        canceled = False
        for i, layer in enumerate(layers, start=1):
            progress.setValue(i-1)
            if progress.wasCanceled():
                canceled = True
                break
            try:
                # Render layer to pixel data. Prefer layer.pixelData(x,y,w,h) but many layer types
                # do not expose direct width/height; use node.bounds() or renderToImage in full implementation.
                w = layer.width()
                h = layer.height()
                bytes_data = layer.pixelData(0, 0, w, h)
                # optional trim
                bx, by, bw, bh = 0, 0, w, h
                if self.trim_cb.isChecked():
                    tx, ty, tw, th = compute_alpha_bbox(bytes_data, w, h)
                    if tw > 0 and th > 0:
                        bx, by, bw, bh = tx, ty, tw, th
                # extract cropped bytes
                cropped = bytearray(bw * bh * 4)
                for yy in range(bh):
                    for xx in range(bw):
                        sidx = ((by + yy) * w + (bx + xx)) * 4
                        didx = (yy * bw + xx) * 4
                        cropped[didx:didx+4] = bytes_data[sidx:sidx+4]

                # scaling
                scale_mode = self.scale_combo.currentText()
                if scale_mode == '1x':
                    scale_factor = 1.0
                elif scale_mode == '2x':
                    scale_factor = 2.0
                elif scale_mode == '0.5x':
                    scale_factor = 0.5
                else:
                    scale_factor = float(self.custom_scale.value()) / 100.0

                out_w = max(1, int(bw * scale_factor))
                out_h = max(1, int(bh * scale_factor))

                # create QImage and scale using Qt (better quality)
                bit_depth = 8
                if self.format_combo.currentText() == 'tiff':
                    # read bit-depth preference from settings or default 8
                    bit_depth = int(self.settings.value('tiff_bit_depth', 8))
                img = qimage_from_bgra(cropped, bw, bh, bit_depth=bit_depth)
                if scale_factor != 1.0:
                    img = img.scaled(out_w, out_h, Qt.IgnoreAspectRatio, Qt.SmoothTransformation)

                # prepare path
                fname = self.build_target_filename(doc.name(), layer.name(), i, f'{scale_factor}x', self.format_combo.currentText())
                path = os.path.join(dest, fname)
                # create subfolder per doc option
                if self.subfolder_cb.isChecked():
                    base = os.path.join(dest, doc.name())
                    ensure_dir(base)
                    path = os.path.join(base, fname)

                ok = save_qimage(img, path, self.format_combo.currentText(), quality=self.quality_spin.value())
                if ok:
                    self.log.append(f'Exported {path}')
                else:
                    self.log.append(f'Failed to save {path}')
            except Exception as e:
                self.log.append(f'Error exporting {getattr(layer,"name",lambda: "layer")()}: {e}')
            QApplication.processEvents()
        progress.setValue(total)
        if canceled:
            self.log.append('Export canceled')
        else:
            self.log.append('Export finished')

    def on_sheet_preview(self):
        self.log.append('Previewing sprite sheet layout (simulated)')
        # collect sources; if directory, list pngs; if selected layers, list names
        mode = self.sheet_source_combo.currentText()
        items = []
        if mode == 'Selected Layers':
            doc = Krita.instance().activeDocument()
            if doc:
                items = [n.name() for n in selected_layers(doc)]
        else:
            d = self.sheet_dest_label.text() or self.settings.value('last_sheet_dest','')
            if d and os.path.isdir(d):
                items = [f for f in os.listdir(d) if f.lower().endswith('.png')]
        self.log.append(f'Found {len(items)} items')
        # simulate layout
        if self.grid_rb.currentText() == 'Grid':
            cols = int(math.ceil(math.sqrt(len(items)))) if items else 0
            self.log.append(f'Grid layout: {cols} cols')
        else:
            self.log.append('Tight pack (simulated)')

    def on_sheet_run(self):
        self.log.append('Running sprite-sheet export (placeholder)')


class BatchLab(Extension):
    def __init__(self, parent=None):
        super().__init__(parent)

    def setup(self):
        pass

    def createActions(self, window):
        action = window.createAction('batch_lab_action', 'BatchLab', 'tools/scripts')
        action.triggered.connect(self.run)

    def run(self):
        dlg = BatchLabDialog()
        dlg.exec_()
