Skip to content

ARM-NEON #618

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 152 additions & 42 deletions html5ever/src/tokenizer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -706,11 +706,11 @@ impl<Sink: TokenSink> Tokenizer<Sink> {
states::Data => loop {
let set = small_char_set!('\r' '\0' '&' '<' '\n');

#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(any(target_arch = "x86", target_arch = "x86_64", target_arch = "aarch64"))]
let set_result = if !(self.opts.exact_errors
|| self.reconsume.get()
|| self.ignore_lf.get())
&& is_x86_feature_detected!("sse2")
&& Self::is_supported_simd_feature_detected()
{
let front_buffer = input.peek_front_chunk_mut();
let Some(mut front_buffer) = front_buffer else {
Expand All @@ -729,8 +729,8 @@ impl<Sink: TokenSink> Tokenizer<Sink> {
self.pop_except_from(input, set)
} else {
// SAFETY:
// This CPU is guaranteed to support SSE2 due to the is_x86_feature_detected check above
let result = unsafe { self.data_state_sse2_fast_path(&mut front_buffer) };
// This CPU is guaranteed to support SIMD due to the is_supported_simd_feature_detected check above
let result = unsafe { self.data_state_simd_fast_path(&mut front_buffer) };

if front_buffer.is_empty() {
drop(front_buffer);
Expand All @@ -743,7 +743,11 @@ impl<Sink: TokenSink> Tokenizer<Sink> {
self.pop_except_from(input, set)
};

#[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
#[cfg(not(any(
target_arch = "x86",
target_arch = "x86_64",
target_arch = "aarch64"
)))]
let set_result = self.pop_except_from(input, set);

let Some(set_result) = set_result else {
Expand Down Expand Up @@ -1885,18 +1889,90 @@ impl<Sink: TokenSink> Tokenizer<Sink> {
}
}

#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[target_feature(enable = "sse2")]
/// Checks for supported SIMD feature, which is now either SSE2 for x86/x86_64 or NEON for aarch64.
fn is_supported_simd_feature_detected() -> bool {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
is_x86_feature_detected!("sse2")
}

#[cfg(target_arch = "aarch64")]
{
std::arch::is_aarch64_feature_detected!("neon")
}

#[cfg(not(any(target_arch = "x86", target_arch = "x86_64", target_arch = "aarch64")))]
false
}

#[cfg(any(target_arch = "x86", target_arch = "x86_64", target_arch = "aarch64"))]
/// Implements the [data state] with SIMD instructions.
/// Calls SSE2- or NEON-specific function for chunks and processes any remaining bytes.
///
/// The algorithm implemented is the naive SIMD approach described [here].
///
/// ### SAFETY:
/// Calling this function on a CPU that does not support SSE2 causes undefined behaviour.
/// Calling this function on a CPU that supports neither SSE2 nor NEON causes undefined behaviour.
///
/// [data state]: https://html.spec.whatwg.org/#data-state
/// [here]: https://lemire.me/blog/2024/06/08/scan-html-faster-with-simd-instructions-chrome-edition/
unsafe fn data_state_sse2_fast_path(&self, input: &mut StrTendril) -> Option<SetResult> {
unsafe fn data_state_simd_fast_path(&self, input: &mut StrTendril) -> Option<SetResult> {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
let (mut i, mut n_newlines) = self.data_state_sse2_fast_path(input);

#[cfg(target_arch = "aarch64")]
let (mut i, mut n_newlines) = self.data_state_neon_fast_path(input);

// Process any remaining bytes (less than STRIDE)
while let Some(c) = input.as_bytes().get(i) {
if matches!(*c, b'<' | b'&' | b'\r' | b'\0') {
break;
}
if *c == b'\n' {
n_newlines += 1;
}

i += 1;
}

let set_result = if i == 0 {
let first_char = input.pop_front_char().unwrap();
debug_assert!(matches!(first_char, '<' | '&' | '\r' | '\0'));

// FIXME: Passing a bogus input queue is only relevant when c is \n, which can never happen in this case.
// Still, it would be nice to not have to do that.
// The same is true for the unwrap call.
let preprocessed_char = self
.get_preprocessed_char(first_char, &BufferQueue::default())
.unwrap();
SetResult::FromSet(preprocessed_char)
} else {
debug_assert!(
input.len() >= i,
"Trying to remove {:?} bytes from a tendril that is only {:?} bytes long",
i,
input.len()
);
let consumed_chunk = input.unsafe_subtendril(0, i as u32);
input.unsafe_pop_front(i as u32);
SetResult::NotFromSet(consumed_chunk)
};

self.current_line.set(self.current_line.get() + n_newlines);

Some(set_result)
}

#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[target_feature(enable = "sse2")]
/// Implements the [data state] with SSE2 instructions for x86/x86_64.
/// Returns a pair of the number of bytes processed and the number of newlines found.
///
/// ### SAFETY:
/// Calling this function on a CPU that does not support NEON causes undefined behaviour.
///
/// [data state]: https://html.spec.whatwg.org/#data-state
unsafe fn data_state_sse2_fast_path(&self, input: &mut StrTendril) -> (usize, u64) {
#[cfg(target_arch = "x86")]
use std::arch::x86::{
__m128i, _mm_cmpeq_epi8, _mm_loadu_si128, _mm_movemask_epi8, _mm_or_si128,
Expand Down Expand Up @@ -1960,44 +2036,78 @@ impl<Sink: TokenSink> Tokenizer<Sink> {
i += STRIDE;
}

// Process any remaining bytes (less than STRIDE)
while let Some(c) = raw_bytes.get(i) {
if matches!(*c, b'<' | b'&' | b'\r' | b'\0') {
break;
}
if *c == b'\n' {
n_newlines += 1;
}
(i, n_newlines)
}

i += 1;
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
/// Implements the [data state] with NEON SIMD instructions for AArch64.
/// Returns a pair of the number of bytes processed and the number of newlines found.
///
/// ### SAFETY:
/// Calling this function on a CPU that does not support NEON causes undefined behaviour.
///
/// [data state]: https://html.spec.whatwg.org/#data-state
unsafe fn data_state_neon_fast_path(&self, input: &mut StrTendril) -> (usize, u64) {
use std::arch::aarch64::{vceqq_u8, vdupq_n_u8, vld1q_u8, vmaxvq_u8, vorrq_u8};

let set_result = if i == 0 {
let first_char = input.pop_front_char().unwrap();
debug_assert!(matches!(first_char, '<' | '&' | '\r' | '\0'));
debug_assert!(!input.is_empty());

// FIXME: Passing a bogus input queue is only relevant when c is \n, which can never happen in this case.
// Still, it would be nice to not have to do that.
// The same is true for the unwrap call.
let preprocessed_char = self
.get_preprocessed_char(first_char, &BufferQueue::default())
.unwrap();
SetResult::FromSet(preprocessed_char)
} else {
debug_assert!(
input.len() >= i,
"Trying to remove {:?} bytes from a tendril that is only {:?} bytes long",
i,
input.len()
);
let consumed_chunk = input.unsafe_subtendril(0, i as u32);
input.unsafe_pop_front(i as u32);
SetResult::NotFromSet(consumed_chunk)
};
let quote_mask = vdupq_n_u8(b'<');
let escape_mask = vdupq_n_u8(b'&');
let carriage_return_mask = vdupq_n_u8(b'\r');
let zero_mask = vdupq_n_u8(b'\0');
let newline_mask = vdupq_n_u8(b'\n');

self.current_line.set(self.current_line.get() + n_newlines);
let raw_bytes: &[u8] = input.as_bytes();
let start = raw_bytes.as_ptr();

Some(set_result)
const STRIDE: usize = 16;
let mut i = 0;
let mut n_newlines = 0;
while i + STRIDE <= raw_bytes.len() {
// Load a 16 byte chunk from the input
let data = vld1q_u8(start.add(i));

// Compare the chunk against each mask
let quotes = vceqq_u8(data, quote_mask);
let escapes = vceqq_u8(data, escape_mask);
let carriage_returns = vceqq_u8(data, carriage_return_mask);
let zeros = vceqq_u8(data, zero_mask);
let newlines = vceqq_u8(data, newline_mask);

// Combine all test results and create a bitmask from them.
// Each bit in the mask will be 1 if the character at the bit position is in the set and 0 otherwise.
let test_result =
vorrq_u8(vorrq_u8(quotes, zeros), vorrq_u8(escapes, carriage_returns));
let bitmask = vmaxvq_u8(test_result);
let newline_mask = vmaxvq_u8(newlines);
if bitmask != 0 {
// We have reached one of the characters that cause the state machine to transition
let chunk_bytes = std::slice::from_raw_parts(start.add(i), STRIDE);
let position = chunk_bytes
.iter()
.position(|&b| matches!(b, b'<' | b'&' | b'\r' | b'\0'))
.unwrap();

n_newlines += chunk_bytes[..position]
.iter()
.filter(|&&b| b == b'\n')
.count() as u64;

i += position;
break;
} else {
if newline_mask != 0 {
let chunk_bytes = std::slice::from_raw_parts(start.add(i), STRIDE);
n_newlines += chunk_bytes.iter().filter(|&&b| b == b'\n').count() as u64;
}
}

i += STRIDE;
}

(i, n_newlines)
}
}

Expand Down