Commit 47580489 authored by Sparkf's avatar Sparkf 🏙️
Browse files

refine accuracy

parent 9cb585e4
Loading
Loading
Loading
Loading
+40 −20
Original line number Diff line number Diff line
@@ -6,13 +6,14 @@ import matplotlib.pyplot as plt
import circle_fit
import os
import json
from skimage.filters import threshold_otsu

input_img = cv2.imread('1492627187263118217/1.jpg')


def transform(input_img, num):
    rows, cols, ch = input_img.shape
    print(input_img.shape)
    # rows, cols, ch = input_img.shape
    # print(input_img.shape)

    # pts1 = np.float32([[629,250],[0,445],[705,250],[1279,430]])
    pts1 = np.float32([[450, 300], [-851, 720], [870, 300], [2057, 720]])
@@ -21,12 +22,16 @@ def transform(input_img, num):
    M = cv2.getPerspectiveTransform(pts1, pts2)

    dst = cv2.warpPerspective(input_img, M, (600, 600))
    dst_extract = abs_thresh(dst, sobel_kernel=3, mag_thresh=(45, 210), direction='x')
    dst_extract = abs_thresh(dst, sobel_kernel=3, mag_thresh=(35, 210), direction='x')
    # dst_extract = abs_otsu(dst)

    # dst_extract = np.where(dst_extract>0.5,1,0)
    # histogram = np.sum(dst_extract[dst_extract.shape[0] // 2:, :], axis=0)

    # plt.imshow(dst), plt.title(foldername + filename), plt.savefig("ware_dataset/op/" + foldername + "_" + filename + "_transform.png", dpi=300)

    ##add line mark
    dst_mark = add_line_mark(dst_extract)
    dst_mark = add_line_mark(dst_extract,dst)

    # plt.subplot(131), plt.imshow(input_img), plt.title('Input')
    # plt.subplot(132), plt.imshow(dst), plt.title('Output')
@@ -34,6 +39,7 @@ def transform(input_img, num):
    # plt.show()



def abs_thresh(img, sobel_kernel=9, mag_thresh=(0, 255), return_grad=False, direction='x'):
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
@@ -58,8 +64,13 @@ def abs_thresh(img, sobel_kernel=9, mag_thresh=(0, 255), return_grad=False, dire

    return grad_binary

def abs_otsu(img):
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    thresh = threshold_otsu(gray)
    binary = gray > thresh
    return binary

def add_line_mark(input_img):
def add_line_mark(input_img,dst):
    # determined start point

    # delete triangle
@@ -71,12 +82,10 @@ def add_line_mark(input_img):
    # fill blank
    input_img_hist_shape = input_img[400:599, 0:599]

    # plt.plot(histogram)
    # plt.show()

    # use window to find points (windows size: H30 W50)

    def gen_windows(bottom_H, midpoint_W, H, W):

        find_indicator = 1
        window = input_img[bottom_H - 30:bottom_H, midpoint_W - 50:midpoint_W + 50]
        H = 30
@@ -84,12 +93,17 @@ def add_line_mark(input_img):
        # window_arr = input_img[]
        histogram_window = np.sum(window[window.shape[0] // 2:, :], axis=0)
        window_peak = find_peaks(histogram_window, distance=100)[0]

        return_mid = 0

        if window_peak.size > 0:
            if histogram_window(window_peak) > 10 and (window_peak - (W / 2) < 100 ):
            if histogram_window[window_peak] > 5 and np.abs(window_peak - (W / 2)) < 50:
                return_mid = midpoint_W - (W / 2) + window_peak
            else:
                return_mid = midpoint_W
        else:
            return_mid = midpoint_W

            find_indicator = 0
        # plt.plot(histogram_window)
        # plt.show()
@@ -100,6 +114,12 @@ def add_line_mark(input_img):
    # gen_windows(299,peak[0],30,50)

    plt.imshow(input_img)
    #plt.imshow(dst)


    input_img_hist_all = np.sum(input_img[input_img.shape[0] // 2:, :], axis=0)
    peak_hist_all = find_peaks(input_img_hist_all, distance=150)
    print(peak_hist_all)

    # iteraton(L)
    # print(peak_L)
@@ -145,14 +165,12 @@ def add_line_mark(input_img):

    try:
        midpoint_arr_L = find_point(0)

        print(midpoint_arr_L.T)
        plt.scatter(x=midpoint_arr_L.T[0], y=midpoint_arr_L.T[1])
        ycl, xcl, rcl, _ = circle_fit.hyper_fit(midpoint_arr_L)
        if rcl<300:
            accuracy=0
        plt.gca().add_artist(plt.Circle((ycl, xcl), rcl, color='r', fill=False))
    except Exception:
        accuracy = 0
        pass

    try:
@@ -177,32 +195,34 @@ def add_line_mark(input_img):
    # plt.title(str(i))
    # plt.savefig("op/op_reg" + str(i) + ".png", dpi=300)
    plt.title(foldername + filename)
    plt.savefig("op/" + foldername + "_" + filename + ".png", dpi=300)
    # plt.savefig("op/" + foldername + "_" + filename + ".png", dpi=300)
    plt.savefig("ware_dataset/op/" + foldername + "_" + filename + ".png", dpi=300)
    plt.close()

    return 0


rootdir = 'D:/train_set/clips/0531'
# rootdir = 'D:/train_set/clips/0531'
rootdir = 'D:/train_set/clips/examine'
foldername = "null"
accuracy_arr = []

for subdir, dirs, files in os.walk(rootdir):

    for file in files:
        impath = os.path.join(subdir, file)
        input_img = cv2.imread(impath)
        input_img = cv2.cvtColor(input_img, cv2.COLOR_BGR2RGB)
        input_img = input_img[:,:,0]
        foldername = subdir[-19:]
        filename = file.split(".")[0]
        i = 0
        accuracy = 1
        transform(input_img, i)
        # print(accuracy)
        accuracy_arr.append(accuracy)


jsonString = json.dumps(accuracy)
with open('accuracy.json', 'w') as outfile:
    json.dump(jsonString, outfile)
# jsonString = json.dumps(accuracy)
# with open('accuracy.json', 'w') as outfile:
#     json.dump(jsonString, outfile)

# for i in range(1, 20):
#     input_img = cv2.imread('1492626805094402903/' + str(i) + '.jpg')