Commit e4303443 authored by Aarni Koskela's avatar Aarni Koskela
Browse files

API: use finally: for state.end()

parent f44feb6a
Loading
Loading
Loading
Loading
+14 −15
Original line number Diff line number Diff line
@@ -602,21 +602,22 @@ class Api:
            shared.state.begin(job="create_embedding")
            filename = create_embedding(**args) # create empty embedding
            sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used
            shared.state.end()
            return models.CreateResponse(info=f"create embedding filename: {filename}")
        except AssertionError as e:
            shared.state.end()
            return models.TrainResponse(info=f"create embedding error: {e}")
        finally:
            shared.state.end()


    def create_hypernetwork(self, args: dict):
        try:
            shared.state.begin(job="create_hypernetwork")
            filename = create_hypernetwork(**args) # create empty embedding
            shared.state.end()
            return models.CreateResponse(info=f"create hypernetwork filename: {filename}")
        except AssertionError as e:
            shared.state.end()
            return models.TrainResponse(info=f"create hypernetwork error: {e}")
        finally:
            shared.state.end()

    def preprocess(self, args: dict):
        try:
@@ -625,14 +626,11 @@ class Api:
            shared.state.end()
            return models.PreprocessResponse(info='preprocess complete')
        except KeyError as e:
            shared.state.end()
            return models.PreprocessResponse(info=f"preprocess error: invalid token: {e}")
        except AssertionError as e:
            shared.state.end()
        except Exception as e:
            return models.PreprocessResponse(info=f"preprocess error: {e}")
        except FileNotFoundError as e:
        finally:
            shared.state.end()
            return models.PreprocessResponse(info=f'preprocess error: {e}')

    def train_embedding(self, args: dict):
        try:
@@ -649,11 +647,11 @@ class Api:
            finally:
                if not apply_optimizations:
                    sd_hijack.apply_optimizations()
                shared.state.end()
            return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
        except AssertionError as msg:
            shared.state.end()
        except Exception as msg:
            return models.TrainResponse(info=f"train embedding error: {msg}")
        finally:
            shared.state.end()

    def train_hypernetwork(self, args: dict):
        try:
@@ -675,9 +673,10 @@ class Api:
                    sd_hijack.apply_optimizations()
                shared.state.end()
            return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
        except AssertionError:
        except Exception as exc:
            return models.TrainResponse(info=f"train embedding error: {exc}")
        finally:
            shared.state.end()
            return models.TrainResponse(info=f"train embedding error: {error}")

    def get_memory(self):
        try: