Unverified Commit c7a6c148 authored by micimize's avatar micimize
Browse files

fix pickle typing for CI, small refactor into load_pickle_file

parent 67bf271e
Loading
Loading
Loading
Loading
+24 −16
Original line number Diff line number Diff line
@@ -10,7 +10,7 @@ import tarfile
import zipfile
import logging
from urllib.request import urlretrieve
from typing import Any, Iterator, List, Optional, Tuple, Union
from typing import Any, Iterator, List, Optional, Tuple, Union, cast, IO

import pandas as pd
import numpy as np
@@ -322,9 +322,29 @@ def load_json_files(input_files: List[str],
        shard_num += 1
        yield df

def load_pickle_file(input_file: str) -> Any:
  """Load from single, possibly gzipped, pickle file.

  Parameters
  ----------
  input_file: str
    The filename of pickle file. This function can load from
    gzipped pickle file like `XXXX.pkl.gz`.

  Returns
  -------
  Any
    The object which is loaded from the pickle file.
  """
  if ".gz" in input_file:
    with gzip.open(input_file, "rb") as unzipped_file:
      return pickle.load(cast(IO[bytes], unzipped_file))
  else:
    with open(input_file, "rb") as opened_file:
      return pickle.load(opened_file)

def load_pickle_files(input_files: List[str]) -> Iterator[Any]:
  """Load dataset from pickle file.
  """Load dataset from pickle files.

  Parameters
  ----------
@@ -338,13 +358,7 @@ def load_pickle_files(input_files: List[str]) -> Iterator[Any]:
    Generator which yields the objects which is loaded from each pickle file.
  """
  for input_file in input_files:
    if ".gz" in input_file:
      with gzip.open(input_file, "rb") as f:
        df = pickle.load(f)
    else:
      with open(input_file, "rb") as f:
        df = pickle.load(f)
    yield df
    yield load_pickle_file(input_file)


def load_data(input_files: List[str],
@@ -442,13 +456,7 @@ def load_from_disk(filename: str) -> Any:
    name = os.path.splitext(name)[0]
  extension = os.path.splitext(name)[1]
  if extension == ".pkl":
    if ".gz" in filename:
      with gzip.open(filename, "rb") as f:
        df = pickle.load(f)
    else:
      with open(filename, "rb") as f:
        df = pickle.load(f)
    return df
    return load_pickle_file(filename)
  elif extension == ".joblib":
    return joblib.load(filename)
  elif extension == ".csv":