Unverified Commit 2f629db3 authored by Richard Berger's avatar Richard Berger
Browse files

Refactor Zstd dump styles

parent ced78a72
Loading
Loading
Loading
Loading
+32 −99
Original line number Diff line number Diff line
@@ -11,29 +11,24 @@
   See the README file in the top-level LAMMPS directory.
------------------------------------------------------------------------- */

/* ----------------------------------------------------------------------
   Contributing author: Richard Berger (Temple U)
------------------------------------------------------------------------- */

#include "dump_atom_zstd.h"
#include "domain.h"
#include "error.h"
#include "update.h"
#include "force.h"

#include <fmt/format.h>
#include <cstring>
#include <fmt/format.h>

using namespace LAMMPS_NS;

DumpAtomZstd::DumpAtomZstd(LAMMPS *lmp, int narg, char **arg) :
  DumpAtom(lmp, narg, arg)
{
  cctx = nullptr;
  zstdFp = nullptr;
  fp = nullptr;
  out_buffer_size = ZSTD_CStreamOutSize();
  out_buffer = new char[out_buffer_size];

  checksum_flag = 1;
  compression_level = 0; // = default

  if (!compressed)
    error->all(FLERR,"Dump atom/zstd only writes compressed files");
}
@@ -42,11 +37,6 @@ DumpAtomZstd::DumpAtomZstd(LAMMPS *lmp, int narg, char **arg) :

DumpAtomZstd::~DumpAtomZstd()
{
  if(cctx && zstdFp) zstd_close();

  delete [] out_buffer;
  out_buffer = nullptr;
  out_buffer_size = 0;
}

/* ----------------------------------------------------------------------
@@ -101,19 +91,15 @@ void DumpAtomZstd::openfile()

  if (filewriter) {
    if (append_flag) {
      zstdFp = fopen(filecurrent,"ab");
    } else {
      zstdFp = fopen(filecurrent,"wb");
      error->one(FLERR, "dump/zstd currently doesn't support append");
    }

    if (zstdFp == nullptr) error->one(FLERR,"Cannot open dump file");

    cctx = ZSTD_createCCtx();
    ZSTD_CCtx_setParameter(cctx, ZSTD_c_compressionLevel, compression_level);
    ZSTD_CCtx_setParameter(cctx, ZSTD_c_checksumFlag, checksum_flag);

    if (cctx == nullptr) error->one(FLERR,"Cannot create Zstd context");
  } else zstdFp = nullptr;
    try {
      writer.open(filecurrent);
    } catch (FileWriterException & e) {
      error->one(FLERR, e.what());
    }
  }

  // delete string with timestep replaced

@@ -151,7 +137,7 @@ void DumpAtomZstd::write_header(bigint ndump)
    }
    header += fmt::format("ITEM: ATOMS {}\n", columns);

    zstd_write(header.c_str(), header.length());
    writer.write(header.c_str(), header.length());
  }
}

@@ -159,14 +145,7 @@ void DumpAtomZstd::write_header(bigint ndump)

void DumpAtomZstd::write_data(int n, double *mybuf)
{
  ZSTD_inBuffer input = { mybuf, (size_t)n, 0 };
  ZSTD_EndDirective mode = ZSTD_e_continue;

  do {
      ZSTD_outBuffer output = { out_buffer, out_buffer_size, 0 };
      size_t const remaining = ZSTD_compressStream2(cctx, &output, &input, mode);
      fwrite(out_buffer, sizeof(char), output.pos, zstdFp);
  } while(input.pos < input.size);
  writer.write(mybuf, n);
}

/* ---------------------------------------------------------------------- */
@@ -176,11 +155,10 @@ void DumpAtomZstd::write()
  DumpAtom::write();
  if (filewriter) {
    if (multifile) {
      zstd_close();
      writer.close();
    } else {
      if (flush_flag && zstdFp) {
        zstd_flush();
        fflush(zstdFp);
      if (flush_flag && writer.isopen()) {
        writer.flush();
      }
    }
  }
@@ -192,67 +170,22 @@ int DumpAtomZstd::modify_param(int narg, char **arg)
{
  int consumed = DumpAtom::modify_param(narg, arg);
  if(consumed == 0) {
    try {
      if (strcmp(arg[0],"checksum") == 0) {
        if (narg < 2) error->all(FLERR,"Illegal dump_modify command");
      if (strcmp(arg[1],"yes") == 0) checksum_flag = 1;
      else if (strcmp(arg[1],"no") == 0) checksum_flag = 0;
        if (strcmp(arg[1],"yes") == 0) writer.setChecksum(true);
        else if (strcmp(arg[1],"no") == 0) writer.setChecksum(false);
        else error->all(FLERR,"Illegal dump_modify command");
        return 2;
      } else if (strcmp(arg[0],"compression_level") == 0) {
        if (narg < 2) error->all(FLERR,"Illegal dump_modify command");
      compression_level = force->inumeric(FLERR,arg[1]);
      int min_level = ZSTD_minCLevel();
      int max_level = ZSTD_maxCLevel();
      if (compression_level < min_level || compression_level > max_level)
        error->all(FLERR, fmt::format("Illegal dump_modify command: compression level must in the range of [{}, {}]", min_level, max_level));
        int compression_level = force->inumeric(FLERR,arg[1]);
        writer.setCompressionLevel(compression_level);
        return 2;
      }
    } catch (FileWriterException & e) {
      error->one(FLERR, e.what());
    }
  return consumed;
  }

/* ---------------------------------------------------------------------- */

void DumpAtomZstd::zstd_write(const void * buffer, size_t length)
{
  ZSTD_inBuffer input = { buffer, length, 0 };
  ZSTD_EndDirective mode = ZSTD_e_continue;
  
  do {
    ZSTD_outBuffer output = { out_buffer, out_buffer_size, 0 };
    size_t const remaining = ZSTD_compressStream2(cctx, &output, &input, mode);
    fwrite(out_buffer, sizeof(char), output.pos, zstdFp);
  } while(input.pos < input.size);
}

void DumpAtomZstd::zstd_flush() {
  size_t remaining;
  ZSTD_inBuffer input = { nullptr, 0, 0 };
  ZSTD_EndDirective mode = ZSTD_e_flush;

  do {
    ZSTD_outBuffer output = { out_buffer, out_buffer_size, 0 };
    remaining = ZSTD_compressStream2(cctx, &output, &input, mode);
    fwrite(out_buffer, sizeof(char), output.pos, zstdFp);
  } while(remaining);
}

/* ---------------------------------------------------------------------- */

void DumpAtomZstd::zstd_close()
{
  size_t remaining;
  ZSTD_inBuffer input = { nullptr, 0, 0 };
  ZSTD_EndDirective mode = ZSTD_e_end;

  do {
    ZSTD_outBuffer output = { out_buffer, out_buffer_size, 0 };
    remaining = ZSTD_compressStream2(cctx, &output, &input, mode);
    fwrite(out_buffer, sizeof(char), output.pos, zstdFp);
  } while(remaining);

  ZSTD_freeCCtx(cctx);
  cctx = nullptr;
  if (zstdFp) fclose(zstdFp);
  zstdFp = nullptr;
  return consumed;
}
+6 −13
Original line number Diff line number Diff line
@@ -11,6 +11,10 @@
   See the README file in the top-level LAMMPS directory.
------------------------------------------------------------------------- */

/* ----------------------------------------------------------------------
   Contributing author: Richard Berger (Temple U)
------------------------------------------------------------------------- */

#ifdef DUMP_CLASS

DumpStyle(atom/zstd,DumpAtomZstd)
@@ -21,8 +25,7 @@ DumpStyle(atom/zstd,DumpAtomZstd)
#define LMP_DUMP_ATOM_ZSTD_H

#include "dump_atom.h"
#include <zstd.h>
#include <stdio.h>
#include "zstd_file_writer.h"

namespace LAMMPS_NS {

@@ -32,13 +35,7 @@ class DumpAtomZstd : public DumpAtom {
  virtual ~DumpAtomZstd();

 protected:
  int compression_level;
  int checksum_flag;

  ZSTD_CCtx * cctx;
  FILE * zstdFp;
  char * out_buffer;
  size_t out_buffer_size;
  ZstdFileWriter writer;

  virtual void openfile();
  virtual void write_header(bigint);
@@ -46,10 +43,6 @@ class DumpAtomZstd : public DumpAtom {
  virtual void write();

  virtual int modify_param(int, char **);

  void zstd_write(const void * buffer, size_t length);
  void zstd_flush();
  void zstd_close();
};

}
+33 −100
Original line number Diff line number Diff line
@@ -11,6 +11,10 @@
   See the README file in the top-level LAMMPS directory.
------------------------------------------------------------------------- */

/* ----------------------------------------------------------------------
   Contributing author: Richard Berger (Temple U)
------------------------------------------------------------------------- */

#include "dump_custom_zstd.h"
#include "domain.h"
#include "error.h"
@@ -25,32 +29,16 @@ using namespace LAMMPS_NS;
DumpCustomZstd::DumpCustomZstd(LAMMPS *lmp, int narg, char **arg) :
  DumpCustom(lmp, narg, arg)
{
  cctx = nullptr;
  zstdFp = nullptr;
  fp = nullptr;
  out_buffer_size = ZSTD_CStreamOutSize();
  out_buffer = new char[out_buffer_size];

  checksum_flag = 1;
  compression_level = 0; // = default

  if (!compressed)
    error->all(FLERR,"Dump custom/zstd only writes compressed files");
}


/* ---------------------------------------------------------------------- */

DumpCustomZstd::~DumpCustomZstd()
{
  if(cctx && zstdFp) zstd_close();

  delete [] out_buffer;
  out_buffer = nullptr;
  out_buffer_size = 0;
}


/* ----------------------------------------------------------------------
   generic opening of a dump file
   ASCII or binary or gzipped
@@ -103,25 +91,23 @@ void DumpCustomZstd::openfile()

  if (filewriter) {
    if (append_flag) {
      zstdFp = fopen(filecurrent,"ab");
    } else {
      zstdFp = fopen(filecurrent,"wb");
      error->one(FLERR, "dump/zstd currently doesn't support append");
    }

    if (zstdFp == nullptr) error->one(FLERR,"Cannot open dump file");

    cctx = ZSTD_createCCtx();
    ZSTD_CCtx_setParameter(cctx, ZSTD_c_compressionLevel, compression_level);
    ZSTD_CCtx_setParameter(cctx, ZSTD_c_checksumFlag, checksum_flag);

    if (cctx == nullptr) error->one(FLERR,"Cannot create Zstd context");
  } else zstdFp = nullptr;
    try {
      writer.open(filecurrent);
    } catch (FileWriterException & e) {
      error->one(FLERR, e.what());
    }
  }

  // delete string with timestep replaced

  if (multifile) delete [] filecurrent;
}

/* ---------------------------------------------------------------------- */

void DumpCustomZstd::write_header(bigint ndump)
{
  std::string header;
@@ -151,7 +137,7 @@ void DumpCustomZstd::write_header(bigint ndump)
    }
    header += fmt::format("ITEM: ATOMS {}\n", columns);

    zstd_write(header.c_str(), header.length());
    writer.write(header.c_str(), header.length());
  }
}

@@ -159,14 +145,7 @@ void DumpCustomZstd::write_header(bigint ndump)

void DumpCustomZstd::write_data(int n, double *mybuf)
{
  ZSTD_inBuffer input = { mybuf, (size_t)n, 0 };
  ZSTD_EndDirective mode = ZSTD_e_continue;

  do {
      ZSTD_outBuffer output = { out_buffer, out_buffer_size, 0 };
      size_t const remaining = ZSTD_compressStream2(cctx, &output, &input, mode);
      fwrite(out_buffer, sizeof(char), output.pos, zstdFp);
  } while(input.pos < input.size);
  writer.write(mybuf, n);
}

/* ---------------------------------------------------------------------- */
@@ -176,11 +155,10 @@ void DumpCustomZstd::write()
  DumpCustom::write();
  if (filewriter) {
    if (multifile) {
      zstd_close();
      writer.close();
    } else {
      if (flush_flag && zstdFp) {
        zstd_flush();
        fflush(zstdFp);
      if (flush_flag && writer.isopen()) {
        writer.flush();
      }
    }
  }
@@ -192,67 +170,22 @@ int DumpCustomZstd::modify_param(int narg, char **arg)
{
  int consumed = DumpCustom::modify_param(narg, arg);
  if(consumed == 0) {
    try {
      if (strcmp(arg[0],"checksum") == 0) {
        if (narg < 2) error->all(FLERR,"Illegal dump_modify command");
      if (strcmp(arg[1],"yes") == 0) checksum_flag = 1;
      else if (strcmp(arg[1],"no") == 0) checksum_flag = 0;
        if (strcmp(arg[1],"yes") == 0) writer.setChecksum(true);
        else if (strcmp(arg[1],"no") == 0) writer.setChecksum(false);
        else error->all(FLERR,"Illegal dump_modify command");
        return 2;
      } else if (strcmp(arg[0],"compression_level") == 0) {
        if (narg < 2) error->all(FLERR,"Illegal dump_modify command");
      compression_level = force->inumeric(FLERR,arg[1]);
      int min_level = ZSTD_minCLevel();
      int max_level = ZSTD_maxCLevel();
      if (compression_level < min_level || compression_level > max_level)
        error->all(FLERR, fmt::format("Illegal dump_modify command: compression level must in the range of [{}, {}]", min_level, max_level));
        int compression_level = force->inumeric(FLERR,arg[1]);
        writer.setCompressionLevel(compression_level);
        return 2;
      }
    } catch (FileWriterException & e) {
      error->one(FLERR, e.what());
    }
  return consumed;
}

/* ---------------------------------------------------------------------- */

void DumpCustomZstd::zstd_write(const void * buffer, size_t length)
{
  ZSTD_inBuffer input = { buffer, length, 0 };
  ZSTD_EndDirective mode = ZSTD_e_continue;
  
  do {
    ZSTD_outBuffer output = { out_buffer, out_buffer_size, 0 };
    size_t const remaining = ZSTD_compressStream2(cctx, &output, &input, mode);
    fwrite(out_buffer, sizeof(char), output.pos, zstdFp);
  } while(input.pos < input.size);
}

void DumpCustomZstd::zstd_flush() {
  size_t remaining;
  ZSTD_inBuffer input = { nullptr, 0, 0 };
  ZSTD_EndDirective mode = ZSTD_e_flush;

  do {
    ZSTD_outBuffer output = { out_buffer, out_buffer_size, 0 };
    remaining = ZSTD_compressStream2(cctx, &output, &input, mode);
    fwrite(out_buffer, sizeof(char), output.pos, zstdFp);
  } while(remaining);
  }

/* ---------------------------------------------------------------------- */

void DumpCustomZstd::zstd_close()
{
  size_t remaining;
  ZSTD_inBuffer input = { nullptr, 0, 0 };
  ZSTD_EndDirective mode = ZSTD_e_end;

  do {
    ZSTD_outBuffer output = { out_buffer, out_buffer_size, 0 };
    remaining = ZSTD_compressStream2(cctx, &output, &input, mode);
    fwrite(out_buffer, sizeof(char), output.pos, zstdFp);
  } while(remaining);

  ZSTD_freeCCtx(cctx);
  cctx = nullptr;
  if (zstdFp) fclose(zstdFp);
  zstdFp = nullptr;
  return consumed;
}
+6 −12
Original line number Diff line number Diff line
@@ -11,6 +11,10 @@
   See the README file in the top-level LAMMPS directory.
------------------------------------------------------------------------- */

/* ----------------------------------------------------------------------
   Contributing author: Richard Berger (Temple U)
------------------------------------------------------------------------- */

#ifdef DUMP_CLASS

DumpStyle(custom/zstd,DumpCustomZstd)
@@ -21,7 +25,7 @@ DumpStyle(custom/zstd,DumpCustomZstd)
#define LMP_DUMP_CUSTOM_ZSTD_H

#include "dump_custom.h"
#include <zstd.h>
#include "zstd_file_writer.h"
#include <stdio.h>

namespace LAMMPS_NS {
@@ -32,13 +36,7 @@ class DumpCustomZstd : public DumpCustom {
  virtual ~DumpCustomZstd();

 protected:
  int compression_level;
  int checksum_flag;

  ZSTD_CCtx * cctx;
  FILE * zstdFp;
  char * out_buffer;
  size_t out_buffer_size;
  ZstdFileWriter writer;

  virtual void openfile();
  virtual void write_header(bigint);
@@ -46,10 +44,6 @@ class DumpCustomZstd : public DumpCustom {
  virtual void write();

  virtual int modify_param(int, char **);

  void zstd_write(const void * buffer, size_t length);
  void zstd_flush();
  void zstd_close();
};

}
+158 −0
Original line number Diff line number Diff line
/* -*- c++ -*- ----------------------------------------------------------
   LAMMPS - Large-scale Atomic/Molecular Massively Parallel Simulator
   http://lammps.sandia.gov, Sandia National Laboratories
   Steve Plimpton, sjplimp@sandia.gov

   Copyright (2003) Sandia Corporation.  Under the terms of Contract
   DE-AC04-94AL85000 with Sandia Corporation, the U.S. Government retains
   certain rights in this software.  This software is distributed under
   the GNU General Public License.

   See the README file in the top-level LAMMPS directory.
------------------------------------------------------------------------- */

/* ----------------------------------------------------------------------
   Contributing author: Richard Berger (Temple U)
------------------------------------------------------------------------- */

#include "zstd_file_writer.h"
#include <stdio.h>
#include <fmt/format.h>

using namespace LAMMPS_NS;

ZstdFileWriter::ZstdFileWriter() : FileWriter(),
    fp(nullptr),
    cctx(nullptr),
    compression_level(0),
    checksum_flag(1)
{
  out_buffer_size = ZSTD_CStreamOutSize();
  out_buffer = new char[out_buffer_size];
}

/* ---------------------------------------------------------------------- */

ZstdFileWriter::~ZstdFileWriter()
{
  close();

  delete [] out_buffer;
  out_buffer = nullptr;
  out_buffer_size = 0;
}

/* ---------------------------------------------------------------------- */

void ZstdFileWriter::open(const std::string & path)
{
    if(isopen()) return;

    fp = fopen(path.c_str(), "wb");

    if (!fp) {
        throw FileWriterException(fmt::format("Could not open file '{}'", path));
    }

    cctx = ZSTD_createCCtx();

    if (!cctx) {
        fclose(fp);
        fp = nullptr;
        throw FileWriterException("Could not create Zstd context");
    }

    ZSTD_CCtx_setParameter(cctx, ZSTD_c_compressionLevel, compression_level);
    ZSTD_CCtx_setParameter(cctx, ZSTD_c_checksumFlag, checksum_flag);
}

/* ---------------------------------------------------------------------- */

size_t ZstdFileWriter::write(const void * buffer, size_t length)
{
  if(!isopen()) return 0;

  ZSTD_inBuffer input = { buffer, length, 0 };
  ZSTD_EndDirective mode = ZSTD_e_continue;

  do {
    ZSTD_outBuffer output = { out_buffer, out_buffer_size, 0 };
    size_t const remaining = ZSTD_compressStream2(cctx, &output, &input, mode);
    fwrite(out_buffer, sizeof(char), output.pos, fp);
  } while(input.pos < input.size);

  return length;
}

/* ---------------------------------------------------------------------- */

void ZstdFileWriter::flush()
{
  if(!isopen()) return;

  size_t remaining;
  ZSTD_inBuffer input = { nullptr, 0, 0 };
  ZSTD_EndDirective mode = ZSTD_e_flush;

  do {
    ZSTD_outBuffer output = { out_buffer, out_buffer_size, 0 };
    remaining = ZSTD_compressStream2(cctx, &output, &input, mode);
    fwrite(out_buffer, sizeof(char), output.pos, fp);
  } while(remaining);

  fflush(fp);
}

/* ---------------------------------------------------------------------- */

void ZstdFileWriter::close()
{
  if(!isopen()) return;

  size_t remaining;
  ZSTD_inBuffer input = { nullptr, 0, 0 };
  ZSTD_EndDirective mode = ZSTD_e_end;

  do {
    ZSTD_outBuffer output = { out_buffer, out_buffer_size, 0 };
    remaining = ZSTD_compressStream2(cctx, &output, &input, mode);
    fwrite(out_buffer, sizeof(char), output.pos, fp);
  } while(remaining);

  ZSTD_freeCCtx(cctx);
  cctx = nullptr;
  fclose(fp);
  fp = nullptr;
}

/* ---------------------------------------------------------------------- */

bool ZstdFileWriter::isopen() const
{
  return fp && cctx;
}

/* ---------------------------------------------------------------------- */

void ZstdFileWriter::setCompressionLevel(int level)
{
  if (isopen())
    throw FileWriterException("Compression level can not be changed while file is open");

  const int min_level = ZSTD_minCLevel();
  const int max_level = ZSTD_maxCLevel();

  if(level < min_level || level > max_level)
    throw FileWriterException(fmt::format("Compression level must in the range of [{}, {}]", min_level, max_level));

  compression_level = level;
}

/* ---------------------------------------------------------------------- */

void ZstdFileWriter::setChecksum(bool enabled)
{
  if (isopen())
    throw FileWriterException("Checksum flag can not be changed while file is open");
  checksum_flag = enabled ? 1 : 0;
}
Loading