Files
zterm/src/simd_utf8.rs
T
2025-12-22 00:22:55 +01:00

1477 lines
56 KiB
Rust

//! SIMD-optimized string operations based on Kitty's implementation.
//!
//! This module provides high-performance SIMD-accelerated operations:
//! - UTF-8 decoder (16-byte SSE or 32-byte AVX2 chunks)
//! - Byte search functions (find_either_of_two_bytes, find_c0_control)
//! - XOR masking for WebSocket frames
//!
//! The UTF-8 algorithm is based on the blog post:
//! https://woboq.com/blog/utf-8-processing-using-simd.html
//! and Kitty's implementation in simd-string-impl.h
// Allow unsafe operations within unsafe functions without additional blocks.
// This code is ported from C and follows the same patterns as Kitty's implementation.
#![allow(unsafe_op_in_unsafe_fn)]
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
#[cfg(target_arch = "x86")]
use std::arch::x86::*;
// ============================================================================
// SIMD Feature Detection and Dispatch
// ============================================================================
/// Cached SIMD capability flags for runtime dispatch.
#[derive(Clone, Copy)]
pub struct SimdCapabilities {
pub has_sse41: bool,
pub has_ssse3: bool,
pub has_avx2: bool,
}
impl SimdCapabilities {
/// Detect SIMD capabilities at runtime.
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
pub fn detect() -> Self {
Self {
has_sse41: is_x86_feature_detected!("sse4.1"),
has_ssse3: is_x86_feature_detected!("ssse3"),
has_avx2: is_x86_feature_detected!("avx2"),
}
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "x86")))]
pub fn detect() -> Self {
Self {
has_sse41: false,
has_ssse3: false,
has_avx2: false,
}
}
}
// Global cached capabilities (initialized on first use)
static SIMD_CAPS: std::sync::OnceLock<SimdCapabilities> = std::sync::OnceLock::new();
/// Get cached SIMD capabilities.
pub fn simd_caps() -> &'static SimdCapabilities {
SIMD_CAPS.get_or_init(SimdCapabilities::detect)
}
// ============================================================================
// Byte Search Functions (like Kitty's find_either_of_two_bytes)
// ============================================================================
/// Find the first occurrence of a single byte in the haystack.
/// Returns the index of the first match, or None if not found.
///
/// This is a SIMD-accelerated alternative to `memchr::memchr`.
#[inline]
pub fn find_byte(haystack: &[u8], needle: u8) -> Option<usize> {
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
{
let caps = simd_caps();
if caps.has_avx2 {
// SAFETY: We checked for AVX2 support
return unsafe { find_byte_avx2(haystack, needle) };
}
if caps.has_sse41 {
// SAFETY: We checked for SSE4.1 support
return unsafe { find_byte_sse(haystack, needle) };
}
}
haystack.iter().position(|&b| b == needle)
}
/// SSE implementation of find_byte.
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "sse2", enable = "sse4.1")]
unsafe fn find_byte_sse(haystack: &[u8], needle: u8) -> Option<usize> {
let needle_vec = _mm_set1_epi8(needle as i8);
let mut offset = 0;
let len = haystack.len();
let ptr = haystack.as_ptr();
while offset + 16 <= len {
let chunk = _mm_loadu_si128(ptr.add(offset) as *const __m128i);
let cmp = _mm_cmpeq_epi8(chunk, needle_vec);
let mask = _mm_movemask_epi8(cmp) as u32;
if mask != 0 {
return Some(offset + mask.trailing_zeros() as usize);
}
offset += 16;
}
for i in offset..len {
if *ptr.add(i) == needle {
return Some(i);
}
}
None
}
/// AVX2 implementation of find_byte.
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "avx2")]
unsafe fn find_byte_avx2(haystack: &[u8], needle: u8) -> Option<usize> {
let needle_vec = _mm256_set1_epi8(needle as i8);
let mut offset = 0;
let len = haystack.len();
let ptr = haystack.as_ptr();
while offset + 32 <= len {
let chunk = _mm256_loadu_si256(ptr.add(offset) as *const __m256i);
let cmp = _mm256_cmpeq_epi8(chunk, needle_vec);
let mask = _mm256_movemask_epi8(cmp) as u32;
if mask != 0 {
return Some(offset + mask.trailing_zeros() as usize);
}
offset += 32;
}
// Handle remainder with SSE
while offset + 16 <= len {
let chunk = _mm_loadu_si128(ptr.add(offset) as *const __m128i);
let needle_vec_128 = _mm_set1_epi8(needle as i8);
let cmp = _mm_cmpeq_epi8(chunk, needle_vec_128);
let mask = _mm_movemask_epi8(cmp) as u32;
if mask != 0 {
return Some(offset + mask.trailing_zeros() as usize);
}
offset += 16;
}
for i in offset..len {
if *ptr.add(i) == needle {
return Some(i);
}
}
None
}
/// Find the first occurrence of either byte `a` or byte `b` in the haystack.
/// Returns the index of the first match, or None if not found.
///
/// This is equivalent to Kitty's `find_either_of_two_bytes` function.
#[inline]
pub fn find_either_of_two_bytes(haystack: &[u8], a: u8, b: u8) -> Option<usize> {
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
{
let caps = simd_caps();
if caps.has_avx2 {
// SAFETY: We checked for AVX2 support
return unsafe { find_either_of_two_bytes_avx2(haystack, a, b) };
}
if caps.has_sse41 {
// SAFETY: We checked for SSE4.1 support
return unsafe { find_either_of_two_bytes_sse(haystack, a, b) };
}
}
find_either_of_two_bytes_scalar(haystack, a, b)
}
/// Scalar fallback for find_either_of_two_bytes.
#[inline]
fn find_either_of_two_bytes_scalar(haystack: &[u8], a: u8, b: u8) -> Option<usize> {
haystack.iter().position(|&byte| byte == a || byte == b)
}
/// SSE implementation of find_either_of_two_bytes.
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "sse2", enable = "sse4.1")]
unsafe fn find_either_of_two_bytes_sse(haystack: &[u8], a: u8, b: u8) -> Option<usize> {
let a_vec = _mm_set1_epi8(a as i8);
let b_vec = _mm_set1_epi8(b as i8);
let mut offset = 0;
let len = haystack.len();
let ptr = haystack.as_ptr();
// Process 16 bytes at a time
while offset + 16 <= len {
let chunk = _mm_loadu_si128(ptr.add(offset) as *const __m128i);
let cmp_a = _mm_cmpeq_epi8(chunk, a_vec);
let cmp_b = _mm_cmpeq_epi8(chunk, b_vec);
let combined = _mm_or_si128(cmp_a, cmp_b);
let mask = _mm_movemask_epi8(combined) as u32;
if mask != 0 {
return Some(offset + mask.trailing_zeros() as usize);
}
offset += 16;
}
// Handle remainder with scalar
for i in offset..len {
let byte = *ptr.add(i);
if byte == a || byte == b {
return Some(i);
}
}
None
}
/// AVX2 implementation of find_either_of_two_bytes.
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "avx2")]
unsafe fn find_either_of_two_bytes_avx2(haystack: &[u8], a: u8, b: u8) -> Option<usize> {
let a_vec = _mm256_set1_epi8(a as i8);
let b_vec = _mm256_set1_epi8(b as i8);
let mut offset = 0;
let len = haystack.len();
let ptr = haystack.as_ptr();
// Process 32 bytes at a time
while offset + 32 <= len {
let chunk = _mm256_loadu_si256(ptr.add(offset) as *const __m256i);
let cmp_a = _mm256_cmpeq_epi8(chunk, a_vec);
let cmp_b = _mm256_cmpeq_epi8(chunk, b_vec);
let combined = _mm256_or_si256(cmp_a, cmp_b);
let mask = _mm256_movemask_epi8(combined) as u32;
if mask != 0 {
return Some(offset + mask.trailing_zeros() as usize);
}
offset += 32;
}
// Handle remainder with SSE (16 bytes)
while offset + 16 <= len {
let chunk = _mm_loadu_si128(ptr.add(offset) as *const __m128i);
let a_vec_128 = _mm_set1_epi8(a as i8);
let b_vec_128 = _mm_set1_epi8(b as i8);
let cmp_a = _mm_cmpeq_epi8(chunk, a_vec_128);
let cmp_b = _mm_cmpeq_epi8(chunk, b_vec_128);
let combined = _mm_or_si128(cmp_a, cmp_b);
let mask = _mm_movemask_epi8(combined) as u32;
if mask != 0 {
return Some(offset + mask.trailing_zeros() as usize);
}
offset += 16;
}
// Handle remainder with scalar
for i in offset..len {
let byte = *ptr.add(i);
if byte == a || byte == b {
return Some(i);
}
}
None
}
// ============================================================================
// C0 Control Character Detection (like Kitty's IndexC0)
// ============================================================================
/// Find the first C0 control character (byte < 0x20 or byte == 0x7F).
/// Returns the index of the first match, or None if not found.
///
/// This is equivalent to Kitty's `IndexC0` function in the simdstring package.
#[inline]
pub fn find_c0_control(haystack: &[u8]) -> Option<usize> {
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
{
let caps = simd_caps();
if caps.has_avx2 {
// SAFETY: We checked for AVX2 support
return unsafe { find_c0_control_avx2(haystack) };
}
if caps.has_sse41 {
// SAFETY: We checked for SSE4.1 support
return unsafe { find_c0_control_sse(haystack) };
}
}
find_c0_control_scalar(haystack)
}
/// Scalar fallback for find_c0_control.
#[inline]
fn find_c0_control_scalar(haystack: &[u8]) -> Option<usize> {
haystack.iter().position(|&byte| byte < 0x20 || byte == 0x7F)
}
/// SSE implementation of find_c0_control.
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "sse2", enable = "sse4.1")]
unsafe fn find_c0_control_sse(haystack: &[u8]) -> Option<usize> {
// C0 control chars are: 0x00-0x1F and 0x7F
// Strategy: (byte < 0x20) || (byte == 0x7F)
// For the < 0x20 check, we use saturating subtraction: if byte < 0x20, then (byte - 0x20) saturates to 0
// Or we can use: byte + 0x80 (wrapping) < 0x20 + 0x80 = 0xA0 in unsigned = -96 in signed
let threshold = _mm_set1_epi8(-96i8); // 0x20 - 0x80 = -96 in signed
let bias = _mm_set1_epi8(-128i8); // 0x80 as i8
let del = _mm_set1_epi8(0x7F);
let mut offset = 0;
let len = haystack.len();
let ptr = haystack.as_ptr();
while offset + 16 <= len {
let chunk = _mm_loadu_si128(ptr.add(offset) as *const __m128i);
// Convert to signed range for comparison: chunk_signed = chunk + 0x80 (wrapping)
// This maps 0x00 -> 0x80 (-128), 0x7F -> 0xFF (-1), 0x80 -> 0x00, etc.
let chunk_signed = _mm_add_epi8(chunk, bias);
// Check chunk_signed < threshold (equivalent to chunk < 0x20)
let lt_20 = _mm_cmplt_epi8(chunk_signed, threshold);
// Check chunk == 0x7F
let eq_7f = _mm_cmpeq_epi8(chunk, del);
// Combine
let combined = _mm_or_si128(lt_20, eq_7f);
let mask = _mm_movemask_epi8(combined) as u32;
if mask != 0 {
return Some(offset + mask.trailing_zeros() as usize);
}
offset += 16;
}
// Handle remainder
for i in offset..len {
let byte = *ptr.add(i);
if byte < 0x20 || byte == 0x7F {
return Some(i);
}
}
None
}
/// AVX2 implementation of find_c0_control.
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "avx2")]
unsafe fn find_c0_control_avx2(haystack: &[u8]) -> Option<usize> {
let threshold = _mm256_set1_epi8(-96i8); // 0x20 - 0x80 = -96 in signed
let bias = _mm256_set1_epi8(-128i8); // 0x80 as i8
let del = _mm256_set1_epi8(0x7F);
let mut offset = 0;
let len = haystack.len();
let ptr = haystack.as_ptr();
while offset + 32 <= len {
let chunk = _mm256_loadu_si256(ptr.add(offset) as *const __m256i);
let chunk_signed = _mm256_add_epi8(chunk, bias);
let lt_20 = _mm256_cmpgt_epi8(threshold, chunk_signed);
let eq_7f = _mm256_cmpeq_epi8(chunk, del);
let combined = _mm256_or_si256(lt_20, eq_7f);
let mask = _mm256_movemask_epi8(combined) as u32;
if mask != 0 {
return Some(offset + mask.trailing_zeros() as usize);
}
offset += 32;
}
// Handle remainder with SSE path
while offset + 16 <= len {
let chunk = _mm_loadu_si128(ptr.add(offset) as *const __m128i);
let threshold_128 = _mm_set1_epi8(-96i8);
let bias_128 = _mm_set1_epi8(-128i8);
let del_128 = _mm_set1_epi8(0x7F);
let chunk_signed = _mm_add_epi8(chunk, bias_128);
let lt_20 = _mm_cmplt_epi8(chunk_signed, threshold_128);
let eq_7f = _mm_cmpeq_epi8(chunk, del_128);
let combined = _mm_or_si128(lt_20, eq_7f);
let mask = _mm_movemask_epi8(combined) as u32;
if mask != 0 {
return Some(offset + mask.trailing_zeros() as usize);
}
offset += 16;
}
// Handle remainder
for i in offset..len {
let byte = *ptr.add(i);
if byte < 0x20 || byte == 0x7F {
return Some(i);
}
}
None
}
// ============================================================================
// XOR Data for WebSocket Masking (like Kitty's xor_data64)
// ============================================================================
/// XOR data with a 4-byte mask (WebSocket frame masking).
/// The mask is applied cyclically starting from the given offset.
/// Returns the new mask offset after processing.
///
/// This is equivalent to Kitty's `xor_data64` function but optimized for
/// the standard 4-byte WebSocket mask.
#[inline]
pub fn xor_mask(data: &mut [u8], mask: [u8; 4], start_offset: usize) -> usize {
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
{
let caps = simd_caps();
if caps.has_avx2 && data.len() >= 32 {
// SAFETY: We checked for AVX2 support
return unsafe { xor_mask_avx2(data, mask, start_offset) };
}
if caps.has_sse41 && data.len() >= 16 {
// SAFETY: We checked for SSE4.1 support
return unsafe { xor_mask_sse(data, mask, start_offset) };
}
}
xor_mask_scalar(data, mask, start_offset)
}
/// Scalar fallback for xor_mask.
#[inline]
fn xor_mask_scalar(data: &mut [u8], mask: [u8; 4], start_offset: usize) -> usize {
let mut offset = start_offset;
for byte in data.iter_mut() {
*byte ^= mask[offset & 3];
offset += 1;
}
offset & 3
}
/// SSE implementation of xor_mask.
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "sse2")]
unsafe fn xor_mask_sse(data: &mut [u8], mask: [u8; 4], start_offset: usize) -> usize {
let len = data.len();
let ptr = data.as_mut_ptr();
let mut pos = 0;
let mut offset = start_offset;
// Handle unaligned prefix to get to mask-aligned position
while pos < len && (offset & 3) != 0 {
*ptr.add(pos) ^= mask[offset & 3];
pos += 1;
offset += 1;
}
// Create 16-byte mask vector (repeat 4-byte mask 4 times)
let mask_vec = _mm_set_epi8(
mask[3] as i8, mask[2] as i8, mask[1] as i8, mask[0] as i8,
mask[3] as i8, mask[2] as i8, mask[1] as i8, mask[0] as i8,
mask[3] as i8, mask[2] as i8, mask[1] as i8, mask[0] as i8,
mask[3] as i8, mask[2] as i8, mask[1] as i8, mask[0] as i8,
);
// Process 16 bytes at a time
while pos + 16 <= len {
let chunk = _mm_loadu_si128(ptr.add(pos) as *const __m128i);
let xored = _mm_xor_si128(chunk, mask_vec);
_mm_storeu_si128(ptr.add(pos) as *mut __m128i, xored);
pos += 16;
offset += 16;
}
// Handle remainder
while pos < len {
*ptr.add(pos) ^= mask[offset & 3];
pos += 1;
offset += 1;
}
offset & 3
}
/// AVX2 implementation of xor_mask.
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "avx2")]
unsafe fn xor_mask_avx2(data: &mut [u8], mask: [u8; 4], start_offset: usize) -> usize {
let len = data.len();
let ptr = data.as_mut_ptr();
let mut pos = 0;
let mut offset = start_offset;
// Handle unaligned prefix
while pos < len && (offset & 3) != 0 {
*ptr.add(pos) ^= mask[offset & 3];
pos += 1;
offset += 1;
}
// Create 32-byte mask vector (repeat 4-byte mask 8 times)
let mask_vec = _mm256_set_epi8(
mask[3] as i8, mask[2] as i8, mask[1] as i8, mask[0] as i8,
mask[3] as i8, mask[2] as i8, mask[1] as i8, mask[0] as i8,
mask[3] as i8, mask[2] as i8, mask[1] as i8, mask[0] as i8,
mask[3] as i8, mask[2] as i8, mask[1] as i8, mask[0] as i8,
mask[3] as i8, mask[2] as i8, mask[1] as i8, mask[0] as i8,
mask[3] as i8, mask[2] as i8, mask[1] as i8, mask[0] as i8,
mask[3] as i8, mask[2] as i8, mask[1] as i8, mask[0] as i8,
mask[3] as i8, mask[2] as i8, mask[1] as i8, mask[0] as i8,
);
// Process 32 bytes at a time
while pos + 32 <= len {
let chunk = _mm256_loadu_si256(ptr.add(pos) as *const __m256i);
let xored = _mm256_xor_si256(chunk, mask_vec);
_mm256_storeu_si256(ptr.add(pos) as *mut __m256i, xored);
pos += 32;
offset += 32;
}
// Process 16 bytes if remaining
while pos + 16 <= len {
let mask_vec_128 = _mm_set_epi8(
mask[3] as i8, mask[2] as i8, mask[1] as i8, mask[0] as i8,
mask[3] as i8, mask[2] as i8, mask[1] as i8, mask[0] as i8,
mask[3] as i8, mask[2] as i8, mask[1] as i8, mask[0] as i8,
mask[3] as i8, mask[2] as i8, mask[1] as i8, mask[0] as i8,
);
let chunk = _mm_loadu_si128(ptr.add(pos) as *const __m128i);
let xored = _mm_xor_si128(chunk, mask_vec_128);
_mm_storeu_si128(ptr.add(pos) as *mut __m128i, xored);
pos += 16;
offset += 16;
}
// Handle remainder
while pos < len {
*ptr.add(pos) ^= mask[offset & 3];
pos += 1;
offset += 1;
}
offset & 3
}
// ============================================================================
// UTF-8 Decoder State and Tables
// ============================================================================
/// UTF-8 decoder state for handling sequences that span chunks.
#[derive(Debug, Default, Clone)]
pub struct Utf8State {
pub cur: u8,
pub prev: u8,
pub codep: u32,
}
const UTF8_ACCEPT: u8 = 0;
const UTF8_REJECT: u8 = 12;
/// UTF-8 state transition table (Bjoern Hoehrmann's DFA).
static UTF8_DECODE_TABLE: [u8; 364] = [
// Character class lookup (0-255)
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,
7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7, 7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,
8,8,2,2,2,2,2,2,2,2,2,2,2,2,2,2, 2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,
10,3,3,3,3,3,3,3,3,3,3,3,3,4,3,3, 11,6,6,6,5,8,8,8,8,8,8,8,8,8,8,8,
// State transition table
0,12,24,36,60,96,84,12,12,12,48,72, 12,12,12,12,12,12,12,12,12,12,12,12,
12, 0,12,12,12,12,12, 0,12, 0,12,12, 12,24,12,12,12,12,12,24,12,24,12,12,
12,12,12,12,12,12,12,24,12,12,12,12, 12,24,12,12,12,12,12,12,12,24,12,12,
12,12,12,12,12,12,12,36,12,36,12,12, 12,36,12,12,12,12,12,36,12,36,12,12,
12,36,12,12,12,12,12,12,12,12,12,12,
];
/// Decode a single UTF-8 byte using DFA.
#[inline(always)]
fn decode_utf8_byte(state: &mut u8, codep: &mut u32, byte: u8) -> u8 {
let char_class = UTF8_DECODE_TABLE[byte as usize];
*codep = if *state == UTF8_ACCEPT {
(0xFF >> char_class) as u32 & byte as u32
} else {
(byte as u32 & 0x3F) | (*codep << 6)
};
*state = UTF8_DECODE_TABLE[256 + *state as usize + char_class as usize];
*state
}
/// SIMD UTF-8 decoder.
///
/// Processes input in 16-byte (SSE) or 32-byte (AVX2) chunks, using SIMD for:
/// - Fast ESC (0x1B) detection
/// - Pure ASCII fast path
/// - Parallel UTF-8 validation and decoding
#[derive(Debug, Default)]
pub struct SimdUtf8Decoder {
pub state: Utf8State,
}
impl SimdUtf8Decoder {
pub fn new() -> Self {
Self::default()
}
pub fn reset(&mut self) {
self.state = Utf8State::default();
}
/// Decode UTF-8 bytes until ESC is found.
/// Returns (bytes_consumed, found_esc).
///
/// Output codepoints are written to the output buffer as u32 values.
/// Uses AVX2 (32 bytes at a time) if available, otherwise SSE (16 bytes).
#[inline]
pub fn decode_to_esc(&mut self, src: &[u8], output: &mut Vec<u32>) -> (usize, bool) {
output.clear();
if src.is_empty() {
return (0, false);
}
output.reserve(src.len());
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
{
let caps = simd_caps();
// TODO: AVX2 decoder would go here when implemented
// For now, AVX2 is used for the byte search functions
if caps.has_sse41 && caps.has_ssse3 {
// SAFETY: We checked for required SIMD support
return unsafe { self.decode_to_esc_simd(src, output) };
}
}
// Fallback to scalar
self.decode_to_esc_scalar(src, output)
}
/// Scalar fallback decoder.
fn decode_to_esc_scalar(&mut self, src: &[u8], output: &mut Vec<u32>) -> (usize, bool) {
let mut pos = 0;
while pos < src.len() {
let byte = src[pos];
if byte == 0x1B {
if self.state.cur != UTF8_ACCEPT {
output.push(0xFFFD);
self.state = Utf8State::default();
}
return (pos + 1, true);
}
pos += 1;
self.state.prev = self.state.cur;
match decode_utf8_byte(&mut self.state.cur, &mut self.state.codep, byte) {
UTF8_ACCEPT => {
output.push(self.state.codep);
}
UTF8_REJECT => {
output.push(0xFFFD);
let was_accept = self.state.prev == UTF8_ACCEPT;
self.state = Utf8State::default();
if !was_accept {
pos -= 1;
}
}
_ => {}
}
}
(pos, false)
}
/// SIMD decoder - processes 16 bytes at a time.
/// Based on Kitty's simd-string-impl.h
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "sse2", enable = "ssse3", enable = "sse4.1")]
unsafe fn decode_to_esc_simd(&mut self, src: &[u8], output: &mut Vec<u32>) -> (usize, bool) {
let mut num_consumed: usize = 0;
// Finish any trailing sequence from previous call
if self.state.cur != UTF8_ACCEPT {
num_consumed = self.scalar_decode_to_accept(src, output);
if num_consumed >= src.len() {
return (num_consumed, false);
}
}
// SIMD constants
let esc_vec = _mm_set1_epi8(0x1Bu8 as i8);
let zero = _mm_setzero_si128();
let one = _mm_set1_epi8(1);
let two = _mm_set1_epi8(2);
let three = _mm_set1_epi8(3);
let four = _mm_set1_epi8(4);
let numbered = _mm_set_epi8(15,14,13,12,11,10,9,8,7,6,5,4,3,2,1,0);
let limit = src.as_ptr().add(src.len());
let mut p = src.as_ptr().add(num_consumed);
let mut sentinel_found = false;
while p < limit && !sentinel_found {
let remaining = limit.offset_from(p) as usize;
let mut chunk_src_sz = remaining.min(16);
// Load chunk (potentially partial)
let mut vec = if chunk_src_sz == 16 {
_mm_loadu_si128(p as *const __m128i)
} else {
// Partial load - zero-extend
let mut buf = [0u8; 16];
std::ptr::copy_nonoverlapping(p, buf.as_mut_ptr(), chunk_src_sz);
_mm_loadu_si128(buf.as_ptr() as *const __m128i)
};
let start_of_current_chunk = p;
p = p.add(chunk_src_sz);
// Check for ESC
let esc_cmp = _mm_cmpeq_epi8(vec, esc_vec);
let num_bytes_to_first_esc = Self::bytes_to_first_match(esc_cmp);
if num_bytes_to_first_esc >= 0 && (num_bytes_to_first_esc as usize) < chunk_src_sz {
sentinel_found = true;
chunk_src_sz = num_bytes_to_first_esc as usize;
num_consumed += chunk_src_sz + 1; // +1 for ESC
if chunk_src_sz == 0 {
continue;
}
} else {
num_consumed += chunk_src_sz;
}
// Zero out bytes past chunk_src_sz
if chunk_src_sz < 16 {
vec = Self::zero_last_n_bytes(vec, 16 - chunk_src_sz);
}
// Check for trailing incomplete sequence
let mut num_trailing_bytes = 0usize;
let mut check_for_trailing = !sentinel_found;
'classification: loop {
// Check if pure ASCII (no high bits set)
let ascii_mask = _mm_movemask_epi8(vec);
if ascii_mask == 0 {
// Pure ASCII - fast output
Self::output_plain_ascii(vec, chunk_src_sz, output);
// Handle trailing bytes
if num_trailing_bytes > 0 && p < limit {
p = p.sub(num_trailing_bytes);
}
break 'classification;
}
// Classify bytes by whether they start 2, 3, or 4 byte sequences
let state_80 = _mm_set1_epi8(0x80u8 as i8);
let vec_signed = _mm_add_epi8(vec, state_80);
// state now has 0x80 on all bytes
let mut state = state_80;
// 2-byte sequence starters (0xC0-0xDF, but 0xC0-0xC1 invalid)
let c2_start = _mm_cmplt_epi8(_mm_set1_epi8((0xC0 - 1 - 0x80) as i8), vec_signed);
state = _mm_blendv_epi8(state, _mm_set1_epi8(0xC2u8 as i8), c2_start);
// 3-byte sequence starters (0xE0-0xEF)
let e3_start = _mm_cmplt_epi8(_mm_set1_epi8((0xE0 - 1 - 0x80) as i8), vec_signed);
state = _mm_blendv_epi8(state, _mm_set1_epi8(0xE3u8 as i8), e3_start);
// 4-byte sequence starters (0xF0-0xFF, but 0xF5+ invalid)
let f4_start = _mm_cmplt_epi8(_mm_set1_epi8((0xF0 - 1 - 0x80) as i8), vec_signed);
state = _mm_blendv_epi8(state, _mm_set1_epi8(0xF4u8 as i8), f4_start);
// mask = upper 5 bits of state (indicates byte type)
let mask = _mm_and_si128(state, _mm_set1_epi8(0xF8u8 as i8));
// count = lower 3 bits of state (sequence length)
let count = _mm_and_si128(state, _mm_set1_epi8(0x07));
// Propagate counts: count[i] = remaining bytes in sequence at position i
// count_subs1[i] = count[i] - 1, saturating
let count_subs1 = _mm_subs_epu8(count, one);
// counts[i] = count[i] + count_subs1[i-1]
let mut counts = _mm_add_epi8(count, _mm_srli_si128(count_subs1, 1));
// counts[i] += counts_subs2[i-2] (for 3 and 4 byte sequences)
counts = _mm_add_epi8(counts, _mm_srli_si128(_mm_subs_epu8(counts, two), 2));
// Check for trailing incomplete sequence
if check_for_trailing {
let last_byte_idx = _mm_set1_epi8((chunk_src_sz - 1) as i8);
let at_last_byte = _mm_cmpeq_epi8(numbered, last_byte_idx);
let counts_at_last = _mm_and_si128(counts, at_last_byte);
let has_trailing = _mm_cmplt_epi8(one, counts_at_last);
if _mm_testz_si128(has_trailing, has_trailing) == 0 {
// We have a trailing incomplete sequence
check_for_trailing = false;
let last_byte = *start_of_current_chunk.add(chunk_src_sz - 1);
if last_byte >= 0xC0 {
num_trailing_bytes = 1;
} else if chunk_src_sz > 1 && *start_of_current_chunk.add(chunk_src_sz - 2) >= 0xE0 {
num_trailing_bytes = 2;
} else if chunk_src_sz > 2 && *start_of_current_chunk.add(chunk_src_sz - 3) >= 0xF0 {
num_trailing_bytes = 3;
}
chunk_src_sz -= num_trailing_bytes;
num_consumed -= num_trailing_bytes;
if chunk_src_sz == 0 {
// Fall back to scalar for trailing bytes
let slice = std::slice::from_raw_parts(
start_of_current_chunk,
num_trailing_bytes
);
self.scalar_decode_all(slice, output);
num_consumed += num_trailing_bytes;
break 'classification;
}
vec = Self::zero_last_n_bytes(vec, 16 - chunk_src_sz);
continue 'classification;
}
}
// Validation: ASCII bytes should have counts[i] == 0
let count_gt_zero = _mm_cmpgt_epi8(counts, zero);
let count_mask = _mm_movemask_epi8(count_gt_zero);
if ascii_mask != count_mask {
// Invalid UTF-8 - fall back to scalar
let slice = std::slice::from_raw_parts(
start_of_current_chunk,
chunk_src_sz + num_trailing_bytes
);
self.scalar_decode_all(slice, output);
num_consumed += num_trailing_bytes;
break 'classification;
}
// Build chunk_is_invalid vector
let mut chunk_invalid = zero;
// Validate 2-byte starters: 0xC0, 0xC1 are invalid
chunk_invalid = _mm_or_si128(chunk_invalid,
_mm_and_si128(c2_start, _mm_cmplt_epi8(vec, _mm_set1_epi8(0xC2u8 as i8))));
// Validate 4-byte starters: 0xF5+ are invalid
chunk_invalid = _mm_or_si128(chunk_invalid,
_mm_and_si128(f4_start, _mm_cmpgt_epi8(vec, _mm_set1_epi8(0xF4u8 as i8))));
// Validate continuation bytes don't have starter bytes
let cont_has_starter = _mm_andnot_si128(
_mm_cmplt_epi8(vec, _mm_set1_epi8(0xC0u8 as i8)),
_mm_cmpgt_epi8(counts, count)
);
chunk_invalid = _mm_or_si128(chunk_invalid, cont_has_starter);
// Validate E0 second bytes (must be >= 0xA0)
let e0_starters = _mm_cmpeq_epi8(vec, _mm_set1_epi8(0xE0u8 as i8));
let e0_followers = _mm_srli_si128(e0_starters, 1);
let e0_invalid = _mm_and_si128(e0_followers,
_mm_cmplt_epi8(_mm_and_si128(e0_followers, vec), _mm_set1_epi8(0xA0u8 as i8)));
chunk_invalid = _mm_or_si128(chunk_invalid, e0_invalid);
// Validate ED second bytes (must be < 0xA0, i.e. <= 0x9F)
let ed_starters = _mm_cmpeq_epi8(vec, _mm_set1_epi8(0xEDu8 as i8));
let ed_followers = _mm_srli_si128(ed_starters, 1);
let ed_invalid = _mm_and_si128(ed_followers,
_mm_cmpgt_epi8(_mm_and_si128(ed_followers, vec), _mm_set1_epi8(0x9Fu8 as i8)));
chunk_invalid = _mm_or_si128(chunk_invalid, ed_invalid);
// Validate F0 second bytes (must be >= 0x90)
let f0_starters = _mm_cmpeq_epi8(vec, _mm_set1_epi8(0xF0u8 as i8));
let f0_followers = _mm_srli_si128(f0_starters, 1);
let f0_invalid = _mm_and_si128(f0_followers,
_mm_cmplt_epi8(_mm_and_si128(f0_followers, vec), _mm_set1_epi8(0x90u8 as i8)));
chunk_invalid = _mm_or_si128(chunk_invalid, f0_invalid);
// Validate F4 second bytes (must be < 0x90, i.e. <= 0x8F)
let f4_starters = _mm_cmpeq_epi8(vec, _mm_set1_epi8(0xF4u8 as i8));
let f4_followers = _mm_srli_si128(f4_starters, 1);
let f4_invalid = _mm_and_si128(f4_followers,
_mm_cmpgt_epi8(_mm_and_si128(f4_followers, vec), _mm_set1_epi8(0x8Fu8 as i8)));
chunk_invalid = _mm_or_si128(chunk_invalid, f4_invalid);
// If invalid, fall back to scalar
if _mm_testz_si128(chunk_invalid, chunk_invalid) == 0 {
let slice = std::slice::from_raw_parts(
start_of_current_chunk,
chunk_src_sz + num_trailing_bytes
);
self.scalar_decode_all(slice, output);
num_consumed += num_trailing_bytes;
break 'classification;
}
// Mask control bits to get payload only
vec = _mm_andnot_si128(mask, vec);
// Build output vectors
let vec_non_ascii = _mm_andnot_si128(_mm_cmpeq_epi8(counts, zero), vec);
// output1: lowest byte of each codepoint
// For count==1 positions: OR with shifted bits from count==2 position
let count1_locs = _mm_cmpeq_epi8(counts, one);
let shifted_6 = _mm_and_si128(
_mm_slli_epi16(_mm_srli_si128(vec_non_ascii, 1), 6),
_mm_set1_epi8(0xC0u8 as i8)
);
let output1 = _mm_blendv_epi8(vec, _mm_or_si128(vec, shifted_6), count1_locs);
// output2: middle byte (for 3 and 4 byte sequences)
let count2_locs = _mm_cmpeq_epi8(counts, two);
let count3_locs = _mm_cmpeq_epi8(counts, three);
let mut output2 = _mm_and_si128(vec, count2_locs);
output2 = _mm_srli_epi32(output2, 2); // bits 5,4,3,2
let shifted_4 = _mm_and_si128(
_mm_set1_epi8(0xF0u8 as i8),
_mm_slli_epi16(_mm_srli_si128(_mm_and_si128(count3_locs, vec_non_ascii), 1), 4)
);
output2 = _mm_or_si128(output2, shifted_4);
output2 = _mm_and_si128(output2, count2_locs);
output2 = _mm_srli_si128(output2, 1);
// output3: highest byte (for 4 byte sequences)
let count4_locs = _mm_cmpeq_epi8(counts, four);
let mut output3 = _mm_and_si128(three, _mm_srli_epi32(vec, 4)); // bits 5,6 from count==3
let shifted_2 = _mm_and_si128(
_mm_set1_epi8(0xFCu8 as i8),
_mm_slli_epi16(_mm_srli_si128(_mm_and_si128(count4_locs, vec_non_ascii), 1), 2)
);
output3 = _mm_or_si128(output3, shifted_2);
output3 = _mm_and_si128(output3, count3_locs);
output3 = _mm_srli_si128(output3, 2);
// Shuffle to remove continuation bytes
// shifts = number of bytes to skip for each position
let mut shifts = count_subs1;
// Propagate shifts: shifts[i] += shifts[i-1] + shifts[i-2] + ...
shifts = _mm_add_epi8(shifts, _mm_srli_si128(shifts, 1));
shifts = _mm_add_epi8(shifts, _mm_srli_si128(shifts, 2));
shifts = _mm_add_epi8(shifts, _mm_srli_si128(shifts, 4));
shifts = _mm_add_epi8(shifts, _mm_srli_si128(shifts, 8));
// Zero shifts for discarded continuation bytes (where counts >= 2)
shifts = _mm_and_si128(shifts, _mm_cmplt_epi8(counts, two));
// Move shifts leftward based on bit patterns
// This is Kitty's move() macro
shifts = Self::move_shifts_by_1(shifts);
shifts = Self::move_shifts_by_2(shifts);
shifts = Self::move_shifts_by_4(shifts);
shifts = Self::move_shifts_by_8(shifts);
// Add byte numbers to create shuffle mask
shifts = _mm_add_epi8(shifts, numbered);
// Shuffle the output vectors
let output1 = _mm_shuffle_epi8(output1, shifts);
let output2 = _mm_shuffle_epi8(output2, shifts);
let output3 = _mm_shuffle_epi8(output3, shifts);
// Count discarded bytes to get codepoint count
let num_discarded = Self::sum_bytes(count_subs1);
let num_codepoints = chunk_src_sz - num_discarded;
// Output unicode codepoints
Self::output_unicode(output1, output2, output3, num_codepoints, output);
// Handle trailing bytes
if num_trailing_bytes > 0 && p < limit {
p = p.sub(num_trailing_bytes);
}
break 'classification;
}
}
(num_consumed, sentinel_found)
}
/// move() macro from Kitty: move shifts leftward based on bit pattern
/// move(shifts, one_byte, 1)
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "sse2", enable = "ssse3", enable = "sse4.1")]
#[inline]
unsafe fn move_shifts_by_1(shifts: __m128i) -> __m128i {
// blendv_epi8(shifts, shift_left_by_one_byte(shifts),
// shift_left_by_one_byte(shift_left_by_bits16(shifts, 7)))
let selector = _mm_slli_si128(_mm_slli_epi16(shifts, 7), 1);
let shifted = _mm_slli_si128(shifts, 1);
_mm_blendv_epi8(shifts, shifted, selector)
}
/// move(shifts, two_bytes, 2)
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "sse2", enable = "ssse3", enable = "sse4.1")]
#[inline]
unsafe fn move_shifts_by_2(shifts: __m128i) -> __m128i {
let selector = _mm_slli_si128(_mm_slli_epi16(shifts, 6), 2);
let shifted = _mm_slli_si128(shifts, 2);
_mm_blendv_epi8(shifts, shifted, selector)
}
/// move(shifts, four_bytes, 3)
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "sse2", enable = "ssse3", enable = "sse4.1")]
#[inline]
unsafe fn move_shifts_by_4(shifts: __m128i) -> __m128i {
let selector = _mm_slli_si128(_mm_slli_epi16(shifts, 5), 4);
let shifted = _mm_slli_si128(shifts, 4);
_mm_blendv_epi8(shifts, shifted, selector)
}
/// move(shifts, eight_bytes, 4)
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "sse2", enable = "ssse3", enable = "sse4.1")]
#[inline]
unsafe fn move_shifts_by_8(shifts: __m128i) -> __m128i {
let selector = _mm_slli_si128(_mm_slli_epi16(shifts, 4), 8);
let shifted = _mm_slli_si128(shifts, 8);
_mm_blendv_epi8(shifts, shifted, selector)
}
/// Find first matching byte position, returns -1 if none found
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "sse2", enable = "sse4.1")]
#[inline]
unsafe fn bytes_to_first_match(cmp_result: __m128i) -> i32 {
if _mm_testz_si128(cmp_result, cmp_result) != 0 {
-1
} else {
_mm_movemask_epi8(cmp_result).trailing_zeros() as i32
}
}
/// Zero the last n bytes of the vector.
/// E.g., zero_last_n_bytes(vec, 3) zeros bytes at indices 13, 14, 15.
/// This matches Kitty's implementation which uses shift_left_by_bytes (actually _mm_srli_si128).
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "sse2")]
#[inline]
unsafe fn zero_last_n_bytes(vec: __m128i, n: usize) -> __m128i {
// Kitty's approach: shift all-ones "left" (toward lower indices) by n bytes
// This uses _mm_srli_si128 which shifts bytes toward index 0, zeros enter at high indices
// Result: mask with FF at indices 0..15-n and 00 at indices 16-n..15
let all_ones = _mm_set1_epi8(-1);
let mask = match n {
0 => all_ones,
1 => _mm_srli_si128(all_ones, 1),
2 => _mm_srli_si128(all_ones, 2),
3 => _mm_srli_si128(all_ones, 3),
4 => _mm_srli_si128(all_ones, 4),
5 => _mm_srli_si128(all_ones, 5),
6 => _mm_srli_si128(all_ones, 6),
7 => _mm_srli_si128(all_ones, 7),
8 => _mm_srli_si128(all_ones, 8),
9 => _mm_srli_si128(all_ones, 9),
10 => _mm_srli_si128(all_ones, 10),
11 => _mm_srli_si128(all_ones, 11),
12 => _mm_srli_si128(all_ones, 12),
13 => _mm_srli_si128(all_ones, 13),
14 => _mm_srli_si128(all_ones, 14),
15 => _mm_srli_si128(all_ones, 15),
_ => _mm_setzero_si128(),
};
_mm_and_si128(mask, vec)
}
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "sse2")]
#[inline]
unsafe fn sum_bytes(vec: __m128i) -> usize {
let sum = _mm_sad_epu8(vec, _mm_setzero_si128());
let lower = _mm_cvtsi128_si32(sum) as usize;
let upper = _mm_cvtsi128_si32(_mm_srli_si128(sum, 8)) as usize;
lower + upper
}
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "sse2", enable = "sse4.1")]
#[inline]
unsafe fn output_plain_ascii(vec: __m128i, src_sz: usize, output: &mut Vec<u32>) {
output.reserve(src_sz);
// Process 4 bytes at a time
let mut v = vec;
let mut remaining = src_sz;
while remaining > 0 {
let unpacked = _mm_cvtepu8_epi32(v);
let to_write = remaining.min(4);
let mut buf = [0u32; 4];
_mm_storeu_si128(buf.as_mut_ptr() as *mut __m128i, unpacked);
output.extend_from_slice(&buf[..to_write]);
remaining = remaining.saturating_sub(4);
v = _mm_srli_si128(v, 4);
}
}
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "sse2", enable = "sse4.1")]
#[inline]
unsafe fn output_unicode(
output1: __m128i,
output2: __m128i,
output3: __m128i,
num_codepoints: usize,
output: &mut Vec<u32>
) {
output.reserve(num_codepoints);
let mut o1 = output1;
let mut o2 = output2;
let mut o3 = output3;
let mut remaining = num_codepoints;
while remaining > 0 {
// Unpack lowest 4 bytes to 4 u32s
let unpacked1 = _mm_cvtepu8_epi32(o1);
// Shift right by 1 byte, then unpack - this puts bytes in position 1 of each u32 (bits 8-15)
let unpacked2 = _mm_cvtepu8_epi32(_mm_srli_si128(o2, 0));
let unpacked2 = _mm_slli_epi32(unpacked2, 8);
// Shift right by 2 bytes for output3 - puts bytes in position 2 (bits 16-23)
let unpacked3 = _mm_cvtepu8_epi32(_mm_srli_si128(o3, 0));
let unpacked3 = _mm_slli_epi32(unpacked3, 16);
let unpacked = _mm_or_si128(_mm_or_si128(unpacked1, unpacked2), unpacked3);
let to_write = remaining.min(4);
let mut buf = [0u32; 4];
_mm_storeu_si128(buf.as_mut_ptr() as *mut __m128i, unpacked);
output.extend_from_slice(&buf[..to_write]);
remaining = remaining.saturating_sub(4);
o1 = _mm_srli_si128(o1, 4);
o2 = _mm_srli_si128(o2, 4);
o3 = _mm_srli_si128(o3, 4);
}
}
/// Scalar decode until state is ACCEPT.
fn scalar_decode_to_accept(&mut self, src: &[u8], output: &mut Vec<u32>) -> usize {
let mut pos = 0;
while pos < src.len() && self.state.cur != UTF8_ACCEPT {
let byte = src[pos];
if byte == 0x1B {
output.push(0xFFFD);
self.state = Utf8State::default();
return pos;
}
pos += 1;
self.state.prev = self.state.cur;
match decode_utf8_byte(&mut self.state.cur, &mut self.state.codep, byte) {
UTF8_ACCEPT => output.push(self.state.codep),
UTF8_REJECT => {
output.push(0xFFFD);
let was_accept = self.state.prev == UTF8_ACCEPT;
self.state = Utf8State::default();
if !was_accept {
pos -= 1;
}
}
_ => {}
}
}
pos
}
/// Scalar decode all bytes.
fn scalar_decode_all(&mut self, src: &[u8], output: &mut Vec<u32>) -> usize {
let mut pos = 0;
while pos < src.len() {
let byte = src[pos];
if byte == 0x1B {
if self.state.cur != UTF8_ACCEPT {
output.push(0xFFFD);
self.state = Utf8State::default();
}
return pos;
}
pos += 1;
self.state.prev = self.state.cur;
match decode_utf8_byte(&mut self.state.cur, &mut self.state.codep, byte) {
UTF8_ACCEPT => output.push(self.state.codep),
UTF8_REJECT => {
output.push(0xFFFD);
let was_accept = self.state.prev == UTF8_ACCEPT;
self.state = Utf8State::default();
if !was_accept {
pos -= 1;
}
}
_ => {}
}
}
pos
}
}
/// Convert u32 codepoints to chars.
/// SAFETY: Caller must ensure all codepoints are valid Unicode.
#[inline]
pub fn codepoints_to_chars(codepoints: &[u32], chars: &mut Vec<char>) {
chars.clear();
chars.reserve(codepoints.len());
for &cp in codepoints {
// SAFETY: The SIMD decoder validates UTF-8, so codepoints are valid
if cp <= 0x10FFFF && !(0xD800..=0xDFFF).contains(&cp) {
chars.push(unsafe { char::from_u32_unchecked(cp) });
} else {
chars.push('\u{FFFD}');
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ascii() {
let mut decoder = SimdUtf8Decoder::new();
let mut output = Vec::new();
let input = b"Hello, World!";
let (consumed, found_esc) = decoder.decode_to_esc(input, &mut output);
assert_eq!(consumed, 13);
assert!(!found_esc);
let chars: Vec<char> = output.iter().filter_map(|&cp| char::from_u32(cp)).collect();
assert_eq!(chars.iter().collect::<String>(), "Hello, World!");
}
#[test]
fn test_with_esc() {
let mut decoder = SimdUtf8Decoder::new();
let mut output = Vec::new();
let input = b"Hello\x1b[0m";
let (consumed, found_esc) = decoder.decode_to_esc(input, &mut output);
assert_eq!(consumed, 6); // "Hello" + ESC
assert!(found_esc);
let chars: Vec<char> = output.iter().filter_map(|&cp| char::from_u32(cp)).collect();
assert_eq!(chars.iter().collect::<String>(), "Hello");
}
#[test]
fn test_utf8_2byte() {
let mut decoder = SimdUtf8Decoder::new();
let mut output = Vec::new();
let input = "café".as_bytes();
let (consumed, found_esc) = decoder.decode_to_esc(input, &mut output);
assert_eq!(consumed, 5); // c, a, f, é (2 bytes)
assert!(!found_esc);
let chars: Vec<char> = output.iter().filter_map(|&cp| char::from_u32(cp)).collect();
assert_eq!(chars.iter().collect::<String>(), "café");
}
#[test]
fn test_utf8_3byte() {
let mut decoder = SimdUtf8Decoder::new();
let mut output = Vec::new();
let input = "日本語".as_bytes();
let (consumed, found_esc) = decoder.decode_to_esc(input, &mut output);
assert_eq!(consumed, 9); // 3 chars * 3 bytes
assert!(!found_esc);
let chars: Vec<char> = output.iter().filter_map(|&cp| char::from_u32(cp)).collect();
assert_eq!(chars.iter().collect::<String>(), "日本語");
}
#[test]
fn test_utf8_4byte() {
let mut decoder = SimdUtf8Decoder::new();
let mut output = Vec::new();
let input = "🎉🚀".as_bytes();
let (consumed, found_esc) = decoder.decode_to_esc(input, &mut output);
assert_eq!(consumed, 8); // 2 chars * 4 bytes
assert!(!found_esc);
let chars: Vec<char> = output.iter().filter_map(|&cp| char::from_u32(cp)).collect();
assert_eq!(chars.iter().collect::<String>(), "🎉🚀");
}
#[test]
fn test_invalid_utf8() {
let mut decoder = SimdUtf8Decoder::new();
let mut output = Vec::new();
let input = b"\xff\xfe";
let (consumed, _) = decoder.decode_to_esc(input, &mut output);
assert_eq!(consumed, 2);
// Should have replacement characters
assert!(output.iter().any(|&cp| cp == 0xFFFD));
}
// ========================================================================
// Tests for find_byte
// ========================================================================
#[test]
fn test_find_byte_first() {
let haystack = b"hello world";
assert_eq!(find_byte(haystack, b'h'), Some(0));
}
#[test]
fn test_find_byte_middle() {
let haystack = b"hello world";
assert_eq!(find_byte(haystack, b'w'), Some(6));
}
#[test]
fn test_find_byte_not_found() {
let haystack = b"hello world";
assert_eq!(find_byte(haystack, b'x'), None);
}
#[test]
fn test_find_byte_long() {
// Test with > 32 bytes to exercise AVX2 path
let mut haystack = vec![b'a'; 100];
haystack[75] = b'Z';
assert_eq!(find_byte(&haystack, b'Z'), Some(75));
}
// ========================================================================
// Tests for find_either_of_two_bytes
// ========================================================================
#[test]
fn test_find_either_of_two_bytes_first() {
let haystack = b"hello world";
assert_eq!(find_either_of_two_bytes(haystack, b'h', b'x'), Some(0));
}
#[test]
fn test_find_either_of_two_bytes_second() {
let haystack = b"hello world";
assert_eq!(find_either_of_two_bytes(haystack, b'x', b'h'), Some(0));
}
#[test]
fn test_find_either_of_two_bytes_middle() {
let haystack = b"hello world";
assert_eq!(find_either_of_two_bytes(haystack, b'w', b'o'), Some(4)); // first 'o' at index 4
}
#[test]
fn test_find_either_of_two_bytes_not_found() {
let haystack = b"hello world";
assert_eq!(find_either_of_two_bytes(haystack, b'x', b'y'), None);
}
#[test]
fn test_find_either_of_two_bytes_esc() {
let haystack = b"hello\x1bworld";
assert_eq!(find_either_of_two_bytes(haystack, 0x1B, b'\n'), Some(5));
}
#[test]
fn test_find_either_of_two_bytes_long() {
// Test with > 32 bytes to exercise AVX2 path
let mut haystack = vec![b'a'; 100];
haystack[50] = b'X';
assert_eq!(find_either_of_two_bytes(&haystack, b'X', b'Y'), Some(50));
}
#[test]
fn test_find_either_of_two_bytes_empty() {
let haystack = b"";
assert_eq!(find_either_of_two_bytes(haystack, b'a', b'b'), None);
}
// ========================================================================
// Tests for find_c0_control
// ========================================================================
#[test]
fn test_find_c0_control_newline() {
let haystack = b"hello\nworld";
assert_eq!(find_c0_control(haystack), Some(5));
}
#[test]
fn test_find_c0_control_tab() {
let haystack = b"hello\tworld";
assert_eq!(find_c0_control(haystack), Some(5));
}
#[test]
fn test_find_c0_control_del() {
let haystack = b"hello\x7fworld";
assert_eq!(find_c0_control(haystack), Some(5));
}
#[test]
fn test_find_c0_control_bell() {
let haystack = b"hello\x07world";
assert_eq!(find_c0_control(haystack), Some(5));
}
#[test]
fn test_find_c0_control_esc() {
let haystack = b"hello\x1bworld";
assert_eq!(find_c0_control(haystack), Some(5));
}
#[test]
fn test_find_c0_control_none() {
let haystack = b"hello world!";
assert_eq!(find_c0_control(haystack), None);
}
#[test]
fn test_find_c0_control_long() {
// Test with > 32 bytes to exercise AVX2 path
let mut haystack = vec![b'a'; 100];
haystack[60] = b'\n';
assert_eq!(find_c0_control(&haystack), Some(60));
}
#[test]
fn test_find_c0_control_at_start() {
let haystack = b"\x00hello";
assert_eq!(find_c0_control(haystack), Some(0));
}
// ========================================================================
// Tests for xor_mask
// ========================================================================
#[test]
fn test_xor_mask_basic() {
let mut data = vec![0u8; 8];
let mask = [0x12, 0x34, 0x56, 0x78];
xor_mask(&mut data, mask, 0);
assert_eq!(data, vec![0x12, 0x34, 0x56, 0x78, 0x12, 0x34, 0x56, 0x78]);
}
#[test]
fn test_xor_mask_offset() {
let mut data = vec![0u8; 8];
let mask = [0x12, 0x34, 0x56, 0x78];
xor_mask(&mut data, mask, 1);
// Starting at offset 1: 0x34, 0x56, 0x78, 0x12, 0x34, ...
assert_eq!(data, vec![0x34, 0x56, 0x78, 0x12, 0x34, 0x56, 0x78, 0x12]);
}
#[test]
fn test_xor_mask_roundtrip() {
let original = b"Hello, World!".to_vec();
let mut data = original.clone();
let mask = [0xAB, 0xCD, 0xEF, 0x01];
// XOR once
xor_mask(&mut data, mask, 0);
assert_ne!(data, original);
// XOR again to get back original
xor_mask(&mut data, mask, 0);
assert_eq!(data, original);
}
#[test]
fn test_xor_mask_long() {
// Test with > 32 bytes to exercise AVX2 path
let mut data = vec![0xFFu8; 100];
let mask = [0x12, 0x34, 0x56, 0x78];
xor_mask(&mut data, mask, 0);
// Verify pattern
for (i, &byte) in data.iter().enumerate() {
assert_eq!(byte, 0xFF ^ mask[i % 4]);
}
}
#[test]
fn test_xor_mask_empty() {
let mut data: Vec<u8> = vec![];
let mask = [0x12, 0x34, 0x56, 0x78];
let result = xor_mask(&mut data, mask, 0);
assert_eq!(result, 0);
}
}