Commit 76d876f8 authored by Richard Berger's avatar Richard Berger
Browse files

Allow detection of MPI_Abort condition in library call

The return value of `lammps_get_last_error_message` now encodes if the last
error was recoverable or should cause an `MPI_Abort`. The driving code is
responsible of reacting to the error and calling `MPI_Abort` on the
communicator it passed to the LAMMPS instance.
parent 2fb666dc
Loading
Loading
Loading
Loading
+10 −2
Original line number Diff line number Diff line
@@ -43,6 +43,7 @@ class lammps(object):
  # create instance of LAMMPS

  def __init__(self,name="",cmdargs=None,ptr=None,comm=None):
    self.comm = comm

    # determine module location

@@ -152,8 +153,15 @@ class lammps(object):

    if self.lib.lammps_has_error(self.lmp):
      sb = create_string_buffer(100)
      self.lib.lammps_get_last_error_message(self.lmp, sb, 100)
      raise Exception(sb.value.decode().strip())
      error_type = self.lib.lammps_get_last_error_message(self.lmp, sb, 100)
      error_msg = sb.value.decode().strip()

      if error_type == 2 and lammps.has_mpi4py_v2 and self.comm != None and self.comm.Get_size() > 1:
        print(error_msg, file=sys.stderr)
        print("Aborting...", file=sys.stderr)
        sys.stderr.flush()
        self.comm.Abort()
      raise Exception(error_msg)

  def extract_global(self,name,type):
    if name: name = name.encode()
+14 −4
Original line number Diff line number Diff line
@@ -22,7 +22,7 @@ using namespace LAMMPS_NS;

/* ---------------------------------------------------------------------- */

Error::Error(LAMMPS *lmp) : Pointers(lmp), last_error_message(NULL) {}
Error::Error(LAMMPS *lmp) : Pointers(lmp), last_error_message(NULL), last_error_type(ERROR_NONE) {}

/* ----------------------------------------------------------------------
   called by all procs in universe
@@ -208,13 +208,22 @@ char * Error::get_last_error() const
  return last_error_message;
}

/* ----------------------------------------------------------------------
   return the type of the last error reported by LAMMPS (only used if
   compiled with -DLAMMPS_EXCEPTIONS)
------------------------------------------------------------------------- */

ErrorType Error::get_last_error_type() const
{
  return last_error_type;
}

/* ----------------------------------------------------------------------
   set the last error message (only used if compiled with
   -DLAMMPS_EXCEPTIONS)
   set the last error message and error type
   (only used if compiled with -DLAMMPS_EXCEPTIONS)
------------------------------------------------------------------------- */

void Error::set_last_error(const char * msg)
void Error::set_last_error(const char * msg, ErrorType type)
{
  delete [] last_error_message;

@@ -224,4 +233,5 @@ void Error::set_last_error(const char * msg)
  } else {
    last_error_message = NULL;
  }
  last_error_type = type;
}
+10 −2
Original line number Diff line number Diff line
@@ -47,8 +47,15 @@ public:
  }
};

enum ErrorType {
   ERROR_NONE   = 0,
   ERROR_NORMAL = 1,
   ERROR_ABORT  = 2
};

class Error : protected Pointers {
  char * last_error_message;
  ErrorType last_error_type;

 public:
  Error(class LAMMPS *);
@@ -64,7 +71,8 @@ class Error : protected Pointers {
  void done(int = 0); // 1 would be fully backwards compatible

  char *    get_last_error() const;
  void   set_last_error(const char * msg);
  ErrorType get_last_error_type() const;
  void   set_last_error(const char * msg, ErrorType type = ERROR_NORMAL);
};

}
+17 −3
Original line number Diff line number Diff line
@@ -113,8 +113,18 @@ char *lammps_command(void *ptr, char *str)

  try {
    return lmp->input->one(str);
  } catch(LAMMPSAbortException & ae) {
    int nprocs = 0;
    MPI_Comm_size(ae.universe, &nprocs );

    if (nprocs > 1) {
      error->set_last_error(ae.message.c_str(), ERROR_ABORT);
    } else {
      error->set_last_error(ae.message.c_str(), ERROR_NORMAL);
    }
    return NULL;
  } catch(LAMMPSException & e) {
    error->set_last_error(e.message.c_str());
    error->set_last_error(e.message.c_str(), ERROR_NORMAL);
    return NULL;
  }
}
@@ -613,6 +623,9 @@ int lammps_has_error(void *ptr) {

/* ----------------------------------------------------------------------
   Copy the last error message of LAMMPS into a character buffer
   The return value encodes which type of error it is.
   1 = normal error (recoverable)
   2 = abort error (non-recoverable)
------------------------------------------------------------------------- */

int lammps_get_last_error_message(void *ptr, char * buffer, int buffer_size) {
@@ -620,9 +633,10 @@ int lammps_get_last_error_message(void *ptr, char * buffer, int buffer_size) {
  Error * error = lmp->error;

  if(error->get_last_error()) {
    int error_type = error->get_last_error_type();
    strncpy(buffer, error->get_last_error(), buffer_size-1);
    error->set_last_error(NULL);
    return 1;
    error->set_last_error(NULL, ERROR_NONE);
    return error_type;
  }
  return 0;
}