Computer Vision/MMdetection

[mmdetection] bbox title 변경방법

dohyeon2 2022. 7. 20. 18:07

목차

    오늘은 실험실에서 mmdetection 툴박스에서 제공하는 Mask R-CNN으로 추론하던 중 추론결과 표시되는 bbox title을 변경하는 task가 주어졌다. 위 문제를 해결하는 방법을 정리해보려 한다.

     

    원본코드를 이용한 추론결과는 위와 같다.


    mmdetection Mask R-CNN코드는 다음 순서에 따라 동작한다.

    1. config, checkpoint file을 불러오기

    2. init_detector를 이용해 detector model 생성

    3. 추론하고자 하는 image를 불러들여 inference_detector(model, image_name) 함수를 이용해 추론 

    4. show_result 함수를 이용해 추론된 결과를 원본 이미지 위에 그려주고 저장

    # config 파일을 설정하고, maskrcnn 모델로 학습한 checkpoint 불러오기. 
    config_file = '/scratch/dohyeon/mmdetection/custom_config.py'
    checkpoint_file = '/scratch/dohyeon/mmdetection/tutorial_exps/epoch_12.pth'
    
    # config 파일과 checkpoint를 기반으로 Detector 모델을 생성.
    model = init_detector(config_file, checkpoint_file, device='cuda:0')
    
    for i in range(1):
            img_name = path_dir + '/' + file_list[i]
            img_arr= cv2.imread(img_name, cv2.IMREAD_COLOR)
            img_arr_rgb = cv2.cvtColor(img_arr, cv2.COLOR_BGR2RGB)
            # cv2.imshow('img',img)
            fig= plt.figure(figsize=(12, 12))
            plt.imshow(img_arr_rgb)
    
            # inference_detector의 인자로 string(file경로), ndarray가 단일 또는 list형태로 입력 될 수 있음. 
            results = inference_detector(model, img_arr)
    
            # # 추론결과 디렉토리에 저장(confidenece score 0.7이상의 instance만 이미지에 그릴 것).
            model.show_result(img_arr, results,score_thr=0.7, bbox_color=(0,0,0),thickness=0.5,font_size=6, out_file= f'{save_dir1}{file_list[i]}')
    
            # 각 이미지마다 추론된 instance 개수 덧셈(여기서는 이미지 1장).
            number = len(results[1][0])
            num_chicken= num_chicken + number

     

    원래 추론결과는 아래 사진과 같이 class_name | confidence_score 순으로 표시된다. 

    하지만 나는 Instance_id | weight  라는 변수를 표시하고 싶어서 show_result 함수를 수정하였다. 

    show_result 함수는 mmdetection/mmdet/models/detectors/base.py 에 위치한 함수인데 

    내용을 들여다보면 다음과 같다. 

    def show_result(self,
                        img,
                        result,
                        score_thr=0.3,
                        bbox_color=(72, 101, 241),
                        text_color=(72, 101, 241),
                        mask_color=None,
                        thickness=2,
                        font_size=13,
                        win_name='',
                        show=False,
                        wait_time=0,
                        out_file=None):
            """Draw `result` over `img`.
    
            Args:
                img (str or Tensor): The image to be displayed.
                result (Tensor or tuple): The results to draw over `img`
                    bbox_result or (bbox_result, segm_result).
                score_thr (float, optional): Minimum score of bboxes to be shown.
                    Default: 0.3.
                bbox_color (str or tuple(int) or :obj:`Color`):Color of bbox lines.
                   The tuple of color should be in BGR order. Default: 'green'
                text_color (str or tuple(int) or :obj:`Color`):Color of texts.
                   The tuple of color should be in BGR order. Default: 'green'
                mask_color (None or str or tuple(int) or :obj:`Color`):
                   Color of masks. The tuple of color should be in BGR order.
                   Default: None
                thickness (int): Thickness of lines. Default: 2
                font_size (int): Font size of texts. Default: 13
                win_name (str): The window name. Default: ''
                wait_time (float): Value of waitKey param.
                    Default: 0.
                show (bool): Whether to show the image.
                    Default: False.
                out_file (str or None): The filename to write the image.
                    Default: None.
    
            Returns:
                img (Tensor): Only if not `show` or `out_file`
            """
            img = mmcv.imread(img)
            img = img.copy()
            if isinstance(result, tuple):
                bbox_result, segm_result = result
                if isinstance(segm_result, tuple):
                    segm_result = segm_result[0]  # ms rcnn
            else:
                bbox_result, segm_result = result, None
            bboxes = np.vstack(bbox_result)
            labels = [
                np.full(bbox.shape[0], i, dtype=np.int32)
                for i, bbox in enumerate(bbox_result)
            ]
            labels = np.concatenate(labels)
            # draw segmentation masks
            segms = None
            if segm_result is not None and len(labels) > 0:  # non empty
                segms = mmcv.concat_list(segm_result)
                if isinstance(segms[0], torch.Tensor):
                    segms = torch.stack(segms, dim=0).detach().cpu().numpy()
                else:
                    segms = np.stack(segms, axis=0)
            # if out_file specified, do not show image in window
            if out_file is not None:
                show = False
            # draw bounding boxes
            img = imshow_det_bboxes(
                img,
                bboxes,
                labels,
                segms,
                class_names=self.CLASSES,
                score_thr=score_thr,
                bbox_color=bbox_color,
                text_color=text_color,
                mask_color=mask_color,
                thickness=thickness,
                font_size=font_size,
                win_name=win_name,
                show=show,
                wait_time=wait_time,
                out_file=out_file)
    
            if not (show or out_file):
                return img

    여기서 # draw bounding boxes 아래 부분을 보면 img = imshow_det_bboxes 라는 함수와 이어진다.

    (해당 함수는 mmdetection/mmdet/core/visualization/image.py 모듈 안에 위치한다.)

     

    여기서 나는 내가 새로만든 모듈에서 weight(육계 체중)를 받아와야 하기 때문에 

    # draw bounding boxes
            img = imshow_det_bboxes(
                img,
                bboxes,
                labels,
                predict_weight_list,
                segms,
                class_names=self.CLASSES,
                score_thr=score_thr,
                bbox_color=bbox_color,
                text_color=text_color,
                mask_color=mask_color,
                thickness=thickness,
                font_size=font_size,
                win_name=win_name,
                show=show,
                wait_time=wait_time,
                out_file=out_file)

    위와 같이 predict_weight_list 라는 값을 인자로 추가하였다. 

     

    <주의사항> 

    python function 정의시 지켜야할 argument 순서가 있는데, 

    함수에서 미리 지정된 parameter는 그렇지 않은 것보다 항상 뒤에 와야한다. 

    def order(no1 = "아침", no2, no3):
        print(f"{no1}은 9시에 {no2}은 12시에 {no3}은 6시에 먹는다.")
      
    order("점심", "저녁")
    
    # 출력값
    SyntaxError: non-default argument follows default argument

    예를들어 위와 같이 no1 이라는 미리 지정된 parameter가 그렇지 않은 no2,no3보다 앞에 왔기 때문에 

    SyntaxError: non-default argument follows default argument가 발생한 것이다. 

    결론적으로 미리 값이 지정된 parameter는 그렇지 않은 parameter보다 항상 뒤에 와야한다.

     

    다시 이어서 mmdetection/mmdet/core/visualization/image.py 에 위치한 imshow_det_bbox 함수를 찾아가보면 

    내용은 아래와 같다.

    def imshow_det_bboxes(img,
                          bboxes,
                          labels,
                          segms=None,
                          class_names=None,
                          score_thr=0,
                          bbox_color='green',
                          text_color='green',
                          mask_color=None,
                          thickness=2,
                          font_size=13,
                          win_name='',
                          show=True,
                          wait_time=0,
                          out_file=None):
        """Draw bboxes and class labels (with scores) on an image.
    
        Args:
            img (str or ndarray): The image to be displayed.
            bboxes (ndarray): Bounding boxes (with scores), shaped (n, 4) or
                (n, 5).
            labels (ndarray): Labels of bboxes.
            segms (ndarray or None): Masks, shaped (n,h,w) or None
            class_names (list[str]): Names of each classes.
            score_thr (float): Minimum score of bboxes to be shown.  Default: 0
            bbox_color (str or tuple(int) or :obj:`Color`):Color of bbox lines.
               The tuple of color should be in BGR order. Default: 'green'
            text_color (str or tuple(int) or :obj:`Color`):Color of texts.
               The tuple of color should be in BGR order. Default: 'green'
            mask_color (str or tuple(int) or :obj:`Color`, optional):
               Color of masks. The tuple of color should be in BGR order.
               Default: None
            thickness (int): Thickness of lines. Default: 2
            font_size (int): Font size of texts. Default: 13
            show (bool): Whether to show the image. Default: True
            win_name (str): The window name. Default: ''
            wait_time (float): Value of waitKey param. Default: 0.
            out_file (str, optional): The filename to write the image.
                Default: None
    
        Returns:
            ndarray: The image with bboxes drawn on it.
        """
        assert bboxes.ndim == 2, \
            f' bboxes ndim should be 2, but its ndim is {bboxes.ndim}.'
        assert labels.ndim == 1, \
            f' labels ndim should be 1, but its ndim is {labels.ndim}.'
        assert bboxes.shape[0] == labels.shape[0], \
            'bboxes.shape[0] and labels.shape[0] should have the same length.'
        assert bboxes.shape[1] == 4 or bboxes.shape[1] == 5, \
            f' bboxes.shape[1] should be 4 or 5, but its {bboxes.shape[1]}.'
        img = mmcv.imread(img).astype(np.uint8)
    
        if score_thr > 0:
            assert bboxes.shape[1] == 5
            scores = bboxes[:, -1]
            inds = scores > score_thr
            bboxes = bboxes[inds, :]
            labels = labels[inds]
            if segms is not None:
                segms = segms[inds, ...]
    
        mask_colors = []
        if labels.shape[0] > 0:
            if mask_color is None:
                # Get random state before set seed, and restore random state later.
                # Prevent loss of randomness.
                # See: https://github.com/open-mmlab/mmdetection/issues/5844
                state = np.random.get_state()
                # random color
                np.random.seed(42)
                mask_colors = [
                    np.random.randint(0, 256, (1, 3), dtype=np.uint8)
                    for _ in range(max(labels) + 1)
                ]
                np.random.set_state(state)
            else:
                # specify  color
                mask_colors = [
                    np.array(mmcv.color_val(mask_color)[::-1], dtype=np.uint8)
                ] * (
                    max(labels) + 1)
    
        bbox_color = color_val_matplotlib(bbox_color)
        text_color = color_val_matplotlib(text_color)
    
        img = mmcv.bgr2rgb(img)
        width, height = img.shape[1], img.shape[0]
        img = np.ascontiguousarray(img)
    
        fig = plt.figure(win_name, frameon=False)
        plt.title(win_name)
        canvas = fig.canvas
        dpi = fig.get_dpi()
        # add a small EPS to avoid precision lost due to matplotlib's truncation
        # (https://github.com/matplotlib/matplotlib/issues/15363)
        fig.set_size_inches((width + EPS) / dpi, (height + EPS) / dpi)
    
        # remove white edges by set subplot margin
        plt.subplots_adjust(left=0, right=1, bottom=0, top=1)
        ax = plt.gca()
        ax.axis('off')
    
        polygons = []
        color = []
        for i, (bbox, label) in enumerate(zip(bboxes, labels)): # i는 인덱스, bbox = bboxes result값, lables= label result값
            bbox_int = bbox.astype(np.int32)
            poly = [[bbox_int[0], bbox_int[1]], [bbox_int[0], bbox_int[3]],
                    [bbox_int[2], bbox_int[3]], [bbox_int[2], bbox_int[1]]]
            np_poly = np.array(poly).reshape((4, 2))
            polygons.append(Polygon(np_poly))
            color.append(bbox_color)
            label_text = class_names[
                label] if class_names is not None else f'class {label}'
            if len(bbox) > 4:
                label_text += f'|{bbox[-1]:.02f}'
            ax.text(
                bbox_int[0],
                bbox_int[1],
                f'{label_text}',
                bbox={
                    'facecolor': 'black',
                    'alpha': 0.8,
                    'pad': 0.7,
                    'edgecolor': 'none'
                },
                color=text_color,
                fontsize=font_size,
                verticalalignment='top',
                horizontalalignment='left')
            if segms is not None:
                color_mask = mask_colors[labels[i]]
                mask = segms[i].astype(bool)
                img[mask] = img[mask] * 0.5 + color_mask * 0.5
    
        plt.imshow(img)
    
        p = PatchCollection(
            polygons, facecolor='none', edgecolors=color, linewidths=thickness)
        ax.add_collection(p)
    
        stream, _ = canvas.print_to_buffer()
        buffer = np.frombuffer(stream, dtype='uint8')
        img_rgba = buffer.reshape(height, width, 4)
        rgb, alpha = np.split(img_rgba, [3], axis=2)
        img = rgb.astype('uint8')
        img = mmcv.rgb2bgr(img)
    
        if show:
            # We do not use cv2 for display because in some cases, opencv will
            # conflict with Qt, it will output a warning: Current thread
            # is not the object's thread. You can refer to
            # https://github.com/opencv/opencv-python/issues/46 for details
            if wait_time == 0:
                plt.show()
            else:
                plt.show(block=False)
                plt.pause(wait_time)
        if out_file is not None:
            mmcv.imwrite(img, out_file)
    
        plt.close()
    
        return img

    여기서 나는 weight 값을 인자로 받아줘야 하기 때문에 아래와 같이 코드를 수정한다. 

    def imshow_det_bboxes(img,
                          bboxes,
                          labels,
                          predict_weight_list,
                          segms=None,
                          class_names=None,
                          score_thr=0,
                          bbox_color='green',
                          text_color='green',
                          mask_color=None,
                          thickness=2,
                          font_size=13,
                          win_name='',
                          show=True,
                          wait_time=0,
                          out_file=None):

    imshow_det_bboxes() 함수의 4번째 인자로 predict_weight_list를 넣어주었다. 

    그 뒤에 아래와 같이 수정하였다.

    # 개체번호 list.
    id =[]
    for i in range(len(labels)):
        id.append(i+1)
    
    # 체중 list 받아오기.
    weight = predict_weight_list
    
    polygons = []
    color = []
    for i, (bbox, label) in enumerate(zip(bboxes, labels)): # i는 인덱스, bbox = bboxes result값, lables= label result값
        bbox_int = bbox.astype(np.int32)
        poly = [[bbox_int[0], bbox_int[1]], [bbox_int[0], bbox_int[3]],
                [bbox_int[2], bbox_int[3]], [bbox_int[2], bbox_int[1]]]
        np_poly = np.array(poly).reshape((4, 2))
        polygons.append(Polygon(np_poly))
        color.append(bbox_color)
        label_text = f'id:{id[i]}' if id is not None else f'class {label}'
        if len(bbox) > 4:
            label_text += f'|{weight[i][0]:.02f}g'

    bbox title에 표시되는 text는

    label_text =    이 부분이기 때문에 

    label_text 값을 id number를 할당하였고, 

    if len(bbox) > 4:
                label_text += f'|{weight[i][0]:.02f}g'

       위 뜻은 id_num뒤에 이어서 weight[index]값을 붙여주겠다 라는 뜻이다. 

     

    이렇게하면 결과적으로 

    위와같이 id_num | weight 값을 title에 표시할 수 있다.