diff --git a/src/it/scala/com/databricks/spark/redshift/RedshiftIntegrationSuite.scala b/src/it/scala/com/databricks/spark/redshift/RedshiftIntegrationSuite.scala index 6555a9d2..95cf7eb3 100644 --- a/src/it/scala/com/databricks/spark/redshift/RedshiftIntegrationSuite.scala +++ b/src/it/scala/com/databricks/spark/redshift/RedshiftIntegrationSuite.scala @@ -186,6 +186,19 @@ class RedshiftIntegrationSuite extends IntegrationSuiteBase { ) } + test("backslashes in queries/subqueries are escaped (regression test for #215)") { + val loadedDf = sqlContext.read + .format("com.databricks.spark.redshift") + .option("url", jdbcUrl) + .option("query", s"select replace(teststring, '\\\\', '') as col from $test_table") + .option("tempdir", tempDir) + .load() + checkAnswer( + loadedDf.filter("col = 'asdf'"), + Seq(Row("asdf")) + ) + } + test("Can load output when 'dbtable' is a subquery wrapped in parentheses") { // scalastyle:off val query = diff --git a/src/main/scala/com/databricks/spark/redshift/RedshiftRelation.scala b/src/main/scala/com/databricks/spark/redshift/RedshiftRelation.scala index 4a52fe7a..cf7ede77 100644 --- a/src/main/scala/com/databricks/spark/redshift/RedshiftRelation.scala +++ b/src/main/scala/com/databricks/spark/redshift/RedshiftRelation.scala @@ -174,8 +174,8 @@ private[redshift] case class RedshiftRelation( val credsString: String = AWSCredentialsUtils.getRedshiftCredentialsString(params, creds) val query = { // Since the query passed to UNLOAD will be enclosed in single quotes, we need to escape - // any single quotes that appear in the query itself - val escapedTableNameOrSubqury = tableNameOrSubquery.replace("'", "\\'") + // any backslashes and single quotes that appear in the query itself + val escapedTableNameOrSubqury = tableNameOrSubquery.replace("\\", "\\\\").replace("'", "\\'") s"SELECT $columnList FROM $escapedTableNameOrSubqury $whereClause" } // We need to remove S3 credentials from the unload path URI because they will conflict with diff --git a/src/test/scala/com/databricks/spark/redshift/RedshiftSourceSuite.scala b/src/test/scala/com/databricks/spark/redshift/RedshiftSourceSuite.scala index 47a305d5..37a2d5d6 100644 --- a/src/test/scala/com/databricks/spark/redshift/RedshiftSourceSuite.scala +++ b/src/test/scala/com/databricks/spark/redshift/RedshiftSourceSuite.scala @@ -182,10 +182,10 @@ class RedshiftSourceSuite |UNLOAD \('SELECT "testbyte", "testbool" FROM | \(select testbyte, testbool | from test_table - | where teststring = \\'Unicode\\'\\'s樂趣\\'\) '\) + | where teststring = \\'\\\\\\\\Unicode\\'\\'s樂趣\\'\) '\) """.stripMargin.lines.map(_.trim).mkString(" ").trim.r val query = - """select testbyte, testbool from test_table where teststring = 'Unicode''s樂趣'""" + """select testbyte, testbool from test_table where teststring = '\\Unicode''s樂趣'""" // scalastyle:on val querySchema = StructType(Seq(StructField("testbyte", ByteType), StructField("testbool", BooleanType)))