Commit 210cb4c1 authored by Tim Patton's avatar Tim Patton
Browse files

Use GPU for loading safetensors, disable export

parent e134b74c
Loading
Loading
Loading
Loading
+3 −2
Original line number Original line Diff line number Diff line
@@ -147,8 +147,9 @@ def torch_load(model_filename, model_info, map_override=None):
    map_override=shared.weight_load_location if not map_override else map_override
    map_override=shared.weight_load_location if not map_override else map_override
    if(checkpoint_types[model_info.exttype] == 'safetensors'):
    if(checkpoint_types[model_info.exttype] == 'safetensors'):
        # safely load weights
        # safely load weights
        # TODO: safetensors supports zero copy fast load to gpu, see issue #684
        # TODO: safetensors supports zero copy fast load to gpu, see issue #684.  
        return load_file(model_filename, device=map_override)
        # GPU only for now, see https://github.com/huggingface/safetensors/issues/95
        return load_file(model_filename, device='cuda')
    else:
    else:
        return torch.load(model_filename, map_location=map_override)
        return torch.load(model_filename, map_location=map_override)


+2 −1
Original line number Original line Diff line number Diff line
@@ -1187,7 +1187,8 @@ def create_ui(wrap_gradio_gpu_call):
                interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3)
                interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3)
                interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method")
                interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method")
                save_as_half = gr.Checkbox(value=False, label="Save as float16")
                save_as_half = gr.Checkbox(value=False, label="Save as float16")
                save_as_safetensors = gr.Checkbox(value=False, label="Save as safetensors format")
                # invisible until feature can be verified
                save_as_safetensors = gr.Checkbox(value=False, label="Save as safetensors format", visible=False)
                modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary')
                modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary')


            with gr.Column(variant='panel'):
            with gr.Column(variant='panel'):