Skip to content

Commit

Permalink
[SPARK-22666][ML][SQL] Spark datasource for image format
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

Implement an image schema datasource.

This image datasource support:
  - partition discovery (loading partitioned images)
  - dropImageFailures (the same behavior with `ImageSchema.readImage`)
  - path wildcard matching (the same behavior with `ImageSchema.readImage`)
  - loading recursively from directory (different from `ImageSchema.readImage`, but use such path: `/path/to/dir/**`)

This datasource **NOT** support:
  - specify `numPartitions` (it will be determined by datasource automatically)
  - sampling (you can use `df.sample` later but the sampling operator won't be pushdown to datasource)

## How was this patch tested?
Unit tests.

## Benchmark
I benchmark and compare the cost time between old `ImageSchema.read` API and my image datasource.

**cluster**: 4 nodes, each with 64GB memory, 8 cores CPU
**test dataset**: Flickr8k_Dataset (about 8091 images)

**time cost**:
- My image datasource time (automatically generate 258 partitions):  38.04s
- `ImageSchema.read` time (set 16 partitions): 68.4s
- `ImageSchema.read` time (set 258 partitions):  90.6s

**time cost when increase image number by double (clone Flickr8k_Dataset and loads double number images)**:
- My image datasource time (automatically generate 515 partitions):  95.4s
- `ImageSchema.read` (set 32 partitions): 109s
- `ImageSchema.read` (set 515 partitions):  105s

So we can see that my image datasource implementation (this PR) bring some performance improvement compared against old`ImageSchema.read` API.

Closes apache#22328 from WeichenXu123/image_datasource.

Authored-by: WeichenXu <[email protected]>
Signed-off-by: Xiangrui Meng <[email protected]>
  • Loading branch information
WeichenXu123 authored and mengxr committed Sep 5, 2018
1 parent c66eef8 commit 9254492
Show file tree
Hide file tree
Showing 27 changed files with 323 additions and 4 deletions.
File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes.
13 changes: 13 additions & 0 deletions data/mllib/images/origin/license.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
The images in the folder "kittens" are under the creative commons CC0 license, or no rights reserved:
https://creativecommons.org/share-your-work/public-domain/cc0/
The images are taken from:
https://ccsearch.creativecommons.org/image/detail/WZnbJSJ2-dzIDiuUUdto3Q==
https://ccsearch.creativecommons.org/image/detail/_TlKu_rm_QrWlR0zthQTXA==
https://ccsearch.creativecommons.org/image/detail/OPNnHJb6q37rSZ5o_L5JHQ==
https://ccsearch.creativecommons.org/image/detail/B2CVP_j5KjwZm7UAVJ3Hvw==

The chr30.4.184.jpg and grayscale.jpg images are also under the CC0 license, taken from:
https://ccsearch.creativecommons.org/image/detail/8eO_qqotBfEm2UYxirLntw==

The image under "multi-channel" directory is under the CC BY-SA 4.0 license cropped from:
https://en.wikipedia.org/wiki/Alpha_compositing#/media/File:Hue_alpha_falloff.png
File renamed without changes
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
not an image
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
org.apache.spark.ml.source.libsvm.LibSVMFileFormat
org.apache.spark.ml.source.image.ImageFileFormat
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.source.image

/**
* `image` package implements Spark SQL data source API for loading image data as `DataFrame`.
* The loaded `DataFrame` has one `StructType` column: `image`.
* The schema of the `image` column is:
* - origin: String (represents the file path of the image)
* - height: Int (height of the image)
* - width: Int (width of the image)
* - nChannels: Int (number of the image channels)
* - mode: Int (OpenCV-compatible type)
* - data: BinaryType (Image bytes in OpenCV-compatible order: row-wise BGR in most cases)
*
* To use image data source, you need to set "image" as the format in `DataFrameReader` and
* optionally specify the data source options, for example:
* {{{
* // Scala
* val df = spark.read.format("image")
* .option("dropInvalid", true)
* .load("data/mllib/images/partitioned")
*
* // Java
* Dataset<Row> df = spark.read().format("image")
* .option("dropInvalid", true)
* .load("data/mllib/images/partitioned");
* }}}
*
* Image data source supports the following options:
* - "dropInvalid": Whether to drop the files that are not valid images from the result.
*
* @note This IMAGE data source does not support saving images to files.
*
* @note This class is public for documentation purpose. Please don't use this class directly.
* Rather, use the data source API as illustrated above.
*/
class ImageDataSource private() {}
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.source.image

import com.google.common.io.{ByteStreams, Closeables}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, Path}
import org.apache.hadoop.mapreduce.Job

import org.apache.spark.ml.image.ImageSchema
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, UnsafeRow}
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.execution.datasources.{DataSource, FileFormat, OutputWriterFactory, PartitionedFile}
import org.apache.spark.sql.sources.{DataSourceRegister, Filter}
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.SerializableConfiguration

private[image] class ImageFileFormat extends FileFormat with DataSourceRegister {

override def inferSchema(
sparkSession: SparkSession,
options: Map[String, String],
files: Seq[FileStatus]): Option[StructType] = Some(ImageSchema.imageSchema)

override def prepareWrite(
sparkSession: SparkSession,
job: Job,
options: Map[String, String],
dataSchema: StructType): OutputWriterFactory = {
throw new UnsupportedOperationException("Write is not supported for image data source")
}

override def shortName(): String = "image"

override protected def buildReader(
sparkSession: SparkSession,
dataSchema: StructType,
partitionSchema: StructType,
requiredSchema: StructType,
filters: Seq[Filter],
options: Map[String, String],
hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = {
assert(
requiredSchema.length <= 1,
"Image data source only produces a single data column named \"image\".")

val broadcastedHadoopConf =
sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf))

val imageSourceOptions = new ImageOptions(options)

(file: PartitionedFile) => {
val emptyUnsafeRow = new UnsafeRow(0)
if (!imageSourceOptions.dropInvalid && requiredSchema.isEmpty) {
Iterator(emptyUnsafeRow)
} else {
val origin = file.filePath
val path = new Path(origin)
val fs = path.getFileSystem(broadcastedHadoopConf.value.value)
val stream = fs.open(path)
val bytes = try {
ByteStreams.toByteArray(stream)
} finally {
Closeables.close(stream, true)
}
val resultOpt = ImageSchema.decode(origin, bytes)
val filteredResult = if (imageSourceOptions.dropInvalid) {
resultOpt.toIterator
} else {
Iterator(resultOpt.getOrElse(ImageSchema.invalidImageRow(origin)))
}

if (requiredSchema.isEmpty) {
filteredResult.map(_ => emptyUnsafeRow)
} else {
val converter = RowEncoder(requiredSchema)
filteredResult.map(row => converter.toRow(row))
}
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.source.image

import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap

private[image] class ImageOptions(
@transient private val parameters: CaseInsensitiveMap[String]) extends Serializable {

def this(parameters: Map[String, String]) = this(CaseInsensitiveMap(parameters))

/**
* Whether to drop invalid images. If true, invalid images will be removed, otherwise
* invalid images will be returned with empty data and all other field filled with `-1`.
*/
val dropInvalid = parameters.getOrElse("dropInvalid", "false").toBoolean
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.sql.types._

class ImageSchemaSuite extends SparkFunSuite with MLlibTestSparkContext {
// Single column of images named "image"
private lazy val imagePath = "../data/mllib/images"
private lazy val imagePath = "../data/mllib/images/origin"

test("Smoke test: create basic ImageSchema dataframe") {
val origin = "path"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.source.image

import java.nio.file.Paths

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.image.ImageSchema._
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.Row
import org.apache.spark.sql.functions.{col, substring_index}

class ImageFileFormatSuite extends SparkFunSuite with MLlibTestSparkContext {

// Single column of images named "image"
private lazy val imagePath = "../data/mllib/images/partitioned"

test("image datasource count test") {
val df1 = spark.read.format("image").load(imagePath)
assert(df1.count === 9)

val df2 = spark.read.format("image").option("dropInvalid", true).load(imagePath)
assert(df2.count === 8)
}

test("image datasource test: read jpg image") {
val df = spark.read.format("image").load(imagePath + "/cls=kittens/date=2018-02/DP153539.jpg")
assert(df.count() === 1)
}

test("image datasource test: read png image") {
val df = spark.read.format("image").load(imagePath + "/cls=multichannel/date=2018-01/BGRA.png")
assert(df.count() === 1)
}

test("image datasource test: read non image") {
val filePath = imagePath + "/cls=kittens/date=2018-01/not-image.txt"
val df = spark.read.format("image").option("dropInvalid", true)
.load(filePath)
assert(df.count() === 0)

val df2 = spark.read.format("image").option("dropInvalid", false)
.load(filePath)
assert(df2.count() === 1)
val result = df2.head()
assert(result === invalidImageRow(
Paths.get(filePath).toAbsolutePath().normalize().toUri().toString))
}

test("image datasource partition test") {
val result = spark.read.format("image")
.option("dropInvalid", true).load(imagePath)
.select(substring_index(col("image.origin"), "/", -1).as("origin"), col("cls"), col("date"))
.collect()

assert(Set(result: _*) === Set(
Row("29.5.a_b_EGDP022204.jpg", "kittens", "2018-01"),
Row("54893.jpg", "kittens", "2018-02"),
Row("DP153539.jpg", "kittens", "2018-02"),
Row("DP802813.jpg", "kittens", "2018-02"),
Row("BGRA.png", "multichannel", "2018-01"),
Row("BGRA_alpha_60.png", "multichannel", "2018-01"),
Row("chr30.4.184.jpg", "multichannel", "2018-02"),
Row("grayscale.jpg", "multichannel", "2018-02")
))
}

// Images with the different number of channels
test("readImages pixel values test") {
val images = spark.read.format("image").option("dropInvalid", true)
.load(imagePath + "/cls=multichannel/").collect()

val firstBytes20Set = images.map { rrow =>
val row = rrow.getAs[Row]("image")
val filename = Paths.get(getOrigin(row)).getFileName().toString()
val mode = getMode(row)
val bytes20 = getData(row).slice(0, 20).toList
filename -> Tuple2(mode, bytes20) // Cannot remove `Tuple2`, otherwise `->` operator
// will match 2 arguments
}.toSet

assert(firstBytes20Set === expectedFirstBytes20Set)
}

// number of channels and first 20 bytes of OpenCV representation
// - default representation for 3-channel RGB images is BGR row-wise:
// (B00, G00, R00, B10, G10, R10, ...)
// - default representation for 4-channel RGB images is BGRA row-wise:
// (B00, G00, R00, A00, B10, G10, R10, A10, ...)
private val expectedFirstBytes20Set = Set(
"grayscale.jpg" ->
((0, List[Byte](-2, -33, -61, -60, -59, -59, -64, -59, -66, -67, -73, -73, -62,
-57, -60, -63, -53, -49, -55, -69))),
"chr30.4.184.jpg" -> ((16,
List[Byte](-9, -3, -1, -43, -32, -28, -75, -60, -57, -78, -59, -56, -74, -59, -57,
-71, -58, -56, -73, -64))),
"BGRA.png" -> ((24,
List[Byte](-128, -128, -8, -1, -128, -128, -8, -1, -128,
-128, -8, -1, 127, 127, -9, -1, 127, 127, -9, -1))),
"BGRA_alpha_60.png" -> ((24,
List[Byte](-128, -128, -8, 60, -128, -128, -8, 60, -128,
-128, -8, 60, 127, 127, -9, 60, 127, 127, -9, 60)))
)
}
2 changes: 1 addition & 1 deletion python/pyspark/ml/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def readImages(self, path, recursive=False, numPartitions=-1,
:return: a :class:`DataFrame` with a single column of "images",
see ImageSchema for details.
>>> df = ImageSchema.readImages('data/mllib/images/kittens', recursive=True)
>>> df = ImageSchema.readImages('data/mllib/images/origin/kittens', recursive=True)
>>> df.count()
5
Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/ml/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2186,7 +2186,7 @@ def tearDown(self):
class ImageReaderTest(SparkSessionTestCase):

def test_read_images(self):
data_path = 'data/mllib/images/kittens'
data_path = 'data/mllib/images/origin/kittens'
df = ImageSchema.readImages(data_path, recursive=True, dropImageFailures=True)
self.assertEqual(df.count(), 4)
first_row = df.take(1)[0][0]
Expand Down Expand Up @@ -2253,7 +2253,7 @@ def tearDownClass(cls):
def test_read_images_multiple_times(self):
# This test case is to check if `ImageSchema.readImages` tries to
# initiate Hive client multiple times. See SPARK-22651.
data_path = 'data/mllib/images/kittens'
data_path = 'data/mllib/images/origin/kittens'
ImageSchema.readImages(data_path, recursive=True, dropImageFailures=True)
ImageSchema.readImages(data_path, recursive=True, dropImageFailures=True)

Expand Down

0 comments on commit 9254492

Please sign in to comment.