Skip to content

Commit

Permalink
fix(rust) fix SplitFields ignoring an empty final field (pola-rs#13542)
Browse files Browse the repository at this point in the history
  • Loading branch information
hamishs authored Jan 9, 2024
1 parent 607e6a2 commit b39ef9f
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 26 deletions.
8 changes: 6 additions & 2 deletions crates/polars-io/src/csv/splitfields.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,10 @@ mod inner {

#[inline]
fn next(&mut self) -> Option<(&'a [u8], bool)> {
if self.v.is_empty() || self.finished {
if self.finished {
return None;
} else if self.v.is_empty() {
return self.finish(false);
}

let mut needs_escaping = false;
Expand Down Expand Up @@ -214,8 +216,10 @@ mod inner {

#[inline]
fn next(&mut self) -> Option<(&'a [u8], bool)> {
if self.v.is_empty() || self.finished {
if self.finished {
return None;
} else if self.v.is_empty() {
return self.finish(false);
}

let mut needs_escaping = false;
Expand Down
10 changes: 2 additions & 8 deletions crates/polars-io/src/csv/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,16 +222,10 @@ pub fn infer_file_schema_inner(
}
final_headers
} else {
let mut column_names: Vec<String> = byterecord
byterecord
.enumerate()
.map(|(i, _s)| format!("column_{}", i + 1))
.collect();
// needed because SplitLines does not return the \n char, so SplitFields does not catch
// the latest value if ending with a separator.
if header_line.ends_with(&[separator]) {
column_names.push(format!("column_{}", column_names.len() + 1))
}
column_names
.collect::<Vec<String>>()
}
} else if has_header && !bytes.is_empty() && recursion_count == 0 {
// there was no new line char. So we copy the whole buf and add one
Expand Down
44 changes: 44 additions & 0 deletions crates/polars/tests/it/io/csv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -997,6 +997,50 @@ fn test_empty_string_cols() -> PolarsResult<()> {
Ok(())
}

#[test]
fn test_empty_col_names() -> PolarsResult<()> {
let csv = "a,b,c\n1,2,3";
let file = Cursor::new(csv);
let df = CsvReader::new(file).finish()?;
let expected = df![
"a" => [1i64],
"b" => [2i64],
"c" => [3i64]
]?;
assert!(df.equals(&expected));

let csv = "a,,c\n1,2,3";
let file = Cursor::new(csv);
let df = CsvReader::new(file).finish()?;
let expected = df![
"a" => [1i64],
"" => [2i64],
"c" => [3i64]
]?;
assert!(df.equals(&expected));

let csv = "a,b,\n1,2,3";
let file = Cursor::new(csv);
let df = CsvReader::new(file).finish()?;
let expected = df![
"a" => [1i64],
"b" => [2i64],
"" => [3i64]
]?;
assert!(df.equals(&expected));

let csv = "a,b,,\n1,2,3";
let file = Cursor::new(csv);
let df_result = CsvReader::new(file).finish()?;
assert_eq!(df_result.shape(), (1, 4));

let csv = "a,b\n1,2,3";
let file = Cursor::new(csv);
let df_result = CsvReader::new(file).finish();
assert!(df_result.is_err());
Ok(())
}

#[test]
fn test_trailing_empty_string_cols() -> PolarsResult<()> {
let csv = "colx\nabc\nxyz\n\"\"";
Expand Down
27 changes: 11 additions & 16 deletions py-polars/tests/unit/io/test_spreadsheet.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,22 +314,18 @@ def test_schema_overrides(path_xlsx: Path, path_xlsb: Path, path_ods: Path) -> N
sheet_name="test4",
schema_overrides={"cardinality": pl.UInt16},
).drop_nulls()
assert df1.schema == {
"cardinality": pl.UInt16,
"rows_by_key": pl.Float64,
"iter_groups": pl.Float64,
}
assert df1.schema["cardinality"] == pl.UInt16
assert df1.schema["rows_by_key"] == pl.Float64
assert df1.schema["iter_groups"] == pl.Float64

df2 = pl.read_excel(
path_xlsx,
sheet_name="test4",
read_csv_options={"dtypes": {"cardinality": pl.UInt16}},
).drop_nulls()
assert df2.schema == {
"cardinality": pl.UInt16,
"rows_by_key": pl.Float64,
"iter_groups": pl.Float64,
}
assert df2.schema["cardinality"] == pl.UInt16
assert df2.schema["rows_by_key"] == pl.Float64
assert df2.schema["iter_groups"] == pl.Float64

df3 = pl.read_excel(
path_xlsx,
Expand All @@ -342,11 +338,9 @@ def test_schema_overrides(path_xlsx: Path, path_xlsb: Path, path_ods: Path) -> N
},
},
).drop_nulls()
assert df3.schema == {
"cardinality": pl.UInt16,
"rows_by_key": pl.Float32,
"iter_groups": pl.Float32,
}
assert df3.schema["cardinality"] == pl.UInt16
assert df3.schema["rows_by_key"] == pl.Float32
assert df3.schema["iter_groups"] == pl.Float32

for workbook_path in (path_xlsx, path_xlsb, path_ods):
df4 = pl.read_excel(
Expand Down Expand Up @@ -392,7 +386,8 @@ def test_schema_overrides(path_xlsx: Path, path_xlsb: Path, path_ods: Path) -> N
sheet_name=["test4", "test4"],
schema_overrides=overrides,
)
assert df["test4"].schema == overrides
for col, dtype in overrides.items():
assert df["test4"].schema[col] == dtype


def test_unsupported_engine() -> None:
Expand Down

0 comments on commit b39ef9f

Please sign in to comment.