Skip to content

Commit

Permalink
Add script csv datasets (huggingface#25)
Browse files Browse the repository at this point in the history
* First commit to try to create agnostic local CSV datasets

* Initial commit to allow genreic local CSV datasets

* Add dummy data

* Ignore index value in _generate_examples

* [WIP] - refactoring

* fixing CSV

* small fix to convert

* let's not handle ClassLabel for ArrowBuilder for now + style/quality

Co-authored-by: Thomas Wolf <[email protected]>
  • Loading branch information
jplu and thomwolf authored May 7, 2020
1 parent 377135a commit 791321d
Show file tree
Hide file tree
Showing 16 changed files with 641 additions and 1,004 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Locked files
*.lock

# Compiled python modules.
*.pyc

Expand Down Expand Up @@ -40,4 +43,4 @@ venv.bak/
.*.swp

# playground
/playground
/playground
50 changes: 50 additions & 0 deletions datasets/dummy_data/csv/test.csv

Large diffs are not rendered by default.

100 changes: 100 additions & 0 deletions datasets/dummy_data/csv/train.csv

Large diffs are not rendered by default.

84 changes: 84 additions & 0 deletions datasets/nlp/csv/csv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# coding=utf-8

from dataclasses import dataclass

import pyarrow as pa
import pyarrow.csv as pac

import nlp


@dataclass
class CsvConfig(nlp.BuilderConfig):
"""BuilderConfig for CSV."""
skip_rows: int = 0
header_as_column_names: bool = True
delimiter: str = ","
quote_char: str = "\""
read_options: pac.ReadOptions = None
parse_options: pac.ParseOptions = None
convert_options: pac.ConvertOptions = None

@property
def pa_read_options(self):
read_options = self.read_options or pac.ReadOptions()
read_options.skip_rows = self.skip_rows
read_options.autogenerate_column_names = not self.header_as_column_names
return read_options

@property
def pa_parse_options(self):
parse_options = self.parse_options or pac.ParseOptions()
parse_options.delimiter = self.delimiter
parse_options.quote_char = self.quote_char
return parse_options

@property
def pa_convert_options(self):
convert_options = self.convert_options or pac.ConvertOptions()
return convert_options


class Csv(nlp.ArrowBasedBuilder):
BUILDER_CONFIGS = [
CsvConfig(
name="CSV",
version=nlp.Version("1.0.0"),
description="Csv dataloader",
),
]

def _info(self):
return nlp.DatasetInfo()

def _split_generators(self, dl_manager):
""" We handle string, list and dicts in datafiles
"""
if isinstance(self.config.data_files, (str, list, tuple)):
files = self.config.data_files
if isinstance(files, str):
files = [files]
return [nlp.SplitGenerator(
name=nlp.Split.TRAIN,
gen_kwargs={"files": files})]
splits = []
for split_name in [nlp.Split.TRAIN, nlp.Split.VALIDATION, nlp.Split.TEST]:
if split_name in self.config.data_files:
files = self.config.data_files[split_name]
if isinstance(files, str):
files = [files]
splits.append(
nlp.SplitGenerator(
name=split_name,
gen_kwargs={"files": files}))
return splits

def _generate_tables(self, files):
for i, file in enumerate(files):
pa_table = pac.read_csv(
file,
read_options=self.config.pa_read_options,
parse_options=self.config.pa_parse_options,
convert_options=self.config.convert_options,
)
yield i, pa_table
4 changes: 2 additions & 2 deletions src/nlp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@
from . import datasets
from .arrow_dataset import Dataset
from .arrow_reader import ReadInstruction
from .builder import BeamBasedBuilder, BuilderConfig, DatasetBuilder, GeneratorBasedBuilder
from .builder import ArrowBasedBuilder, BeamBasedBuilder, BuilderConfig, DatasetBuilder, GeneratorBasedBuilder
from .features import ClassLabel, Features, Sequence, Tensor, Translation, TranslationVariableLanguages, Value
from .info import DatasetInfo
from .lazy_imports_lib import lazy_imports
from .load import builder, get_builder_cls_from_module, load, load_dataset_module
from .load import load
from .splits import NamedSplit, Split, SplitBase, SplitDict, SplitGenerator, SplitInfo, SubSplitInfo, percent
from .utils import *
from .utils.tqdm_utils import disable_progress_bar
62 changes: 49 additions & 13 deletions src/nlp/arrow_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,19 +37,20 @@ def __init__(
writer_batch_size: Optional[int] = None,
disable_nullable: bool = True,
):
if data_type is None and schema is None:
raise ValueError("At least one of data_type and schema must be provided.")
if path is None and stream is None:
raise ValueError("At least one of path and stream must be provided.")

if data_type is not None:
self._type: pa.DataType = data_type
self._schema: pa.Schema = pa.schema(field for field in self._type)
else:
elif schema is not None:
self._schema: pa.Schema = schema
self._type: pa.DataType = pa.struct(field for field in self._schema)
else:
self._schema = None
self._type = None

if disable_nullable:
if disable_nullable and self._schema is not None:
self._schema = pa.schema(pa.field(field.name, field.type, nullable=False) for field in self._type)
self._type = pa.struct(pa.field(field.name, field.type, nullable=False) for field in self._type)

Expand All @@ -59,21 +60,39 @@ def __init__(
else:
self.stream = stream

self.writer = pa.RecordBatchStreamWriter(self.stream, self._schema)
self.writer_batch_size = writer_batch_size

self._num_examples = 0
self._num_bytes = 0
self.current_rows = []

self._build_writer(schema=self._schema)

def _build_writer(self, pa_table=None, schema=None):
if schema is not None:
self._schema: pa.Schema = schema
self._type: pa.DataType = pa.struct(field for field in self._schema)
self.pa_writer = pa.RecordBatchStreamWriter(self.stream, schema)
elif pa_table is not None:
self._schema: pa.Schema = pa_table.schema
self._type: pa.DataType = pa.struct(field for field in self._schema)
self.pa_writer = pa.RecordBatchStreamWriter(self.stream, self._schema)
else:
self.pa_writer = None

@property
def schema(self):
return self._schema if self._schema is not None else []

def write_on_file(self):
""" Write stored examples
"""
pa_array = pa.array(self.current_rows, type=self._type)
pa_batch = pa.RecordBatch.from_struct_array(pa_array)
self._num_bytes += pa_array.nbytes
self.writer.write_batch(pa_batch)
self.current_rows = []
if self.current_rows:
pa_array = pa.array(self.current_rows, type=self._type)
pa_batch = pa.RecordBatch.from_struct_array(pa_array)
self._num_bytes += pa_array.nbytes
self.pa_writer.write_batch(pa_batch)
self.current_rows = []

def write(self, example: Dict[str, Any], writer_batch_size: Optional[int] = None):
""" Add a given Example to the write-pool which is written to file.
Expand Down Expand Up @@ -101,11 +120,28 @@ def write_batch(self, batch_examples: Dict[str, List[Any]], writer_batch_size: O
self._num_bytes += sum(batch.nbytes for batch in batches)
self._num_examples += pa_table.num_rows
for batch in batches:
self.writer.write_batch(batch)
self.pa_writer.write_batch(batch)

def write_table(self, pa_table: pa.Table, writer_batch_size: Optional[int] = None):
""" Write a batch of Example to file.
Args:
example: the Example to add.
"""
if writer_batch_size is None:
writer_batch_size = self.writer_batch_size
if self.pa_writer is None:
self._build_writer(pa_table=pa_table)
batches: List[pa.RecordBatch] = pa_table.to_batches(max_chunksize=writer_batch_size)
self._num_bytes += sum(batch.nbytes for batch in batches)
self._num_examples += pa_table.num_rows
for batch in batches:
self.pa_writer.write_batch(batch)

def finalize(self, close_stream=True):
self.write_on_file()
self.writer.close()
if self.pa_writer is not None:
self.write_on_file()
self.pa_writer.close()
if close_stream:
self.stream.close()
logger.info(
Expand Down
Loading

0 comments on commit 791321d

Please sign in to comment.