Commit 265d2d71 authored by Bucky Lee's avatar Bucky Lee
Browse files

feat: update path

parent fbab58df
Loading
Loading
Loading
Loading
+4 −5
Original line number Diff line number Diff line
@@ -3,7 +3,6 @@ import os
from core.predict import predict
from core.process import preprocess
from PIL import Image
import shutil


def process(file_path, model_name):
@@ -14,13 +13,13 @@ def process(file_path, model_name):
    """
    # TODO: fill this
    image = Image.open(file_path)
    image.save('data/dataset/target/image1.png')
    file_name = file_path.split('/')[-1].split('.')[0]
    image.save('data/dataset/target/image1.png')
    preprocess('data/dataset/target/image1.png', 'data/dataset/target_mask/image1.png')
    predict('RCDG_model')
    img = Image.open('core/results/RCDG_drive/test_latest/images/image1_fake_TB.png')
    img.resize((512, 512))
    img.save(str(f'./data/processed/{file_name}.png'))
    img = img = img.resize((512, 512))
    img.save(str(f'data/processed/{file_name}.png'))
    os.remove('data/dataset/target/image1.png')
    os.remove('data/dataset/target_mask/image1.png')

@@ -28,5 +27,5 @@ def process(file_path, model_name):


if __name__ == '__main__':
    process('../data/unprocessed/test.jpg', "RCDG_model")
    process('data/unprocessed/new_test.jpg', "RCDG_model")
    # image = Image.open('../data/unprocessed/test.jpg')
+21 −12
Original line number Diff line number Diff line
import os
import sys
import cv2

import torch

# import model run in server
import numpy as np
# from predict_models.res.predict_models import arcnet_model
from predict_models.res.options import test_options
# from predict_models.res_dg.models.RCDG_model import RCDGModel
from predict_models.res_dg.test import test

import argparse

# from predict_models.res.test import test
# from models.res.test import test

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
torch.set_num_threads(4)
@@ -19,16 +20,24 @@ def predict(model):
    # TODO: 模型预测,结果储存到特定文件夹中
    # test(file_path, '00_low_quality_dir')
    # opt = test_options.TestOptions()
    os.system(
        'python predict_models/res_dg/test.py --dataroot ./data/dataset/'
        ' --name RCDG_drive --model RCDG --dataset_mode cataract_guide_padding --eval')
    # os.system(
    #     'python ../predict_models/res/test.py --dataroot ../data/unprocessed/'
    #     'python ../models/res_dg/test.py --dataroot ../data/dataset/'
    #     ' --name RCDG_drive --model RCDG --dataset_mode cataract_guide_padding --eval')
    # os.system(
    #     'python ../models/res/test.py --dataroot ../data/unprocessed/'
    #     ' --name arcnet --model arcnet --dataset_mode cataract_guide_padding --eval')

    # model = arcnet_model.ArcNetModel()
    # dic = dict()
    # dic['dataroot'] = '../data/dataset'
    # dic['name'] = 'RCDG_drive'
    # dic['model'] = 'RCDG'
    # dic['dataset_mode'] = 'cataract_guide_padding'
    # dic['eval'] = True
    # parser = argparse.ArgumentParser()
    test()

    # model = RCDGModel()
    # model.set_input()
    # model.test()

# predict('arcnet')
+235 KiB
Loading image diff...
+2.39 KiB
Loading image diff...
+50.6 KiB
Loading image diff...
Loading