Skip to content

Commit

Permalink
Accept unicode arguments at a csv.writer (fixes #2632).
Browse files Browse the repository at this point in the history
The CPython csv.writer accepts unicode strings and encodes them using
the current default encoding. This is not documented, but we can easily
reproduce the behaviour, which is relied on by some users. A simple
test_csv_jy is added for UTF-8 default. We hide sys.setdefaultencoding
again after use since this otherwise causes test_site to fail. The same
fault is corrected, where it was latent in test_unicode_jy.
  • Loading branch information
jeff5 committed Nov 21, 2017
1 parent c31766b commit 5717dd2
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 41 deletions.
96 changes: 96 additions & 0 deletions Lib/test/test_csv_jy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2017 Jython Developers

# Additional csv module unit tests for Jython

import csv
import io
import sys
from tempfile import TemporaryFile
from test import test_support
import unittest

# This test has been adapted from Python 3 test_csv.TestUnicode. In Python 3,
# the csv module supports Unicode directly. In Python 2, it does not, except
# that it is transparent to byte data. Many tools, however, accept UTF-8
# encoded text in a CSV file.
#
class EncodingContext(object):
"""Context manager to save and restore the encoding.
Use like this:
with EncodingContext("utf-8"):
self.assertEqual("'caf\xc3\xa9'", u"'caf\xe9'")
"""

def __init__(self, encoding):
if not hasattr(sys, "setdefaultencoding"):
reload(sys)
self.original_encoding = sys.getdefaultencoding()
sys.setdefaultencoding(encoding)

def __enter__(self):
return self

def __exit__(self, *ignore_exc):
sys.setdefaultencoding(self.original_encoding)

class TestUnicode(unittest.TestCase):

names = [u"Martin von Löwis",
u"Marc André Lemburg",
u"Guido van Rossum",
u"François Pinard",
u"稲田直樹"]

def test_decode_read(self):
# The user code receives byte data and takes care of the decoding
with TemporaryFile("w+b") as fileobj:
line = u",".join(self.names) + u"\r\n"
fileobj.write(line.encode('utf-8'))
fileobj.seek(0)
reader = csv.reader(fileobj)
# The reader yields rows of byte strings that decode to the data
table = [[e.decode('utf-8') for e in row] for row in reader]
self.assertEqual(table, [self.names])

def test_encode_write(self):
# The user encodes unicode objects to byte data that csv writes
with TemporaryFile("w+b") as fileobj:
writer = csv.writer(fileobj)
# We present a row of encoded strings to the writer
writer.writerow([n.encode('utf-8') for n in self.names])
# We expect the file contents to be the UTF-8 of the csv data
expected = u",".join(self.names) + u"\r\n"
fileobj.seek(0)
self.assertEqual(fileobj.read().decode('utf-8'), expected)

def test_unicode_write(self):
# The user supplies unicode data that csv.writer default-encodes
# (undocumented feature relied upon by client code).
# See Issue #2632 https://github.com/jythontools/jython/issues/90
with TemporaryFile("w+b") as fileobj:
with EncodingContext('utf-8'):
writer = csv.writer(fileobj)
# We present a row of unicode strings to the writer
writer.writerow(self.names)
# We expect the file contents to be the UTF-8 of the csv data
expected = u",".join(self.names) + u"\r\n"
fileobj.seek(0)
self.assertEqual(fileobj.read().decode(), expected)


def test_main():
# We'll be enabling sys.setdefaultencoding so remember to disable
had_set = hasattr(sys, "setdefaultencoding")
try:
test_support.run_unittest(
TestUnicode,
)
finally:
if not had_set:
delattr(sys, "setdefaultencoding")

if __name__ == "__main__":
test_main()
8 changes: 7 additions & 1 deletion Lib/test/test_unicode_jy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1341,7 +1341,10 @@ class DefaultDecodingCp850(DefaultDecodingTestCase):


def test_main():
test_support.run_unittest(
# We'll be enabling sys.setdefaultencoding so remember to disable
had_set = hasattr(sys, "setdefaultencoding")
try:
test_support.run_unittest(
UnicodeTestCase,
UnicodeIndexMixTest,
UnicodeFormatTestCase,
Expand All @@ -1353,6 +1356,9 @@ def test_main():
DefaultDecodingUTF8,
DefaultDecodingCp850,
)
finally:
if not had_set:
delattr(sys, "setdefaultencoding")


if __name__ == "__main__":
Expand Down
35 changes: 20 additions & 15 deletions src/org/python/modules/_csv/PyDialect.java
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright (c) Jython Developers */
/* Copyright (c)2017 Jython Developers */
package org.python.modules._csv;

import org.python.core.ArgParser;
Expand All @@ -9,6 +9,7 @@
import org.python.core.PyObject;
import org.python.core.PyString;
import org.python.core.PyType;
import org.python.core.PyUnicode;
import org.python.core.Untraversable;
import org.python.expose.ExposedDelete;
import org.python.expose.ExposedGet;
Expand Down Expand Up @@ -153,17 +154,21 @@ private static boolean toBool(String name, PyObject src, boolean dflt) {
private static char toChar(String name, PyObject src, char dflt) {
if (src == null) {
return dflt;
}
boolean isStr = Py.isInstance(src, PyString.TYPE);
if (src == Py.None || isStr && src.__len__() == 0) {
} else if (src == Py.None) {
return '\0';
} else if (!isStr || src.__len__() != 1) {
throw Py.TypeError(String.format("\"%s\" must be an 1-character string", name));
}
return src.toString().charAt(0);
} else if (src instanceof PyString) {
String s = (src instanceof PyUnicode) ? ((PyUnicode) src).encode() : src.toString();
if (s.length() == 0) {
return '\0';
} else if (s.length() == 1) {
return s.charAt(0);
}
}
// This is only going to work for BMP strings because of the char return type
throw Py.TypeError(String.format("\"%s\" must be a 1-character string", name));
}

private static int toInt(String name, PyObject src, int dflt) {
private static int toInt(String name, PyObject src, int dflt) {
if (src == null) {
return dflt;
}
Expand All @@ -176,14 +181,14 @@ private static int toInt(String name, PyObject src, int dflt) {
private static String toStr(String name, PyObject src, String dflt) {
if (src == null) {
return dflt;
}
if (src == Py.None) {
} else if (src == Py.None) {
return null;
} else if (src instanceof PyUnicode) {
return ((PyUnicode) src).encode().toString();
} else if (src instanceof PyString) {
return src.toString();
}
if (!(src instanceof PyBaseString)) {
throw Py.TypeError(String.format("\"%s\" must be an string", name));
}
return src.toString();
throw Py.TypeError(String.format("\"%s\" must be a string", name));
}

@ExposedGet(name = "escapechar")
Expand Down
48 changes: 23 additions & 25 deletions src/org/python/modules/_csv/PyWriter.java
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright (c) Jython Developers */
/* Copyright (c)2017 Jython Developers */
package org.python.modules._csv;

import org.python.core.Py;
Expand All @@ -7,6 +7,7 @@
import org.python.core.PyObject;
import org.python.core.PyString;
import org.python.core.PyType;
import org.python.core.PyUnicode;
import org.python.core.Traverseproc;
import org.python.core.Visitproc;
import org.python.expose.ExposedType;
Expand All @@ -21,11 +22,9 @@
@ExposedType(name = "_csv.writer", doc = PyWriter.writer_doc)
public class PyWriter extends PyObject implements Traverseproc {

public static final String writer_doc =
"CSV writer\n" +
"\n" +
"Writer objects are responsible for generating tabular data\n" +
"in CSV format from sequence input.\n";
public static final String writer_doc = "CSV writer\n\n"//
+ "Writer objects are responsible for generating tabular data\n"
+ "in CSV format from sequence input.\n";

public static final PyType TYPE = PyType.fromClass(PyWriter.class);

Expand Down Expand Up @@ -53,11 +52,10 @@ public PyWriter(PyObject writeline, PyDialect dialect) {
this.dialect = dialect;
}

public static PyString __doc__writerows = Py.newString(
"writerows(sequence of sequences)\n" +
"\n" +
"Construct and write a series of sequences to a csv file. Non-string\n" +
"elements will be converted to string.");
public static PyString __doc__writerows = Py.newString(//
"writerows(sequence of sequences)\n\n"
+ "Construct and write a series of sequences to a csv file. Non-string\n"
+ "elements will be converted to string.");

public void writerows(PyObject seqseq) {
writer_writerows(seqseq);
Expand All @@ -82,12 +80,10 @@ final void writer_writerows(PyObject seqseq) {
}
}

public static PyString __doc__writerow = Py.newString(
"writerow(sequence)\n" +
"\n" +
"Construct and write a CSV record from a sequence of fields. Non-string\n" +
"elements will be converted to string."
);
public static PyString __doc__writerow = Py.newString(//
"writerow(sequence)\n\n"
+ "Construct and write a CSV record from a sequence of fields. Non-string\n"
+ "elements will be converted to string.");

public boolean writerow(PyObject seq) {
return writer_writerow(seq);
Expand Down Expand Up @@ -134,14 +130,17 @@ final boolean writer_writerow(PyObject seq) {
quoted = false;
}

if (field instanceof PyString) {
if (field instanceof PyUnicode) {
// Unicode fields get the default encoding (must yield U16 bytes).
append_ok = join_append(((PyString) field).encode(), len == 1);
} else if (field instanceof PyString) {
// Not unicode, so must be U16 bytes.
append_ok = join_append(field.toString(), len == 1);
} else if (field == Py.None) {
append_ok = join_append("", len == 1);
} else {
PyObject str;
//XXX: in 3.x this check can go away and we can just always use
// __str__
// XXX: in 3.x this check can go away and we can just always use __str__
if (field.getClass() == PyFloat.class) {
str = field.__repr__();
} else {
Expand Down Expand Up @@ -195,9 +194,9 @@ private boolean join_append(String field, boolean quote_empty) {
}

/**
* This method behaves differently depending on the value of copy_phase: if copy_phase
* is false, then the method determines the new record length. If copy_phase is true
* then the new field is appended to the record.
* This method behaves differently depending on the value of copy_phase: if copy_phase is false,
* then the method determines the new record length. If copy_phase is true then the new field is
* appended to the record.
*/
private int join_append_data(String field, boolean quote_empty, boolean copy_phase) {
int i;
Expand Down Expand Up @@ -225,7 +224,7 @@ private int join_append_data(String field, boolean quote_empty, boolean copy_pha
break;
}
if (c == dialect.delimiter || c == dialect.escapechar || c == dialect.quotechar
|| dialect.lineterminator.indexOf(c) > -1) {
|| dialect.lineterminator.indexOf(c) > -1) {
if (dialect.quoting == QuoteStyle.QUOTE_NONE) {
want_escape = true;
} else {
Expand Down Expand Up @@ -282,7 +281,6 @@ private void addChar(char c, boolean copy_phase) {
rec_len++;
}


/* Traverseproc implementation */
@Override
public int traverse(Visitproc visit, Object arg) {
Expand Down

0 comments on commit 5717dd2

Please sign in to comment.