Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Output dataset in WebDataset format #29

Open
galv opened this issue Jun 23, 2021 · 1 comment
Open

Output dataset in WebDataset format #29

galv opened this issue Jun 23, 2021 · 1 comment
Assignees

Comments

@galv
Copy link
Collaborator

galv commented Jun 23, 2021

Webdataset format is preferable for data distribution for a few reasons:

  • Easy to use without installing dependencies because it's just .tar files.
  • Natively supports the concept of sharding. Each shard is a single .tar file.
  • Battle-tested in NeMo. Several published results were trained on datasets stored in webdataset format.
  • Not as crazy as TFRecord. No funkiness with serializing your half precision float data as 8-bit chars.

This is all great, but there's a pickle: Spark doesn't have support for reading and writing in this format!

A straightforward to get around this sort of issue is to "repartition()" or "coalesce()" the dataframe so that each partition is a reasonable size (we probably want 2-4GiB for each tar file). Then we can call foreachPartition to run conversion of each partition to a tar file in python: https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.sql.DataFrame.foreachPartition.html We could then save each file to remote storage using tf.io.gfile, which doesn't stall like gcsfuse does. It also isn't clear to me from the documentation whether or not I can use pandas_udfs with foreachPartition (although using the existing pickle serialization format may be fine for this use case).

I'm not 100% certain if that will work, though. We may run out of memory converting data from Spark's UnsafeRow format to Arrow format for python. One may wonder why I don't propose writing the foreachPartition function in Java. I would like to, but java doesn't have good support for the tar file format for some reason. All I can find is this: https://github.com/kamranzafar/jtar Python has built-in support, meanwhile: https://docs.python.org/3/library/tarfile.html

Maybe tar file format is easy enough to write a parser and serializer for anyway in Java or Scala from scratch, but I doubt it. the tarfile implementation in python is 2500 lines long.

The other alternative is to create a new "DataSource" for tar files in spark. Since Spark is commonly used for machine learning, it seems like support for Webdataset format is something we might want to make a publicly distributed spark plugin for to contribute back to the community.

@galv
Copy link
Collaborator Author

galv commented Jun 23, 2021

Okay, it looks like Apache Commons Compress library supports tar file format, so I would prefer to use that, going the java route. This actually seems pretty reasonable to do. The binary jar is only 632KB in total (!) http://commons.apache.org/proper/commons-compress/javadocs/api-release/index.html

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant