# MainWindow.py
from ui import MainWindowUI as UI
from PyQt5.QtWidgets import QFileDialog,QMainWindow
from PyQt5.QtGui import QPixmap, QImage
import cv2
from PyQt5.QtCore import QTimer
from random import randint

class MainWindow(QMainWindow,UI):
    def __init__(self):
        super(MainWindow, self).__init__()
        self.setupUi(self)

        self.btn_video.clicked.connect(self.load_video)
        self.btn_stop_video.clicked.connect(self.stop_video)   
        self.btn_select.clicked.connect(self.select)
        self.btn_begin.clicked.connect(self.begin_track)
        
        self.track_flag=False

        self.timer = QTimer()#设置定时器，用来显示视频
        self.timer.timeout.connect(self.updateFrame)#定时器定时执行的函数 
        
        self.btn_select_multi.clicked.connect(self.select_multi)
        self.btn_begin_multi.clicked.connect(self.begin_multi_track)
        self.multi_track_flag=False

    
    def display_image(self,img):
        if len(img.shape)==3:#彩色图像
            height, width, channel = img.shape
            bytes_per_line = channel * width
            q_img = QImage(img.data, width, height, bytes_per_line, QImage.Format_BGR888)
        else:               # 灰度图像或二值图像
            height, width= img.shape
            bytes_per_line = width
            q_img = QImage(img.data, width, height, bytes_per_line, QImage.Format_Grayscale8)            

        self.label.setPixmap(QPixmap.fromImage(q_img))

    def load_video(self):
        options = QFileDialog.Options()
        fileName, _ = QFileDialog.getOpenFileName(self, "Open Video File", "", "Video Files (*.mp4 *.avi *.mov);;All Files (*)", options=options)
        if fileName:
            self.cap = cv2.VideoCapture(fileName)
            if not self.cap.isOpened():
                print("Error: Could not open video.")
                return
            ret, frame = self.cap.read()
            if not ret:
                print("无法读取视频帧")
                return
            # 显示第一帧并让用户选择ROI
            self.frame=frame
            self.display_image(frame)
            
    def stop_video(self):
        # 获取当前按钮的文本
        current_text = self.btn_stop_video.text()
        
        # 根据当前文本，切换为“播放”或“停止”
        if current_text == '停止视频':
            self.btn_stop_video.setText('播放视频')
            self.timer.stop()
           
        else:
            self.btn_stop_video.setText('停止视频')
            self.timer.start(20)

    def select(self):
        
        self.timer.stop()
        self.btn_stop_video.setText('播放视频')
        ret, frame=self.cap.read()   
        if not ret:
            print("无法读取视频帧")
            return

        self.bbox = cv2.selectROI("Frame", frame, fromCenter=False, showCrosshair=True)
        cv2.destroyWindow("Frame")
        # 初始化KCF跟踪器
        self.tracker = cv2.TrackerKCF_create()
        self.tracker.init(frame, self.bbox)
       
        self.track_flag=True
        self.btn_begin.setText('停止跟踪')
        self.btn_stop_video.setText('停止视频')
        self.timer.start(20)        
        
        
    def begin_track(self):
     
        # 根据跟踪标志，切换为“开始”或“停止”
        if self.track_flag:
            self.btn_begin.setText('开始跟踪')
            self.track_flag=False
        else:
            self.btn_begin.setText('停止跟踪')
            self.track_flag=True      
        
        
    def track(self,frame):
        
        sucsess,bbox = self.tracker.update(frame)
        
        if sucsess:
            p1 = (int(bbox[0]), int(bbox[1]))
            p2 = (int(bbox[0] + bbox[2]), int(bbox[1] + bbox[3]))
            cv2.rectangle(frame, p1, p2, (255, 0, 0), 2, 1)
        else:
            cv2.putText(frame, "Fail to track", (100, 20),cv2.FONT_HERSHEY_SIMPLEX, 0.75, (0, 0, 255), 2)
        
        return frame

            
    def updateFrame(self):
            
        ret, frame = self.cap.read()
        
        if not ret:
            print("Error: Could not read frame from camera.")
            self.timer.stop()
            self.cap.release()
            return
            
        if self.track_flag:
            frame=self.track(frame)
            
        if self.multi_track_flag:
            frame=self.multi_track(frame)

        self.display_image(frame)#显示结果   
            

    def select_multi(self):
        self.timer.stop()
        self.btn_stop_video.setText('播放视频')
        ret, frame=self.cap.read()
        if not ret:
            print("无法读取视频帧")
            return

        #选择多个目标
        bboxes = []
        self.colors = []

        # OpenCV's selectROI function doesn't work for selecting multiple objects in Python
        # So we will call this function in a loop till we are done selecting all objects
        while True:
            # draw bounding boxes over objects
            # selectROI's default behaviour is to draw box starting from the center
            # when fromCenter is set to false, you can draw box starting from top left corner
            cv2.putText(frame, "Press space key to confirm object", (100, 20),cv2.FONT_HERSHEY_SIMPLEX, 0.75, (0, 0, 255), 2)
            bbox = cv2.selectROI('MultiTracker', frame)
            bboxes.append(bbox)
            color=(randint(0, 255), randint(0, 255), randint(0, 255))
            self.colors.append(color)
            # print("Press q to quit selecting boxes and start tracking")
            # print("Press any other key to select next object")
            print(len(bboxes))
            p1 = (int(bbox[0]), int(bbox[1]))
            p2 = (int(bbox[0] + bbox[2]), int(bbox[1] + bbox[3]))
            cv2.rectangle(frame, p1, p2, color, 2, 1)

            cv2.putText(frame, "Press any other key to select next object", (100, 50),cv2.FONT_HERSHEY_SIMPLEX, 0.75, (0, 0, 255), 2)
            cv2.putText(frame, "Press q or Esc to quit selecting boxes and start tracking", (100, 80),cv2.FONT_HERSHEY_SIMPLEX, 0.75, (0, 0, 255), 2)

            k = cv2.waitKey(0) & 0xFF
            if k in(27, 113):  # 27=ESC, 113=q
              break

        print('Selected bounding boxes {}'.format(bboxes))
        cv2.destroyWindow("MultiTracker")

        self.multiTracker = cv2.legacy.MultiTracker.create()

        # Initialize MultiTracker
        for bbox in bboxes:
            self.multiTracker.add(cv2.legacy.TrackerKCF_create(), frame, bbox)

        self.track_flag=False
        self.multi_track_flag=True
        self.btn_stop_video.setText('停止视频')
        self.btn_begin_multi.setText('停止多目标跟踪')
        self.timer.start(20)

        
    def begin_multi_track(self):
     
        # 根据跟踪标志，切换为“开始”或“停止”
        if self.multi_track_flag:
            self.btn_begin_multi.setText('开始多目标跟踪')
            self.multi_track_flag=False
        else:
            self.btn_begin_multi.setText('停止多目标跟踪')
            self.multi_track_flag=True

    def multi_track(self,frame):
        
        success, boxes = self.multiTracker.update(frame)
        
         # draw tracked objects
        for i, newbox in enumerate(boxes):
           p1 = (int(newbox[0]), int(newbox[1]))
           p2 = (int(newbox[0] + newbox[2]), int(newbox[1] + newbox[3]))
           cv2.rectangle(frame, p1, p2, self.colors[i], 2, 1)
        
        return frame

