# MainWindow.py
from ui import MainWindowUI as UI
from PyQt5.QtWidgets import QFileDialog,QMainWindow
from PyQt5.QtGui import QPixmap, QImage
import cv2
import numpy as np


class MainWindow(QMainWindow,UI):
    def __init__(self):
        super(MainWindow, self).__init__()
        self.setupUi(self)

        self.btn_load.clicked.connect(self.load_image)# 连接打开槽函数
        self.btn_detect.clicked.connect(self.lineDetection)
        self.btn_draw.clicked.connect(self.laneDetect)

        
    
    def load_image(self):
        options = QFileDialog.Options()
        fileName, _ = QFileDialog.getOpenFileName(self, "QFileDialog.getOpenFileName()", "", "Image Files (*.png *.jpg *.jpeg)", options=options)
        if fileName:
            self.image = cv2.imread(fileName)
            self.display_image(self.image)
    
    def display_image(self,img):
        if len(img.shape)==3:#彩色图像
            height, width, channel = img.shape
            bytes_per_line = channel * width
            q_img = QImage(img.data, width, height, bytes_per_line, QImage.Format_BGR888)
        else:               # 灰度图像或二值图像
            height, width= img.shape
            bytes_per_line = width
            q_img = QImage(img.data, width, height, bytes_per_line, QImage.Format_Grayscale8)            

        self.label.setPixmap(QPixmap.fromImage(q_img))


    def lineDetection(self):
        img=self.image.copy()
 
        edges = self.preprocess(img)#图像预处理
        
        #ROI提取
        roi_edges = self.get_ROI(edges)
        
        # 霍夫变换参数
        rho = 1
        theta = np.pi / 180
        threshold = 15
        min_line_len = 40
        max_line_gap = 20
        # 霍夫直线提取
        lines = cv2.HoughLinesP(roi_edges, rho, theta, threshold, minLineLength=min_line_len, maxLineGap=max_line_gap)
        self.lines=lines
        
        #画检测到的直线
        for line in lines:
            for x1, y1, x2, y2 in line:
                cv2.line(img, (x1, y1), (x2, y2), (0,0,255), 1)

        self.display_image(img)  # 更新显示识别后的图像
        
        
    def preprocess(self,img):

        gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) #灰度化
        blur_gray = cv2.GaussianBlur(gray, (5, 5), 1) #高斯滤波去噪
        edges = cv2.Canny(blur_gray, 50, 150)#边缘检测
        
        #查看中间结果，最终完成时注释掉
        # cv2.imshow('edges',edges)
        # cv2.waitKey()
        
        return edges
        
    def get_ROI(self,edges):
        
        # 标记四个坐标点用于ROI截取
        rows, cols = edges.shape
        points = np.array([[(0, rows), (cols//2, rows//2), (cols//2, rows//2), (cols, rows)]])
        
        # 创建掩膜
        mask = np.zeros_like(edges)
        cv2.fillPoly(mask, points, 255)
        
        #ROI提取
        roi_edges = cv2.bitwise_and(edges, mask)
        
        #查看中间结果，最终完成时注释掉
        # cv2.imshow('roi',roi_edges)
        # cv2.waitKey()
        
        return roi_edges
        
    
    def laneDetect(self ):
        lines=self.lines
        img = np.zeros((self.image.shape[0], self.image.shape[1], 3), dtype=np.uint8)
        
        # a. 划分左右车道
        left_lines, right_lines = [], []
        for line in lines:
            for x1, y1, x2, y2 in line:
                k = (y2 - y1) / (x2 - x1)
                if k < 0:
                    left_lines.append(line)
                else:
                    right_lines.append(line)
        if (len(left_lines) <= 0 or len(right_lines) <= 0):
            return
        
        # b. 清理异常数据
        self.clean_lines(left_lines, 0.1)
        self.clean_lines(right_lines, 0.1)
        
        # c. 得到左右车道线点的集合，拟合直线
        left_points = [(x1, y1) for line in left_lines for x1, y1, x2, y2 in line]
        left_points = left_points + [(x2, y2) for line in left_lines for x1, y1, x2, y2 in line]
        right_points = [(x1, y1) for line in right_lines for x1, y1, x2, y2 in line]
        right_points = right_points + [(x2, y2) for line in right_lines for x1, y1, x2, y2 in line]
        left_results = self.least_squares_fit(left_points, 325, img.shape[0])
        right_results = self.least_squares_fit(right_points, 325, img.shape[0])
        
        # 注意这里点的顺序
        vtxs = np.array([[left_results[1], left_results[0], right_results[0], right_results[1]]])
        
        # d. 填充车道区域
        cv2.fillPoly(img, vtxs, (0, 255, 0))
        
        result = cv2.addWeighted(self.image, 0.9, img, 0.2, 0)
        self.display_image(result)  # 更新显示识别后的图像     
        
    def clean_lines(self,lines, threshold):
        # 迭代计算斜率均值，排除掉与差值差异较大的数据
        slope = [(y2 - y1) / (x2 - x1) for line in lines for x1, y1, x2, y2 in line]
        while len(lines) > 0:
            mean = np.mean(slope)
            diff = [abs(s - mean) for s in slope]
            idx = np.argmax(diff)
            if diff[idx] > threshold:
                slope.pop(idx)
                lines.pop(idx)
            else:
                break
                
    def least_squares_fit(self,point_list, ymin, ymax):
        # 最小二乘法拟合
        x = [p[0] for p in point_list]
        y = [p[1] for p in point_list]
        # polyfit第三个参数为拟合多项式的阶数，所以1代表线性
        fit = np.polyfit(y, x, 1)
        fit_fn = np.poly1d(fit)  # 获取拟合的结果
        xmin = int(fit_fn(ymin))
        xmax = int(fit_fn(ymax))
        return [(xmin, ymin), (xmax, ymax)]