Unverified Commit 6657a463 authored by GaoYunshu1's avatar GaoYunshu1 Committed by GitHub
Browse files

Add files via upload

parent ffb97889
Loading
Loading
Loading
Loading
+15 −6
Original line number Diff line number Diff line
@@ -13,19 +13,28 @@ def process(file_path, model_name):
    """
    # TODO: fill this
    image = Image.open(file_path)
    image = image.resize((512, 512))
    file_name = file_path.split('/')[-1].split('.')[0]
    image.save('data/dataset/target/image1.png')
    preprocess('data/dataset/target/image1.png', 'data/dataset/target_mask/image1.png')
    if model_name == 'RCDG_model':

        predict('RCDG_model')
        img = Image.open('results/RCDG_drive/test_latest/images/image1_fake_TB.png')
    img = img = img.resize((512, 512))
        img = img.resize((512, 512))
        img.save(str(f'data/processed/{file_name}.png'))
    # os.remove('data/dataset/target/image1.png')
    # os.remove('data/dataset/target_mask/image1.png')

    elif model_name == 'ArcNet':
        predict('ArcNet')
        img = Image.open('results/arcnet/test_latest/images/image1_fake_TB.png')
        # img = img.resize((512, 512))
        img.save(str(f'data/processed/{file_name}.png'))
    else:
        return 'No such model'
    return 'Success'


if __name__ == '__main__':
    process('data/unprocessed/new_test.jpg', "RCDG_model")
    process('data/unprocessed/new_test.jpg', "AecNet")
    # image = Image.open('../data/unprocessed/test.jpg')
+7 −1
Original line number Diff line number Diff line
@@ -5,6 +5,7 @@ import torch
# import model run in server
# from predict_models.res_dg.models.RCDG_model import RCDGModel
from predict_models.res_dg.test import test
from predict_models.res.test import test_ArcNet

import argparse

@@ -17,6 +18,11 @@ torch.cuda.empty_cache()


def predict(model):
    # TODO: 模型预测,结果储存到特定文件夹中
    if model == 'RCDG_model':
        test()
    elif model == 'ArcNet':
        test_ArcNet()