# MainWindow.py
from ui import ImageBrowserUI as UI
from PyQt5.QtWidgets import QFileDialog, QMainWindow
from PyQt5.QtGui import QPixmap, QImage
import cv2
import numpy as np
import math


class MainWindow(QMainWindow, UI):
    def __init__(self):
        super(MainWindow, self).__init__()
        self.setupUi(self)

        self.open_button.clicked.connect(self.load_image)  # 连接打开槽函数
        self.save_button.clicked.connect(self.save_image)  # 连接保存槽函数
        self.rotate_button.clicked.connect(self.rotate_image)  # 连接旋转槽函数
        self.shear_button.clicked.connect(self.shear_image)  # 连接错切槽函数
        self.zoom_slider.valueChanged.connect(self.zoom_image)  # 连接缩放槽函数

        self.color_button.clicked.connect(self.rgb2hsv)  # 连接RGB->HSV槽函数
        self.histeq_button.clicked.connect(self.histeq)  # 连接直方图均衡槽函数

        self.image = None  # 保存处理后的图像
        self.original = None  # 保存打开的原始图像

    def load_image(self):
        options = QFileDialog.Options()
        fileName, _ = QFileDialog.getOpenFileName(self, "QFileDialog.getOpenFileName()", "",
                                                  "Image Files (*.png *.jpg *.jpeg)", options=options)
        if fileName:
            self.original_image = cv2.imread(fileName)
            self.image = self.original_image
            self.display_image(self.original_image)

    def display_image(self, img):
        if len(img.shape) == 3:  # 彩色图像
            height, width, channel = img.shape
            bytes_per_line = channel * width
            q_img = QImage(img.data, width, height, bytes_per_line, QImage.Format_BGR888)
        else:  # 灰度图像或二值图像
            height, width = img.shape
            bytes_per_line = width
            q_img = QImage(img.data, width, height, bytes_per_line, QImage.Format_Grayscale8)

        self.image_label.setPixmap(QPixmap.fromImage(q_img))

    def save_image(self): # 保存图像函数
        options = QFileDialog.Options()
        file_name, _ = QFileDialog.getSaveFileName(self, "Save Image File", "",
                                                   "Images (*.png *.jpg *.bmp *.jpeg);;All Files (*)", options=options)
        if file_name:
            cv2.imwrite(file_name, self.image)

    def zoom_image(self):
        zoom_factor = self.zoom_slider.value() / 100.0
        src = self.original_image
        (h, w) = src.shape[:2]
        new_dim = (int(w * zoom_factor), int(h * zoom_factor))
        resized = cv2.resize(src, new_dim, interpolation=cv2.INTER_AREA)

        self.image = resized
        self.display_image(resized)

    def rotate_image(self):
        angle = int(self.rotate_edit.text())
        src = self.original_image

        # 获取图像分辨率并确定旋转中心
        (h, w) = src.shape[:2]
        (cX, cY) = (w // 2, h // 2)

        # 计算旋转矩阵
        M = cv2.getRotationMatrix2D((cX, cY), -angle, 1.0)
        cos = np.abs(M[0, 0])
        sin = np.abs(M[0, 1])

        # 计算旋转后画布的大小
        nW = int((h * sin) + (w * cos))
        nH = int((h * cos) + (w * sin))

        # 计算平移量
        M[0, 2] += (nW / 2) - cX
        M[1, 2] += (nH / 2) - cY

        # 使用仿射变换实现图像旋转
        rot = cv2.warpAffine(src, M, (nW, nH))

        self.image = rot
        self.display_image(rot)

    def shear_image(self):
        src = self.original_image

        (rows, cols) = src.shape[:2]  # 获取图像的高和宽

        angle = int(self.shear_edit.text())  # 获取输入的错切角度
        ratio = math.tan(math.pi * angle / 180)  # 计算错切比率

        pts1 = np.float32([[0, 0], [0, rows], [cols, 0]])  # 选取原始图像的原始参考点（左上角、左下角和右下角）

        # 计算3个参考点错切后的坐标
        if angle > 0:
            pts2 = np.float32([[rows * ratio, 0], [0, rows], [cols + rows * ratio, 0]])
        else:
            pts2 = np.float32([[0, 0], [rows * (-ratio), rows], [cols, 0]])

            # 通过参考点计算错切矩阵
        M = cv2.getAffineTransform(pts1, pts2)

        # 计算水平错切后的画布大小
        dst_cols = int(cols + rows * math.fabs(ratio));
        dst_rows = rows;

        # 使用仿射变换实现图像水平错切
        sheared = cv2.warpAffine(src, M, (dst_cols, dst_rows))

        self.image = sheared
        self.display_image(sheared)

    def rgb2hsv(self):
        src = self.original_image

        hsv = cv2.cvtColor(src, cv2.COLOR_BGR2HSV)
        self.image = hsv

        self.display_image(hsv)


    def histeq(self):
        src = self.original_image

        gray = cv2.cvtColor(src, cv2.COLOR_BGR2GRAY)
        eq = cv2.equalizeHist(gray)

        self.image = eq
        self.display_image(eq)
