Skip to content

Commit

Permalink
Inline custom mapping function in _byte_pair_merge
Browse files Browse the repository at this point in the history
  • Loading branch information
Lőrinc authored and hauntsaninja committed Feb 9, 2024
1 parent b4c687e commit 6defed5
Showing 1 changed file with 16 additions and 17 deletions.
33 changes: 16 additions & 17 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,17 @@ use std::thread;
use fancy_regex::Regex;
use pyo3::exceptions;
use pyo3::prelude::*;
use pyo3::pyclass;
use pyo3::PyResult;
use pyo3::types::{PyBytes, PyList, PyTuple};
use rustc_hash::FxHashMap as HashMap;

type Rank = u32;

fn _byte_pair_merge<T>(
piece: &[u8],
fn _byte_pair_merge(
ranks: &HashMap<Vec<u8>, Rank>,
f: impl Fn(std::ops::Range<usize>) -> T,
) -> Vec<T> {
piece: &[u8],
) -> Vec<(usize, Rank)> {
// This is a vector of (start, rank).
// The rank is of the byte pair starting at position start.
// The rank of the last item in the vector is not a valid value.
Expand Down Expand Up @@ -93,25 +93,24 @@ fn _byte_pair_merge<T>(
break;
}
}
let mut out: Vec<T> = Vec::with_capacity(parts.len() - 1);
for i in 0..parts.len() - 1 {
out.push(f(parts[i].0..parts[i + 1].0));
}
out

parts
}

pub fn byte_pair_encode(piece: &[u8], ranks: &HashMap<Vec<u8>, Rank>) -> Vec<Rank> {
if piece.len() == 1 {
return vec![ranks[piece]];
}
_byte_pair_merge(piece, ranks, |p| ranks[&piece[p.start..p.end]])
assert!(piece.len() > 1);
_byte_pair_merge(&ranks, &piece)
.windows(2)
.map(|part| ranks[&piece[part[0].0..part[1].0]])
.collect()
}

pub fn byte_pair_split<'a>(piece: &'a [u8], ranks: &HashMap<Vec<u8>, Rank>) -> Vec<&'a [u8]> {
if piece.len() == 1 {
return vec![piece];
}
_byte_pair_merge(piece, ranks, |p| &piece[p.start..p.end])
assert!(piece.len() > 1);
_byte_pair_merge(&ranks, &piece)
.windows(2)
.map(|part| &piece[part[0].0..part[1].0])
.collect()
}

// Various performance notes:
Expand Down

0 comments on commit 6defed5

Please sign in to comment.