From 4287e0ca7dc22ebc6cf9bd70c1edb77fbb4e479e Mon Sep 17 00:00:00 2001 From: Finn Bear Date: Mon, 25 Aug 2025 08:49:07 -0700 Subject: [PATCH 01/11] Support C-style enums with more than 256 variants. --- bitcode_derive/src/decode.rs | 11 +++-- bitcode_derive/src/encode.rs | 7 +-- bitcode_derive/src/shared.rs | 91 ++++++++++++++++++++++++++++++---- src/derive/option.rs | 4 +- src/derive/result.rs | 4 +- src/derive/variant.rs | 95 ++++++++++++++++++++++++++++++------ src/fast.rs | 9 +++- 7 files changed, 183 insertions(+), 38 deletions(-) diff --git a/bitcode_derive/src/decode.rs b/bitcode_derive/src/decode.rs index 14a69fb..c4c3d0a 100644 --- a/bitcode_derive/src/decode.rs +++ b/bitcode_derive/src/decode.rs @@ -1,6 +1,6 @@ use crate::attribute::BitcodeAttrs; use crate::private; -use crate::shared::{remove_lifetimes, replace_lifetimes, variant_index}; +use crate::shared::{remove_lifetimes, replace_lifetimes, VariantIndex}; use proc_macro2::{Ident, Span, TokenStream}; use quote::{quote, ToTokens}; use syn::{ @@ -111,6 +111,7 @@ impl crate::shared::Item for Item { self, crate_name: &Path, variant_count: usize, + variant_index: VariantIndex, pattern: impl Fn(usize) -> TokenStream, inner: impl Fn(Self, usize) -> TokenStream, ) -> TokenStream { @@ -126,7 +127,7 @@ impl crate::shared::Item for Item { .then(|| { let private = private(crate_name); let c_style = inners.is_empty(); - quote! { variants: #private::VariantDecoder<#de, #variant_count, #c_style>, } + quote! { variants: #private::VariantDecoder<#de, #variant_index, #variant_count, #c_style>, } }) .unwrap_or_default(); quote! { @@ -165,7 +166,7 @@ impl crate::shared::Item for Item { if inner.is_empty() { quote! {} } else { - let i = variant_index(i); + let i = variant_index.instance_to_tokens(i); let length = decode_variants .then(|| { quote! { @@ -209,7 +210,7 @@ impl crate::shared::Item for Item { .map(|i| { let inner = inner(i); let pattern = pattern(i); - let i = variant_index(i); + let i = variant_index.instance_to_tokens(i); quote! { #i => { #inner @@ -221,7 +222,7 @@ impl crate::shared::Item for Item { quote! { match self.variants.decode() { #variants - // Safety: VariantDecoder::decode outputs numbers less than N. + // Safety: VariantDecoder<_, N, _>::decode outputs numbers less than N. _ => unsafe { ::core::hint::unreachable_unchecked() } } } diff --git a/bitcode_derive/src/encode.rs b/bitcode_derive/src/encode.rs index 9680229..b4532bb 100644 --- a/bitcode_derive/src/encode.rs +++ b/bitcode_derive/src/encode.rs @@ -1,6 +1,6 @@ use crate::attribute::BitcodeAttrs; use crate::private; -use crate::shared::{remove_lifetimes, replace_lifetimes, variant_index}; +use crate::shared::{remove_lifetimes, replace_lifetimes, VariantIndex}; use proc_macro2::{Ident, Span, TokenStream}; use quote::{quote, ToTokens}; use syn::{parse_quote, Generics, Path, Type}; @@ -114,6 +114,7 @@ impl crate::shared::Item for Item { self, crate_name: &Path, variant_count: usize, + variant_index: VariantIndex, pattern: impl Fn(usize) -> TokenStream, inner: impl Fn(Self, usize) -> TokenStream, ) -> TokenStream { @@ -124,7 +125,7 @@ impl crate::shared::Item for Item { let variants = encode_variants .then(|| { let private = private(crate_name); - quote! { variants: #private::VariantEncoder<#variant_count>, } + quote! { variants: #private::VariantEncoder<#variant_index, #variant_count>, } }) .unwrap_or_default(); let inners: TokenStream = (0..variant_count).map(|i| inner(self, i)).collect(); @@ -149,7 +150,7 @@ impl crate::shared::Item for Item { let variants: TokenStream = (0..variant_count) .map(|i| { let pattern = pattern(i); - let i = variant_index(i); + let i = variant_index.instance_to_tokens(i); quote! { #pattern => #i, } diff --git a/bitcode_derive/src/shared.rs b/bitcode_derive/src/shared.rs index 3d4f67e..a3b92cd 100644 --- a/bitcode_derive/src/shared.rs +++ b/bitcode_derive/src/shared.rs @@ -9,9 +9,72 @@ use syn::{ Result, Type, WherePredicate, }; -type VariantIndex = u8; -pub fn variant_index(i: usize) -> VariantIndex { - i.try_into().unwrap() +#[derive(Copy, Clone, Debug)] +pub enum VariantIndex { + U8, + U16, + U32, +} + +impl VariantIndex { + pub fn new(variant_count: usize, ident: &Ident) -> Result { + for candidate in [Self::U8, Self::U16, Self::U32] { + if variant_count <= candidate.max_variants() { + return Ok(candidate); + } + } + err( + &ident, + &format!( + "enums with more than {} variants are not supported", + Self::U32.max_variants() + ), + ) + } + + fn max_variants(self) -> usize { + (match self { + Self::U8 => u8::MAX as usize, + Self::U16 => u16::MAX as usize, + Self::U32 => u32::MAX as usize, + }) + 1 + } + + /// If returns `false`, only C-style enums are supported. + pub fn supports_fields(self) -> bool { + match self { + Self::U8 => true, + _ => false, + } + } + + pub fn instance_to_tokens(self, index: usize) -> TokenStream { + match self { + Self::U8 => { + let n: u8 = index.try_into().unwrap(); + quote! {#n} + } + Self::U16 => { + let n: u16 = index.try_into().unwrap(); + quote! {#n} + } + Self::U32 => { + let n: u32 = index.try_into().unwrap(); + quote! {#n} + } + } + } +} + +impl ToTokens for VariantIndex { + fn to_tokens(&self, tokens: &mut TokenStream) { + use quote::TokenStreamExt; + tokens.append(match self { + Self::U8 => Ident::new("u8", Span::call_site()), + Self::U16 => Ident::new("u16", Span::call_site()), + Self::U32 => Ident::new("u32", Span::call_site()), + }); + } } pub trait Item: Copy + Sized { @@ -36,6 +99,7 @@ pub trait Item: Copy + Sized { self, crate_name: &Path, variant_count: usize, + variant_index: VariantIndex, pattern: impl Fn(usize) -> TokenStream, inner: impl Fn(Self, usize) -> TokenStream, ) -> TokenStream; @@ -132,12 +196,20 @@ pub trait Derive { }) } Data::Enum(data_enum) => { - let max_variants = VariantIndex::MAX as usize + 1; - if data_enum.variants.len() > max_variants { - return err( - &ident, - &format!("enums with more than {max_variants} variants are not supported"), - ); + let variant_index = VariantIndex::new(data_enum.variants.len(), &ident)?; + + if !variant_index.supports_fields() { + for variant in &data_enum.variants { + if !variant.fields.is_empty() { + return err( + &ident, + &format!( + "enums with more than {} variants must not have any variants with fields", + VariantIndex::U8.max_variants() + ), + ); + } + } } // Used for adding `bounds` and skipping fields. Would be used by `#[bitcode(with_serde)]`. @@ -154,6 +226,7 @@ pub trait Derive { item.enum_impl( &attrs.crate_name, data_enum.variants.len(), + variant_index, |i| { let variant = &data_enum.variants[i]; let variant_name = &variant.ident; diff --git a/src/derive/option.rs b/src/derive/option.rs index b192bae..6064875 100644 --- a/src/derive/option.rs +++ b/src/derive/option.rs @@ -7,7 +7,7 @@ use core::mem::MaybeUninit; use core::num::NonZeroUsize; pub struct OptionEncoder { - variants: VariantEncoder<2>, + variants: VariantEncoder, some: T::Encoder, } @@ -86,7 +86,7 @@ impl Buffer for OptionEncoder { } pub struct OptionDecoder<'a, T: Decode<'a>> { - variants: VariantDecoder<'a, 2, false>, + variants: VariantDecoder<'a, u8, 2, false>, some: T::Decoder, } diff --git a/src/derive/result.rs b/src/derive/result.rs index 9ec6971..7364c82 100644 --- a/src/derive/result.rs +++ b/src/derive/result.rs @@ -7,7 +7,7 @@ use core::mem::MaybeUninit; use core::num::NonZeroUsize; pub struct ResultEncoder { - variants: VariantEncoder<2>, + variants: VariantEncoder, ok: T::Encoder, err: E::Encoder, } @@ -55,7 +55,7 @@ impl Buffer for ResultEncoder { } pub struct ResultDecoder<'a, T: Decode<'a>, E: Decode<'a>> { - variants: VariantDecoder<'a, 2, false>, + variants: VariantDecoder<'a, u8, 2, false>, ok: T::Decoder, err: E::Decoder, } diff --git a/src/derive/variant.rs b/src/derive/variant.rs index 67463c3..ba7154a 100644 --- a/src/derive/variant.rs +++ b/src/derive/variant.rs @@ -1,23 +1,29 @@ use crate::coder::{Buffer, Decoder, Encoder, Result, View}; use crate::fast::{CowSlice, NextUnchecked, PushUnchecked, VecImpl}; use crate::pack::{pack_bytes_less_than, unpack_bytes_less_than}; +use crate::pack_ints::{pack_ints, unpack_ints, Int}; use alloc::vec::Vec; +use core::any::TypeId; use core::num::NonZeroUsize; #[derive(Default)] -pub struct VariantEncoder(VecImpl); +pub struct VariantEncoder(VecImpl); -impl Encoder for VariantEncoder { +impl Encoder for VariantEncoder { #[inline(always)] - fn encode(&mut self, v: &u8) { + fn encode(&mut self, v: &T) { unsafe { self.0.push_unchecked(*v) }; } } -impl Buffer for VariantEncoder { +impl Buffer for VariantEncoder { fn collect_into(&mut self, out: &mut Vec) { assert!(N >= 2); - pack_bytes_less_than::(self.0.as_slice(), out); + if std::mem::size_of::() > 1 { + pack_ints(self.0.as_mut_slice(), out); + } else { + pack_bytes_less_than::(bytemuck::must_cast_slice::(self.0.as_slice()), out); + }; self.0.clear(); } @@ -26,13 +32,13 @@ impl Buffer for VariantEncoder { } } -pub struct VariantDecoder<'a, const N: usize, const C_STYLE: bool> { - variants: CowSlice<'a, u8>, +pub struct VariantDecoder<'a, T: Int, const N: usize, const C_STYLE: bool> { + variants: CowSlice<'a, T::Une>, histogram: [usize; N], // Not required if C_STYLE. TODO don't reserve space for it. } // [(); N] doesn't implement Default. -impl Default for VariantDecoder<'_, N, C_STYLE> { +impl Default for VariantDecoder<'_, T, N, C_STYLE> { fn default() -> Self { Self { variants: Default::default(), @@ -42,29 +48,44 @@ impl Default for VariantDecoder<'_, N, C_ST } // C style enums don't require length, so we can skip making a histogram for them. -impl<'a, const N: usize> VariantDecoder<'a, N, false> { +impl<'a, T: Int, const N: usize> VariantDecoder<'a, T, N, false> { pub fn length(&self, variant_index: u8) -> usize { self.histogram[variant_index as usize] } } -impl<'a, const N: usize, const C_STYLE: bool> View<'a> for VariantDecoder<'a, N, C_STYLE> { +impl<'a, T: Int, const N: usize, const C_STYLE: bool> View<'a> + for VariantDecoder<'a, T, N, C_STYLE> +{ fn populate(&mut self, input: &mut &'a [u8], length: usize) -> Result<()> { assert!(N >= 2); - if C_STYLE { - unpack_bytes_less_than::(input, length, &mut self.variants)?; + if TypeId::of::() != TypeId::of::() { + unpack_ints::(input, length, &mut self.variants)?; } else { - self.histogram = unpack_bytes_less_than::(input, length, &mut self.variants)?; + // SAFETY: Checked the type above and [u8; 1] has the + // same memory layout as `u8`. + let out = unsafe { + std::mem::transmute::<&mut CowSlice<'a, T::Une>, &mut CowSlice<'a, u8>>( + &mut self.variants, + ) + }; + if C_STYLE { + unpack_bytes_less_than::(input, length, out)?; + } else { + self.histogram = unpack_bytes_less_than::(input, length, out)?; + } } Ok(()) } } -impl<'a, const N: usize, const C_STYLE: bool> Decoder<'a, u8> for VariantDecoder<'a, N, C_STYLE> { +impl<'a, T: Int, const N: usize, const C_STYLE: bool> Decoder<'a, T> + for VariantDecoder<'a, T, N, C_STYLE> +{ // Guaranteed to output numbers less than N. #[inline(always)] - fn decode(&mut self) -> u8 { - unsafe { self.variants.mut_slice().next_unchecked() } + fn decode(&mut self) -> T { + bytemuck::must_cast(unsafe { self.variants.mut_slice().next_unchecked() }) } } @@ -99,6 +120,48 @@ mod tests { assert!(matches!(decode(&encode(&Enum1::F)), Ok(Enum1::F))); } + #[allow(unused)] + #[test] + fn test_large_c_style_enum() { + #[cfg_attr(not(test), rustfmt::skip)] + #[derive(Encode, Decode)] + enum Enum300 { + V1, V2, V3, V4, V5, V6, V7, V8, V9, V10, + V11, V12, V13, V14, V15, V16, V17, V18, V19, V20, + V21, V22, V23, V24, V25, V26, V27, V28, V29, V30, + V31, V32, V33, V34, V35, V36, V37, V38, V39, V40, + V41, V42, V43, V44, V45, V46, V47, V48, V49, V50, + V51, V52, V53, V54, V55, V56, V57, V58, V59, V60, + V61, V62, V63, V64, V65, V66, V67, V68, V69, V70, + V71, V72, V73, V74, V75, V76, V77, V78, V79, V80, + V81, V82, V83, V84, V85, V86, V87, V88, V89, V90, + V91, V92, V93, V94, V95, V96, V97, V98, V99, V100, + V101, V102, V103, V104, V105, V106, V107, V108, V109, V110, + V111, V112, V113, V114, V115, V116, V117, V118, V119, V120, + V121, V122, V123, V124, V125, V126, V127, V128, V129, V130, + V131, V132, V133, V134, V135, V136, V137, V138, V139, V140, + V141, V142, V143, V144, V145, V146, V147, V148, V149, V150, + V151, V152, V153, V154, V155, V156, V157, V158, V159, V160, + V161, V162, V163, V164, V165, V166, V167, V168, V169, V170, + V171, V172, V173, V174, V175, V176, V177, V178, V179, V180, + V181, V182, V183, V184, V185, V186, V187, V188, V189, V190, + V191, V192, V193, V194, V195, V196, V197, V198, V199, V200, + V201, V202, V203, V204, V205, V206, V207, V208, V209, V210, + V211, V212, V213, V214, V215, V216, V217, V218, V219, V220, + V221, V222, V223, V224, V225, V226, V227, V228, V229, V230, + V231, V232, V233, V234, V235, V236, V237, V238, V239, V240, + V241, V242, V243, V244, V245, V246, V247, V248, V249, V250, + V251, V252, V253, V254, V255, V256, V257, V258, V259, V260, + V261, V262, V263, V264, V265, V266, V267, V268, V269, V270, + V271, V272, V273, V274, V275, V276, V277, V278, V279, V280, + V281, V282, V283, V284, V285, V286, V287, V288, V289, V290, + V291, V292, V293, V294, V295, V296, V297, V298, V299, V300, + } + + assert!(matches!(decode(&encode(&Enum300::V42)), Ok(Enum300::V42))); + assert!(matches!(decode(&encode(&Enum300::V300)), Ok(Enum300::V300))); + } + #[allow(unused)] #[test] fn test_rust_style_enum() { diff --git a/src/fast.rs b/src/fast.rs index 13adb2e..ebbe23e 100644 --- a/src/fast.rs +++ b/src/fast.rs @@ -332,11 +332,18 @@ impl<'a, T: Copy> NextUnchecked<'a, T> for &'a [T] { } /// Maybe owned [`FastSlice`]. Saves its allocation even if borrowing something. -#[derive(Default)] pub struct CowSlice<'borrowed, T> { slice: SliceImpl<'borrowed, T>, // Lifetime is min of 'borrowed and &'me self. vec: Vec, } +impl<'borrowed, T> Default for CowSlice<'borrowed, T> { + fn default() -> Self { + Self { + slice: Default::default(), + vec: Default::default(), + } + } +} impl<'borrowed, T> CowSlice<'borrowed, T> { /// Creates a [`CowSlice`] with an allocation of `vec`. None of `vec`'s elements are kept. pub fn with_allocation(mut vec: Vec) -> Self { From fee3e5a8df062ba536156eaee1c18c243f1f3979 Mon Sep 17 00:00:00 2001 From: Finn Bear Date: Mon, 25 Aug 2025 08:55:02 -0700 Subject: [PATCH 02/11] no_std. --- src/derive/variant.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/derive/variant.rs b/src/derive/variant.rs index ba7154a..e042948 100644 --- a/src/derive/variant.rs +++ b/src/derive/variant.rs @@ -19,7 +19,7 @@ impl Encoder for VariantEncoder { impl Buffer for VariantEncoder { fn collect_into(&mut self, out: &mut Vec) { assert!(N >= 2); - if std::mem::size_of::() > 1 { + if core::mem::size_of::() > 1 { pack_ints(self.0.as_mut_slice(), out); } else { pack_bytes_less_than::(bytemuck::must_cast_slice::(self.0.as_slice()), out); @@ -65,7 +65,7 @@ impl<'a, T: Int, const N: usize, const C_STYLE: bool> View<'a> // SAFETY: Checked the type above and [u8; 1] has the // same memory layout as `u8`. let out = unsafe { - std::mem::transmute::<&mut CowSlice<'a, T::Une>, &mut CowSlice<'a, u8>>( + core::mem::transmute::<&mut CowSlice<'a, T::Une>, &mut CowSlice<'a, u8>>( &mut self.variants, ) }; From bf54b777a27c2a9a248056999db67f31f7a97d0d Mon Sep 17 00:00:00 2001 From: Finn Bear Date: Mon, 25 Aug 2025 09:07:39 -0700 Subject: [PATCH 03/11] Fuzz. --- fuzz/fuzz_targets/fuzz.rs | 52 ++++++++++++++++++++++++++++++++++++--- 1 file changed, 48 insertions(+), 4 deletions(-) diff --git a/fuzz/fuzz_targets/fuzz.rs b/fuzz/fuzz_targets/fuzz.rs index a9beda4..6557ef9 100644 --- a/fuzz/fuzz_targets/fuzz.rs +++ b/fuzz/fuzz_targets/fuzz.rs @@ -3,14 +3,14 @@ use libfuzzer_sys::fuzz_target; extern crate bitcode; use arrayvec::{ArrayString, ArrayVec}; use bitcode::{Decode, DecodeOwned, Encode}; +use rust_decimal::Decimal; use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; use std::collections::{BTreeMap, HashMap}; use std::fmt::Debug; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; use std::num::NonZeroU32; use std::time::Duration; -use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; -use rust_decimal::Decimal; #[inline(never)] fn test_derive(data: &[u8]) { @@ -140,6 +140,39 @@ fuzz_target!(|data: &[u8]| { pub enum Enum16 { A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P } #[derive(Serialize, Deserialize, Encode, Decode, Debug, PartialEq)] pub enum Enum17 { A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q } + #[derive(Serialize, Deserialize, Encode, Decode, Debug, PartialEq)] + enum Enum300 { + V1, V2, V3, V4, V5, V6, V7, V8, V9, V10, + V11, V12, V13, V14, V15, V16, V17, V18, V19, V20, + V21, V22, V23, V24, V25, V26, V27, V28, V29, V30, + V31, V32, V33, V34, V35, V36, V37, V38, V39, V40, + V41, V42, V43, V44, V45, V46, V47, V48, V49, V50, + V51, V52, V53, V54, V55, V56, V57, V58, V59, V60, + V61, V62, V63, V64, V65, V66, V67, V68, V69, V70, + V71, V72, V73, V74, V75, V76, V77, V78, V79, V80, + V81, V82, V83, V84, V85, V86, V87, V88, V89, V90, + V91, V92, V93, V94, V95, V96, V97, V98, V99, V100, + V101, V102, V103, V104, V105, V106, V107, V108, V109, V110, + V111, V112, V113, V114, V115, V116, V117, V118, V119, V120, + V121, V122, V123, V124, V125, V126, V127, V128, V129, V130, + V131, V132, V133, V134, V135, V136, V137, V138, V139, V140, + V141, V142, V143, V144, V145, V146, V147, V148, V149, V150, + V151, V152, V153, V154, V155, V156, V157, V158, V159, V160, + V161, V162, V163, V164, V165, V166, V167, V168, V169, V170, + V171, V172, V173, V174, V175, V176, V177, V178, V179, V180, + V181, V182, V183, V184, V185, V186, V187, V188, V189, V190, + V191, V192, V193, V194, V195, V196, V197, V198, V199, V200, + V201, V202, V203, V204, V205, V206, V207, V208, V209, V210, + V211, V212, V213, V214, V215, V216, V217, V218, V219, V220, + V221, V222, V223, V224, V225, V226, V227, V228, V229, V230, + V231, V232, V233, V234, V235, V236, V237, V238, V239, V240, + V241, V242, V243, V244, V245, V246, V247, V248, V249, V250, + V251, V252, V253, V254, V255, V256, V257, V258, V259, V260, + V261, V262, V263, V264, V265, V266, V267, V268, V269, V270, + V271, V272, V273, V274, V275, V276, V277, V278, V279, V280, + V281, V282, V283, V284, V285, V286, V287, V288, V289, V290, + V291, V292, V293, V294, V295, V296, V297, V298, V299, V300, + } } use enums::*; @@ -148,10 +181,20 @@ fuzz_target!(|data: &[u8]| { A, B, C(u16), - D { a: u8, b: u8, #[serde(skip)] #[bitcode(skip)] c: u8 }, + D { + a: u8, + b: u8, + #[serde(skip)] + #[bitcode(skip)] + c: u8, + }, E(String), F, - G(#[bitcode(skip)] #[serde(skip)] i16), + G( + #[bitcode(skip)] + #[serde(skip)] + i16, + ), P(BTreeMap), } @@ -219,6 +262,7 @@ fuzz_target!(|data: &[u8]| { Enum15, Enum16, Enum17, + Enum300, Enum, ArrayString<5>, ArrayString<70>, From 3d32a3b8f4cefb34ff092f9bdfa3536326340be5 Mon Sep 17 00:00:00 2001 From: Finn Bear Date: Mon, 25 Aug 2025 09:09:49 -0700 Subject: [PATCH 04/11] pub in fuzz. --- fuzz/fuzz_targets/fuzz.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fuzz/fuzz_targets/fuzz.rs b/fuzz/fuzz_targets/fuzz.rs index 6557ef9..aed0aad 100644 --- a/fuzz/fuzz_targets/fuzz.rs +++ b/fuzz/fuzz_targets/fuzz.rs @@ -141,7 +141,7 @@ fuzz_target!(|data: &[u8]| { #[derive(Serialize, Deserialize, Encode, Decode, Debug, PartialEq)] pub enum Enum17 { A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q } #[derive(Serialize, Deserialize, Encode, Decode, Debug, PartialEq)] - enum Enum300 { + pub enum Enum300 { V1, V2, V3, V4, V5, V6, V7, V8, V9, V10, V11, V12, V13, V14, V15, V16, V17, V18, V19, V20, V21, V22, V23, V24, V25, V26, V27, V28, V29, V30, From 2047753287dbdac5faf483e84d32c405a0169048 Mon Sep 17 00:00:00 2001 From: Finn Bear Date: Mon, 25 Aug 2025 09:21:45 -0700 Subject: [PATCH 05/11] Variant bounds check. --- src/derive/variant.rs | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/derive/variant.rs b/src/derive/variant.rs index e042948..76d11d4 100644 --- a/src/derive/variant.rs +++ b/src/derive/variant.rs @@ -1,4 +1,5 @@ use crate::coder::{Buffer, Decoder, Encoder, Result, View}; +use crate::error::err; use crate::fast::{CowSlice, NextUnchecked, PushUnchecked, VecImpl}; use crate::pack::{pack_bytes_less_than, unpack_bytes_less_than}; use crate::pack_ints::{pack_ints, unpack_ints, Int}; @@ -54,13 +55,19 @@ impl<'a, T: Int, const N: usize> VariantDecoder<'a, T, N, false> { } } -impl<'a, T: Int, const N: usize, const C_STYLE: bool> View<'a> +impl<'a, T: Int + Into, const N: usize, const C_STYLE: bool> View<'a> for VariantDecoder<'a, T, N, C_STYLE> { fn populate(&mut self, input: &mut &'a [u8], length: usize) -> Result<()> { assert!(N >= 2); if TypeId::of::() != TypeId::of::() { unpack_ints::(input, length, &mut self.variants)?; + // TOOD: this uses extra memory bandwith to rescan. + for int in unsafe { self.variants.as_slice(length) } { + if T::from_unaligned(*int).into() >= N { + return err("invalid enum variant index"); + } + } } else { // SAFETY: Checked the type above and [u8; 1] has the // same memory layout as `u8`. @@ -79,7 +86,7 @@ impl<'a, T: Int, const N: usize, const C_STYLE: bool> View<'a> } } -impl<'a, T: Int, const N: usize, const C_STYLE: bool> Decoder<'a, T> +impl<'a, T: Int + Into, const N: usize, const C_STYLE: bool> Decoder<'a, T> for VariantDecoder<'a, T, N, C_STYLE> { // Guaranteed to output numbers less than N. From c0f99c0a0c2dd153f129aa07c8fcb71b61f18794 Mon Sep 17 00:00:00 2001 From: Finn Bear Date: Thu, 28 Aug 2025 14:59:42 -0700 Subject: [PATCH 06/11] Optimizations. --- src/derive/variant.rs | 115 +++++++++++++++++++++++++----------------- src/serde/ser.rs | 5 ++ 2 files changed, 74 insertions(+), 46 deletions(-) diff --git a/src/derive/variant.rs b/src/derive/variant.rs index 76d11d4..33c4e54 100644 --- a/src/derive/variant.rs +++ b/src/derive/variant.rs @@ -20,7 +20,7 @@ impl Encoder for VariantEncoder { impl Buffer for VariantEncoder { fn collect_into(&mut self, out: &mut Vec) { assert!(N >= 2); - if core::mem::size_of::() > 1 { + if TypeId::of::() != TypeId::of::() { pack_ints(self.0.as_mut_slice(), out); } else { pack_bytes_less_than::(bytemuck::must_cast_slice::(self.0.as_slice()), out); @@ -62,20 +62,30 @@ impl<'a, T: Int + Into, const N: usize, const C_STYLE: bool> View<'a> assert!(N >= 2); if TypeId::of::() != TypeId::of::() { unpack_ints::(input, length, &mut self.variants)?; - // TOOD: this uses extra memory bandwith to rescan. - for int in unsafe { self.variants.as_slice(length) } { - if T::from_unaligned(*int).into() >= N { + + /// Checks that `unpacked` ints are less than `N`, hopefully + /// without a branch instruction for every int. + fn check_less_than, const N: usize>( + unpacked: &[T::Une], + ) -> Result<()> { + if 2u64.pow(std::mem::size_of::() as u32 * 8) - 1 > N as u64 + && unpacked + .iter() + .copied() + .map(T::from_unaligned) + .max() + .map(Into::into) + .unwrap_or(0) + >= N + { return err("invalid enum variant index"); } + Ok(()) } + + check_less_than::(unsafe { self.variants.as_slice(length) })?; } else { - // SAFETY: Checked the type above and [u8; 1] has the - // same memory layout as `u8`. - let out = unsafe { - core::mem::transmute::<&mut CowSlice<'a, T::Une>, &mut CowSlice<'a, u8>>( - &mut self.variants, - ) - }; + let out = self.variants.cast_mut::(); if C_STYLE { unpack_bytes_less_than::(input, length, out)?; } else { @@ -130,41 +140,6 @@ mod tests { #[allow(unused)] #[test] fn test_large_c_style_enum() { - #[cfg_attr(not(test), rustfmt::skip)] - #[derive(Encode, Decode)] - enum Enum300 { - V1, V2, V3, V4, V5, V6, V7, V8, V9, V10, - V11, V12, V13, V14, V15, V16, V17, V18, V19, V20, - V21, V22, V23, V24, V25, V26, V27, V28, V29, V30, - V31, V32, V33, V34, V35, V36, V37, V38, V39, V40, - V41, V42, V43, V44, V45, V46, V47, V48, V49, V50, - V51, V52, V53, V54, V55, V56, V57, V58, V59, V60, - V61, V62, V63, V64, V65, V66, V67, V68, V69, V70, - V71, V72, V73, V74, V75, V76, V77, V78, V79, V80, - V81, V82, V83, V84, V85, V86, V87, V88, V89, V90, - V91, V92, V93, V94, V95, V96, V97, V98, V99, V100, - V101, V102, V103, V104, V105, V106, V107, V108, V109, V110, - V111, V112, V113, V114, V115, V116, V117, V118, V119, V120, - V121, V122, V123, V124, V125, V126, V127, V128, V129, V130, - V131, V132, V133, V134, V135, V136, V137, V138, V139, V140, - V141, V142, V143, V144, V145, V146, V147, V148, V149, V150, - V151, V152, V153, V154, V155, V156, V157, V158, V159, V160, - V161, V162, V163, V164, V165, V166, V167, V168, V169, V170, - V171, V172, V173, V174, V175, V176, V177, V178, V179, V180, - V181, V182, V183, V184, V185, V186, V187, V188, V189, V190, - V191, V192, V193, V194, V195, V196, V197, V198, V199, V200, - V201, V202, V203, V204, V205, V206, V207, V208, V209, V210, - V211, V212, V213, V214, V215, V216, V217, V218, V219, V220, - V221, V222, V223, V224, V225, V226, V227, V228, V229, V230, - V231, V232, V233, V234, V235, V236, V237, V238, V239, V240, - V241, V242, V243, V244, V245, V246, V247, V248, V249, V250, - V251, V252, V253, V254, V255, V256, V257, V258, V259, V260, - V261, V262, V263, V264, V265, V266, V267, V268, V269, V270, - V271, V272, V273, V274, V275, V276, V277, V278, V279, V280, - V281, V282, V283, V284, V285, V286, V287, V288, V289, V290, - V291, V292, V293, V294, V295, V296, V297, V298, V299, V300, - } - assert!(matches!(decode(&encode(&Enum300::V42)), Ok(Enum300::V42))); assert!(matches!(decode(&encode(&Enum300::V300)), Ok(Enum300::V300))); } @@ -207,4 +182,52 @@ mod tests { .collect() } crate::bench_encode_decode!(bool_enum_vec: Vec<_>); + + #[cfg_attr(not(test), rustfmt::skip)] + #[derive(Encode, Decode, Debug, PartialEq)] + pub enum Enum300 { + V1, V2, V3, V4, V5, V6, V7, V8, V9, V10, + V11, V12, V13, V14, V15, V16, V17, V18, V19, V20, + V21, V22, V23, V24, V25, V26, V27, V28, V29, V30, + V31, V32, V33, V34, V35, V36, V37, V38, V39, V40, + V41, V42, V43, V44, V45, V46, V47, V48, V49, V50, + V51, V52, V53, V54, V55, V56, V57, V58, V59, V60, + V61, V62, V63, V64, V65, V66, V67, V68, V69, V70, + V71, V72, V73, V74, V75, V76, V77, V78, V79, V80, + V81, V82, V83, V84, V85, V86, V87, V88, V89, V90, + V91, V92, V93, V94, V95, V96, V97, V98, V99, V100, + V101, V102, V103, V104, V105, V106, V107, V108, V109, V110, + V111, V112, V113, V114, V115, V116, V117, V118, V119, V120, + V121, V122, V123, V124, V125, V126, V127, V128, V129, V130, + V131, V132, V133, V134, V135, V136, V137, V138, V139, V140, + V141, V142, V143, V144, V145, V146, V147, V148, V149, V150, + V151, V152, V153, V154, V155, V156, V157, V158, V159, V160, + V161, V162, V163, V164, V165, V166, V167, V168, V169, V170, + V171, V172, V173, V174, V175, V176, V177, V178, V179, V180, + V181, V182, V183, V184, V185, V186, V187, V188, V189, V190, + V191, V192, V193, V194, V195, V196, V197, V198, V199, V200, + V201, V202, V203, V204, V205, V206, V207, V208, V209, V210, + V211, V212, V213, V214, V215, V216, V217, V218, V219, V220, + V221, V222, V223, V224, V225, V226, V227, V228, V229, V230, + V231, V232, V233, V234, V235, V236, V237, V238, V239, V240, + V241, V242, V243, V244, V245, V246, V247, V248, V249, V250, + V251, V252, V253, V254, V255, V256, V257, V258, V259, V260, + V261, V262, V263, V264, V265, V266, V267, V268, V269, V270, + V271, V272, V273, V274, V275, V276, V277, V278, V279, V280, + V281, V282, V283, V284, V285, V286, V287, V288, V289, V290, + V291, V292, V293, V294, V295, V296, V297, V298, V299, V300, + } +} + +#[cfg(test)] +mod test2 { + use crate::derive::variant::tests::Enum300; + + fn bench_data() -> Vec { + crate::random_data(1000) + .into_iter() + .map(|v: u16| unsafe { core::mem::transmute_copy::<_, Enum300>(&(v % 300)) }) + .collect() + } + crate::bench_encode_decode!(enum_300_variants_vec: Vec<_>); } diff --git a/src/serde/ser.rs b/src/serde/ser.rs index 63057a6..331c467 100644 --- a/src/serde/ser.rs +++ b/src/serde/ser.rs @@ -240,6 +240,11 @@ impl<'a> EncoderWrapper<'a> { #[inline(always)] fn variant_index_u8(variant_index: u32) -> Result { if variant_index > u8::MAX as u32 { + // Properly optimizing the size of large enums would + // require `serde` to specify the variant count. + // + // Good news: the `derive` version of `bitcode` supports + // arbitrary-sized fieldless enums! err("enums with more than 256 variants are unsupported") } else { Ok(variant_index as u8) From 5260315c262b1c6eced6fc53b718f5a31a89eca3 Mon Sep 17 00:00:00 2001 From: Finn Bear Date: Thu, 28 Aug 2025 15:18:41 -0700 Subject: [PATCH 07/11] Remove histogram while decoding C-style enums. --- bitcode_derive/src/decode.rs | 7 ++++++- src/derive/option.rs | 2 +- src/derive/result.rs | 2 +- src/derive/variant.rs | 28 ++++++++++++++-------------- 4 files changed, 22 insertions(+), 17 deletions(-) diff --git a/bitcode_derive/src/decode.rs b/bitcode_derive/src/decode.rs index c4c3d0a..85207fb 100644 --- a/bitcode_derive/src/decode.rs +++ b/bitcode_derive/src/decode.rs @@ -127,7 +127,12 @@ impl crate::shared::Item for Item { .then(|| { let private = private(crate_name); let c_style = inners.is_empty(); - quote! { variants: #private::VariantDecoder<#de, #variant_index, #variant_count, #c_style>, } + let histogram = if c_style { + 0 + } else { + variant_count + }; + quote! { variants: #private::VariantDecoder<#de, #variant_index, #variant_count, #histogram>, } }) .unwrap_or_default(); quote! { diff --git a/src/derive/option.rs b/src/derive/option.rs index 6064875..967aec6 100644 --- a/src/derive/option.rs +++ b/src/derive/option.rs @@ -86,7 +86,7 @@ impl Buffer for OptionEncoder { } pub struct OptionDecoder<'a, T: Decode<'a>> { - variants: VariantDecoder<'a, u8, 2, false>, + variants: VariantDecoder<'a, u8, 2, 2>, some: T::Decoder, } diff --git a/src/derive/result.rs b/src/derive/result.rs index 7364c82..fb7dede 100644 --- a/src/derive/result.rs +++ b/src/derive/result.rs @@ -55,7 +55,7 @@ impl Buffer for ResultEncoder { } pub struct ResultDecoder<'a, T: Decode<'a>, E: Decode<'a>> { - variants: VariantDecoder<'a, u8, 2, false>, + variants: VariantDecoder<'a, u8, 2, 2>, ok: T::Decoder, err: E::Decoder, } diff --git a/src/derive/variant.rs b/src/derive/variant.rs index 33c4e54..cafde8a 100644 --- a/src/derive/variant.rs +++ b/src/derive/variant.rs @@ -33,13 +33,16 @@ impl Buffer for VariantEncoder { } } -pub struct VariantDecoder<'a, T: Int, const N: usize, const C_STYLE: bool> { +pub struct VariantDecoder<'a, T: Int, const N: usize, const HISTOGRAM: usize> { variants: CowSlice<'a, T::Une>, - histogram: [usize; N], // Not required if C_STYLE. TODO don't reserve space for it. + // `HISTOGRAM` is 0 for C style (fieldless) enums. + histogram: [usize; HISTOGRAM], } // [(); N] doesn't implement Default. -impl Default for VariantDecoder<'_, T, N, C_STYLE> { +impl Default + for VariantDecoder<'_, T, N, HISTOGRAM> +{ fn default() -> Self { Self { variants: Default::default(), @@ -48,15 +51,16 @@ impl Default for VariantDecoder<'_, } } -// C style enums don't require length, so we can skip making a histogram for them. -impl<'a, T: Int, const N: usize> VariantDecoder<'a, T, N, false> { +// C style enums (`HISTOGRAM` = 0) don't require length, so we +// can skip making a histogram for them. +impl<'a, T: Int, const N: usize> VariantDecoder<'a, T, N, N> { pub fn length(&self, variant_index: u8) -> usize { self.histogram[variant_index as usize] } } -impl<'a, T: Int + Into, const N: usize, const C_STYLE: bool> View<'a> - for VariantDecoder<'a, T, N, C_STYLE> +impl<'a, T: Int + Into, const N: usize, const HISTOGRAM: usize> View<'a> + for VariantDecoder<'a, T, N, HISTOGRAM> { fn populate(&mut self, input: &mut &'a [u8], length: usize) -> Result<()> { assert!(N >= 2); @@ -86,18 +90,14 @@ impl<'a, T: Int + Into, const N: usize, const C_STYLE: bool> View<'a> check_less_than::(unsafe { self.variants.as_slice(length) })?; } else { let out = self.variants.cast_mut::(); - if C_STYLE { - unpack_bytes_less_than::(input, length, out)?; - } else { - self.histogram = unpack_bytes_less_than::(input, length, out)?; - } + self.histogram = unpack_bytes_less_than::(input, length, out)?; } Ok(()) } } -impl<'a, T: Int + Into, const N: usize, const C_STYLE: bool> Decoder<'a, T> - for VariantDecoder<'a, T, N, C_STYLE> +impl<'a, T: Int + Into, const N: usize, const HISTOGRAM: usize> Decoder<'a, T> + for VariantDecoder<'a, T, N, HISTOGRAM> { // Guaranteed to output numbers less than N. #[inline(always)] From 608233b708182e1cd7bdefd8df8247e8743c81de Mon Sep 17 00:00:00 2001 From: Finn Bear Date: Thu, 28 Aug 2025 15:19:25 -0700 Subject: [PATCH 08/11] No std. --- src/derive/variant.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/derive/variant.rs b/src/derive/variant.rs index cafde8a..4f6aead 100644 --- a/src/derive/variant.rs +++ b/src/derive/variant.rs @@ -72,7 +72,7 @@ impl<'a, T: Int + Into, const N: usize, const HISTOGRAM: usize> View<'a> fn check_less_than, const N: usize>( unpacked: &[T::Une], ) -> Result<()> { - if 2u64.pow(std::mem::size_of::() as u32 * 8) - 1 > N as u64 + if 2u64.pow(core::mem::size_of::() as u32 * 8) - 1 > N as u64 && unpacked .iter() .copied() From 5aad8d9db474a3ee75066816812403fcd6f25a08 Mon Sep 17 00:00:00 2001 From: Finn Bear Date: Thu, 28 Aug 2025 15:33:09 -0700 Subject: [PATCH 09/11] Fix new test. --- src/derive/variant.rs | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/src/derive/variant.rs b/src/derive/variant.rs index 4f6aead..630c855 100644 --- a/src/derive/variant.rs +++ b/src/derive/variant.rs @@ -137,13 +137,6 @@ mod tests { assert!(matches!(decode(&encode(&Enum1::F)), Ok(Enum1::F))); } - #[allow(unused)] - #[test] - fn test_large_c_style_enum() { - assert!(matches!(decode(&encode(&Enum300::V42)), Ok(Enum300::V42))); - assert!(matches!(decode(&encode(&Enum300::V300)), Ok(Enum300::V300))); - } - #[allow(unused)] #[test] fn test_rust_style_enum() { @@ -182,6 +175,12 @@ mod tests { .collect() } crate::bench_encode_decode!(bool_enum_vec: Vec<_>); +} + +#[cfg(test)] +mod test2 { + use crate::{decode, encode, Decode, Encode}; + use alloc::vec::Vec; #[cfg_attr(not(test), rustfmt::skip)] #[derive(Encode, Decode, Debug, PartialEq)] @@ -217,11 +216,13 @@ mod tests { V281, V282, V283, V284, V285, V286, V287, V288, V289, V290, V291, V292, V293, V294, V295, V296, V297, V298, V299, V300, } -} -#[cfg(test)] -mod test2 { - use crate::derive::variant::tests::Enum300; + #[allow(unused)] + #[test] + fn test_large_c_style_enum() { + assert!(matches!(decode(&encode(&Enum300::V42)), Ok(Enum300::V42))); + assert!(matches!(decode(&encode(&Enum300::V300)), Ok(Enum300::V300))); + } fn bench_data() -> Vec { crate::random_data(1000) From 2c1414474f73ace09975d995f016789155a7faae Mon Sep 17 00:00:00 2001 From: Finn Bear Date: Thu, 28 Aug 2025 15:45:51 -0700 Subject: [PATCH 10/11] Off by one error. --- src/derive/variant.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/derive/variant.rs b/src/derive/variant.rs index 630c855..8389b0b 100644 --- a/src/derive/variant.rs +++ b/src/derive/variant.rs @@ -72,7 +72,7 @@ impl<'a, T: Int + Into, const N: usize, const HISTOGRAM: usize> View<'a> fn check_less_than, const N: usize>( unpacked: &[T::Une], ) -> Result<()> { - if 2u64.pow(core::mem::size_of::() as u32 * 8) - 1 > N as u64 + if 2u64.pow(core::mem::size_of::() as u32 * 8) > N as u64 && unpacked .iter() .copied() From a4603205a08239965198e10b86ac5e55a67338b9 Mon Sep 17 00:00:00 2001 From: Finn Bear Date: Thu, 28 Aug 2025 15:57:39 -0700 Subject: [PATCH 11/11] Code cleanup. --- bitcode_derive/src/shared.rs | 13 ++++++++----- src/derive/variant.rs | 2 +- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/bitcode_derive/src/shared.rs b/bitcode_derive/src/shared.rs index a3b92cd..5c123ae 100644 --- a/bitcode_derive/src/shared.rs +++ b/bitcode_derive/src/shared.rs @@ -69,11 +69,14 @@ impl VariantIndex { impl ToTokens for VariantIndex { fn to_tokens(&self, tokens: &mut TokenStream) { use quote::TokenStreamExt; - tokens.append(match self { - Self::U8 => Ident::new("u8", Span::call_site()), - Self::U16 => Ident::new("u16", Span::call_site()), - Self::U32 => Ident::new("u32", Span::call_site()), - }); + tokens.append(Ident::new( + match self { + Self::U8 => "u8", + Self::U16 => "u16", + Self::U32 => "u32", + }, + Span::call_site(), + )); } } diff --git a/src/derive/variant.rs b/src/derive/variant.rs index 8389b0b..c2c230d 100644 --- a/src/derive/variant.rs +++ b/src/derive/variant.rs @@ -102,7 +102,7 @@ impl<'a, T: Int + Into, const N: usize, const HISTOGRAM: usize> Decoder<' // Guaranteed to output numbers less than N. #[inline(always)] fn decode(&mut self) -> T { - bytemuck::must_cast(unsafe { self.variants.mut_slice().next_unchecked() }) + T::from_unaligned(unsafe { self.variants.mut_slice().next_unchecked() }) } }