Commit 9cb585e4 authored by Sparkf's avatar Sparkf 🏙️
Browse files

use circle curve to fit the lane

parent 6c50dff6
Loading
Loading
Loading
Loading
+65 −58
Original line number Diff line number Diff line
@@ -4,6 +4,8 @@ from scipy.signal import find_peaks
from scipy.optimize import curve_fit
import matplotlib.pyplot as plt
import circle_fit
import os
import json

input_img = cv2.imread('1492627187263118217/1.jpg')

@@ -19,18 +21,17 @@ def transform(input_img, num):
    M = cv2.getPerspectiveTransform(pts1, pts2)

    dst = cv2.warpPerspective(input_img, M, (600, 600))
    dst_extract = abs_thresh(dst, sobel_kernel=3, mag_thresh=(35, 210), direction='x')
    dst_extract = abs_thresh(dst, sobel_kernel=3, mag_thresh=(45, 210), direction='x')
    # dst_extract = np.where(dst_extract>0.5,1,0)
    # histogram = np.sum(dst_extract[dst_extract.shape[0] // 2:, :], axis=0)

    ##add line mark
    dst_mark = add_line_mark(dst_extract)

    plt.subplot(131), plt.imshow(input_img), plt.title('Input')
    plt.subplot(132), plt.imshow(dst), plt.title('Output')
    plt.subplot(133), plt.imshow(dst_extract), plt.title('op_extracted')
    # plt.subplot(131), plt.imshow(input_img), plt.title('Input')
    # plt.subplot(132), plt.imshow(dst), plt.title('Output')
    # plt.subplot(133), plt.imshow(dst_extract), plt.title('op_extracted')
    # plt.show()
    plt.savefig("op/op_reg" + str(num) + ".png", dpi=300)


def abs_thresh(img, sobel_kernel=9, mag_thresh=(0, 255), return_grad=False, direction='x'):
@@ -85,7 +86,7 @@ def add_line_mark(input_img):
        window_peak = find_peaks(histogram_window, distance=100)[0]
        return_mid = 0
        if window_peak.size > 0:
            if window_peak > 10:
            if histogram_window(window_peak) > 10 and (window_peak - (W / 2) < 100 ):
                return_mid = midpoint_W - (W / 2) + window_peak
        else:
            return_mid = midpoint_W
@@ -105,11 +106,12 @@ def add_line_mark(input_img):

    def find_hist_iteration(baseline=500, start_W=0):

        input_img_hist_iter = input_img[baseline:baseline + 99, start_W:start_W + 300]
        input_img_hist_iter = input_img[baseline:baseline + 149, start_W:start_W + 300]
        find_hist_iteration_iter = np.sum(input_img_hist_iter[input_img_hist_iter.shape[0] // 2:, :], axis=0)
        peak_hist = find_peaks(find_hist_iteration_iter, distance=180)[0]

        if find_hist_iteration_iter[peak_hist] < 10 or peak_hist.size == 0:
        peak_hist = find_peaks(find_hist_iteration_iter, distance=150)[0] or [0]

        if (find_hist_iteration_iter[peak_hist[0]] < 10) or peak_hist.size == 0:
            # print(find_peaks(find_hist_iteration_iter, distance=180))
            # print(baseline)
            find_hist_iteration(baseline - 50, start_W)
@@ -128,7 +130,7 @@ def add_line_mark(input_img):
    def find_point(start_W):
        midpoint_arr_R = a = np.zeros((0, 2)).astype(int)
        # print(midpoint_arr_R)
        midpoint_tmp_R, bottom_R = find_hist_iteration(500, start_W)
        midpoint_tmp_R, bottom_R = find_hist_iteration(450, start_W)
        # print(midpoint_tmp_R)
        while bottom_R > 0:
            window_H, window_WM, find_indicator = gen_windows(bottom_R, midpoint_tmp_R, 30, 100)
@@ -141,32 +143,29 @@ def add_line_mark(input_img):
            bottom_R = bottom_R - 30
        return midpoint_arr_R

    midpoint_arr_R = find_point(300)
    print(midpoint_arr_R.T)
    plt.scatter(x=midpoint_arr_R.T[0], y=midpoint_arr_R.T[1])
    ycl, xcl, rcl, _ = circle_fit.hyper_fit(midpoint_arr_R)
    plt.gca().add_artist(plt.Circle((ycl, xcl), rcl, color='b', fill=False))

    try:
        midpoint_arr_L = find_point(0)
        print(midpoint_arr_L.T)
        plt.scatter(x=midpoint_arr_L.T[0], y=midpoint_arr_L.T[1])
        ycl, xcl, rcl, _ = circle_fit.hyper_fit(midpoint_arr_L)
        if rcl<300:
            accuracy=0
        plt.gca().add_artist(plt.Circle((ycl, xcl), rcl, color='r', fill=False))
    except Exception:
        accuracy = 0
        pass

    # midpoint_tmp_L, bottom_L = find_hist_iteration(500, 0)
    # midpoint_arr_L = np.zeros((0,2)).astype(int)
    # # print(midpoint_tmp_L)
    # while bottom_L > 0:
    #     window_H, window_WM, find_indicator = gen_windows(bottom_L, midpoint_tmp_L, 30, 100)
    #     midpoint_tmp_L = int(window_WM)
    #
    #     if find_indicator == 1:
    #         tmp_list = np.array([[midpoint_tmp_L, bottom_L]]).astype(int)
    #
    #         midpoint_arr_L = np.vstack([midpoint_arr_L,tmp_list])
    #         # print(midpoint_tmp_L)
    #         plt.scatter(midpoint_tmp_L, bottom_L)
    #     bottom_L = bottom_L - 30
    try:
        midpoint_arr_R = find_point(300)
        print(midpoint_arr_R.T)
        plt.scatter(x=midpoint_arr_R.T[0], y=midpoint_arr_R.T[1])
        ycl, xcl, rcl, _ = circle_fit.hyper_fit(midpoint_arr_R)
        if rcl<300:
            accuracy=0
        plt.gca().add_artist(plt.Circle((ycl, xcl), rcl, color='b', fill=False))
    except Exception:
        accuracy = 0
        pass

    # np.polyfit

@@ -175,29 +174,37 @@ def add_line_mark(input_img):
    plt.ylim(top=599)  # ymax is your value
    plt.ylim(bottom=0)  # ymin is your value
    plt.gca().invert_yaxis()
    plt.show()

    # ##RIGHT
    # histogram = np.sum(input_img_hist_shape[input_img_hist_shape.shape[0] // 2:, :], axis=0)
    # peak_R = find_peaks(histogram[300:599], distance=150)[0][0]+299
    # if peak_R.size > 0 and histogram[peak_R] > 50:
    #     pass
    # else:
    #     peak_R = 400
    # # iteraton(L)
    # print(peak_R)
    # midpoint_tmp_R = peak_R
    # for bottom in range(599, 29, -30):
    #     window_H, window_WM = gen_windows(bottom, midpoint_tmp_R, 30, 50)
    #     midpoint_tmp_R = int(window_WM)
    #     plt.scatter(midpoint_tmp_R, bottom)
    #
    # plt.show()
    # plt.title(str(i))
    # plt.savefig("op/op_reg" + str(i) + ".png", dpi=300)
    plt.title(foldername + filename)
    plt.savefig("op/" + foldername + "_" + filename + ".png", dpi=300)
    plt.close()

    return 0


for i in range(10, 11):
    input_img = cv2.imread('1492626270684175793/' + str(i) + '.jpg')
rootdir = 'D:/train_set/clips/0531'
foldername = "null"
accuracy_arr = []
for subdir, dirs, files in os.walk(rootdir):
    for file in files:
        impath = os.path.join(subdir, file)
        input_img = cv2.imread(impath)
        input_img = cv2.cvtColor(input_img, cv2.COLOR_BGR2RGB)
        foldername = subdir[-19:]
        filename = file.split(".")[0]
        i = 0
        accuracy = 1
        transform(input_img, i)
        # print(accuracy)
        accuracy_arr.append(accuracy)


jsonString = json.dumps(accuracy)
with open('accuracy.json', 'w') as outfile:
    json.dump(jsonString, outfile)

# for i in range(1, 20):
#     input_img = cv2.imread('1492626805094402903/' + str(i) + '.jpg')
#     input_img = cv2.cvtColor(input_img, cv2.COLOR_BGR2RGB)
#     transform(input_img, i)