Source code for claf.data.data_handler


import logging
import pickle
import os
from pathlib import Path, PosixPath
import shutil
import tempfile

import msgpack
import requests
from tqdm import tqdm

from claf import nsml

logger = logging.getLogger(__name__)


[docs]class CachePath: if nsml.IS_ON_NSML: ROOT = Path("./claf_cache") else: ROOT = Path.home() / ".claf_cache" DATASET = ROOT / "dataset" MACHINE = ROOT / "machine" PRETRAINED_VECTOR = ROOT / "pretrained_vector" TOKEN_COUNTER = ROOT / "token_counter" VOCAB = ROOT / "vocab"
[docs]class DataHandler: """ DataHandler with CachePath - read (from_path, from_http) - dump (.msgpack or .pkl (pickle)) - load """ def __init__(self, cache_path=CachePath.ROOT): if type(cache_path) != PosixPath: raise ValueError(f"cache_path type is PosixPath (use pathlib.Path). not f{type(cache_path)}") self.cache_path = cache_path cache_path.mkdir(parents=True, exist_ok=True)
[docs] def convert_cache_path(self, path): cache_data_path = self.cache_path / Path(path) return cache_data_path
[docs] def read_embedding(self, file_path): raise NotImplementedError()
[docs] def read(self, file_path, encoding="utf-8", return_path=False): if file_path.startswith("http"): file_path = self._read_from_http(file_path, encoding) path = Path(file_path) if path.exists(): if return_path: return path return path.read_bytes().decode(encoding) if nsml.IS_ON_NSML: dataset_path = Path(nsml.DATASET_PATH) path = dataset_path / file_path if not path.exists(): path = dataset_path / "train" / file_path if not path.exists(): raise FileNotFoundError(path) if path.exists(): if return_path: return path return path.read_bytes().decode(encoding) else: raise FileNotFoundError(f"{file_path} is not found.")
def _read_from_http(self, file_path, encoding, return_path=False): cache_data_path = self.cache_path / Path(file_path).name if cache_data_path.exists(): logger.info(f"'{file_path}' is already downloaded.") pass else: with tempfile.TemporaryFile() as temp_file: self._download_from_http(temp_file, file_path) temp_file.flush() temp_file.seek(0) with open(cache_data_path, 'wb') as cache_file: shutil.copyfileobj(temp_file, cache_file) return cache_data_path def _download_from_http(self, temp_file, url): req = requests.get(url, stream=True) content_length = req.headers.get('Content-Length') total = int(content_length) if content_length is not None else None with tqdm(total=total, unit="B", unit_scale=True, desc="download...") as pbar: for chunk in req.iter_content(chunk_size=1024): if chunk: # filter out keep-alive new chunks temp_file.write(chunk) pbar.update(len(chunk))
[docs] def cache_token_counter(self, data_reader_config, tokenizer_name, obj=None): data_paths = os.path.basename(data_reader_config.train_file_path) if getattr(data_reader_config, "valid_file_path", None): data_paths += "#" + os.path.basename(data_reader_config.valid_file_path) path = self.cache_path / data_reader_config.dataset / data_paths path.mkdir(parents=True, exist_ok=True) path = path / tokenizer_name if obj: self.dump(path, obj) else: return self.load(path)
[docs] def load(self, file_path, encoding="utf-8"): path = self.cache_path / file_path logger.info(f"load path: {path}") msgpack_path = path.with_suffix(".msgpack") if msgpack_path.exists(): return self._load_msgpack(msgpack_path, encoding) pickle_path = path.with_suffix(".pkl") if pickle_path.exists(): return self._load_pickle(pickle_path, encoding) return None
def _load_msgpack(self, path, encoding): with open(path, "rb") as in_file: return msgpack.unpack(in_file, encoding=encoding) def _load_pickle(self, path, encoding): with open(path, "rb") as in_file: return pickle.load(in_file, encoding=encoding)
[docs] def dump(self, file_path, obj, encoding="utf-8"): path = self.cache_path / file_path path.parent.mkdir(parents=True, exist_ok=True) try: with open(path.with_suffix(".msgpack"), "wb") as out_file: msgpack.pack(obj, out_file, encoding=encoding) except TypeError: os.remove(path.with_suffix(".msgpack")) with open(path.with_suffix(".pkl"), "wb") as out_file: pickle.dump(obj, out_file, protocol=pickle.HIGHEST_PROTOCOL)