diff --git a/src/lib.rs b/src/lib.rs index 8d9801e..7d43048 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -75,8 +75,22 @@ use alloc::boxed::Box; /// it also respects any alignment requirements for the wrapped future. Note that the /// wrapped future's alignment must be less than or equal to that of the overall /// `StackFuture` struct. +// NOTE: we use a type alias rather than a default const generic argument, as that would make methods +// like StackFuture::new ambiguous when calling. +pub type StackFuture<'a, T, const STACK_SIZE: usize> = StackFutureImpl<'a, T, STACK_SIZE, true>; + +/// A variant of [`StackFuture`] which allows for futures that do not implement the [`Send`] trait. +/// +/// See the documentation of `StackFuture` for more information. +pub type LocalStackFuture<'a, T, const STACK_SIZE: usize> = StackFutureImpl<'a, T, STACK_SIZE, false>; + +/// A variant of [`StackFuture`] which supports either [`Send`] ofr non-`Send` futures, depending +/// on the value of the `SEND` const generic argument. +/// +/// In most cases, you will want to use `StackFuture` or [`LocalStackFuture`] directly. +/// See the documentation for [`StackFuture`] for more details. #[repr(C)] // Ensures the data first does not have any padding before it in the struct -pub struct StackFuture<'a, T, const STACK_SIZE: usize> { +pub struct StackFutureImpl<'a, T, const STACK_SIZE: usize, const SEND: bool> { /// An array of bytes that is used to store the wrapped future. data: [MaybeUninit; STACK_SIZE], /// Since the type of `StackFuture` does not know the underlying future that it is wrapping, @@ -84,17 +98,30 @@ pub struct StackFuture<'a, T, const STACK_SIZE: usize> { /// generated and filled in by `StackFuture::from`. /// /// This field stores a pointer to the poll function wrapper. - poll_fn: fn(this: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll, + /// + /// SAFETY: + /// * the argument `this` must be the same instance of this type that `poll_fn` was obtained from. + poll_fn: unsafe fn(this: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll, /// Stores a pointer to the drop function wrapper /// /// See the documentation on `poll_fn` for more details. - drop_fn: fn(this: &mut Self), + /// + /// SAFETY: + /// * the argument `this` must be the same instance of this type that `drop_fn` was obtained from. + /// * must only be called from the Drop impl of this type. + drop_fn: unsafe fn(this: &mut Self), /// StackFuture can be used similarly to a `dyn Future`. We keep a PhantomData /// here so the type system knows this. - _phantom: PhantomData + Send + 'a>, + _phantom: PhantomData + 'a>, } -impl<'a, T, const STACK_SIZE: usize> StackFuture<'a, T, { STACK_SIZE }> { +// SAFETY: +// We ensure by the API exposed for this type that the contained future will always be Send +// as long as the `SEND` const generic arg is true. +unsafe impl<'a, T, const STACK_SIZE: usize> Send for StackFutureImpl<'a, T, STACK_SIZE, true> +{ } + +impl<'a, T, const STACK_SIZE: usize> StackFutureImpl<'a, T, { STACK_SIZE }, true> { /// Creates a `StackFuture` from an existing future /// /// See the documentation on [`StackFuture`] for examples of how to use this. @@ -135,13 +162,88 @@ impl<'a, T, const STACK_SIZE: usize> StackFuture<'a, T, { STACK_SIZE }> { /// ``` pub fn from(future: F) -> Self where - F: Future + Send + 'a, // the bounds here should match those in the _phantom field + F: Future + Send + 'a, + { + Self::from_inner(future) + } + + /// Attempts to create a `StackFuture` from an existing future + /// + /// If the `StackFuture` is not large enough to hold `future`, this function returns an + /// `Err` with the argument `future` returned to you. + /// + /// Panics + /// + /// If we cannot satisfy the alignment requirements for `F`, this function will panic. + pub fn try_from(future: F) -> Result> + where + F: Future + Send + 'a, + { + Self::try_from_inner(future) + } + + /// Creates a StackFuture from the given future, boxing if necessary + /// + /// This version will succeed even if the future is larger than `STACK_SIZE`. If the future + /// is too large, `from_or_box` will allocate a `Box` on the heap and store the resulting + /// boxed future in the `StackFuture`. + /// + /// The same thing also happens if the wrapped future's alignment is larger than StackFuture's + /// alignment. + /// + /// This function requires the "alloc" crate feature. + #[cfg(feature = "alloc")] + pub fn from_or_box(future: F) -> Self + where + F: Future + Send + 'a, + { + Self::from_or_box_inner(future) + } +} + +impl<'a, T, const STACK_SIZE: usize> StackFutureImpl<'a, T, STACK_SIZE, false> { + /// Creates a `StackFuture` from an existing future. + /// + /// See the documentation of [`StackFuture::from`] for more details. + pub fn from(future: F) -> Self + where + F: Future + 'a, // the bounds here should match those in the _phantom field + { + Self::from_inner(future) + } + + /// Attempts to create a `StackFuture` from an existing future. + /// + /// See the documentation of [`StackFuture::try_from`] for more details. + pub fn try_from(future: F) -> Result> + where + F: Future + 'a, // the bounds here should match those in the _phantom field + { + Self::try_from_inner(future) + } + + /// Creates a StackFuture from the given future, boxing if necessary + /// + /// See the documentation of [`StackFuture::from_or_box`] for more details. + #[cfg(feature = "alloc")] + pub fn from_or_box(future: F) -> Self + where + F: Future + 'a, // the bounds here should match those in the _phantom field + { + Self::from_or_box_inner(future) + } +} + +impl<'a, T, const STACK_SIZE: usize, const SEND: bool> StackFutureImpl<'a, T, STACK_SIZE, SEND> { + fn from_inner(future: F) -> Self + where + F: Future + 'a, // the bounds here should match those in the _phantom field { // Ideally we would provide this as: // // impl<'a, F, const STACK_SIZE: usize> From for StackFuture<'a, F::Output, { STACK_SIZE }> // where - // F: Future + Send + 'a + // F: Future + 'a // // However, libcore provides a blanket `impl From for T`, and since `StackFuture: Future`, // both impls end up being applicable to do `From for StackFuture`. @@ -150,25 +252,26 @@ impl<'a, T, const STACK_SIZE: usize> StackFuture<'a, T, { STACK_SIZE }> { #[allow(clippy::let_unit_value)] let _ = AssertFits::::ASSERT; - Self::try_from(future).unwrap() + Self::try_from_inner(future).unwrap() } - /// Attempts to create a `StackFuture` from an existing future - /// - /// If the `StackFuture` is not large enough to hold `future`, this function returns an - /// `Err` with the argument `future` returned to you. - /// - /// Panics - /// - /// If we cannot satisfy the alignment requirements for `F`, this function will panic. - pub fn try_from(future: F) -> Result> + fn try_from_inner(future: F) -> Result> where - F: Future + Send + 'a, // the bounds here should match those in the _phantom field + F: Future + 'a, // the bounds here should match those in the _phantom field { if Self::has_space_for_val(&future) && Self::has_alignment_for_val(&future) { - let mut result = StackFuture { + let mut result = Self { data: [MaybeUninit::uninit(); STACK_SIZE], + // SAFETY: + // `poll_inner` and `drop_inner` both require `F` to match the future type + // used to construct `self` here. The invariants on the `poll_fn` and `drop_fn` + // fields require that they are only called using the original instance they were + // obtained from, which ensures that `F` will still match at the point + // they are called. poll_fn: Self::poll_inner::, + // SAFETY: + // the invariants on `drop_fn` transitively uphold the requirement that `drop_inner` + // is only called from the Drop impl of this type. drop_fn: Self::drop_inner::, _phantom: PhantomData, }; @@ -192,34 +295,35 @@ impl<'a, T, const STACK_SIZE: usize> StackFuture<'a, T, { STACK_SIZE }> { } } - /// Creates a StackFuture from the given future, boxing if necessary - /// - /// This version will succeed even if the future is larger than `STACK_SIZE`. If the future - /// is too large, `from_or_box` will allocate a `Box` on the heap and store the resulting - /// boxed future in the `StackFuture`. - /// - /// The same thing also happens if the wrapped future's alignment is larger than StackFuture's - /// alignment. - /// - /// This function requires the "alloc" crate feature. #[cfg(feature = "alloc")] - pub fn from_or_box(future: F) -> Self + fn from_or_box_inner(future: F) -> Self where - F: Future + Send + 'a, // the bounds here should match those in the _phantom field + F: Future + 'a, // the bounds here should match those in the _phantom field { - Self::try_from(future).unwrap_or_else(|err| Self::from(Box::pin(err.into_inner()))) + Self::try_from_inner(future).unwrap_or_else(|err| Self::from_inner(Box::pin(err.into_inner()))) } /// A wrapper around the inner future's poll function, which we store in the poll_fn field /// of this struct. - fn poll_inner(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - self.as_pin_mut_ref::().poll(cx) + /// + /// SAFETY: + /// * the generic argument `F` must be the exact same type originally used to construct + /// this instance via `try_from_inner`. + unsafe fn poll_inner(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + unsafe { self.as_pin_mut_ref::() }.poll(cx) } /// A wrapper around the inner future's drop function, which we store in the drop_fn field /// of this struct. - fn drop_inner(&mut self) { - // SAFETY: *this.as_mut_ptr() was previously written as type F + /// + /// SAFETY: + /// * the generic argument `F` must be the exact same type originally used to construct + /// this instance via `try_from_inner`. + /// * must only be called from the drop impl of this type. + unsafe fn drop_inner(&mut self) { + // SAFETY: + // * this.as_mut_ptr() was previously written as type F + // * caller ensures this will only be called from the drop impl of this type. unsafe { ptr::drop_in_place(self.as_mut_ptr::()) } } @@ -236,8 +340,12 @@ impl<'a, T, const STACK_SIZE: usize> StackFuture<'a, T, { STACK_SIZE }> { } /// Returns a pinned mutable reference to a type F stored in self.data - fn as_pin_mut_ref(self: Pin<&mut Self>) -> Pin<&mut F> { - // SAFETY: `StackFuture` is only created by `StackFuture::from`, which + /// + /// SAFETY: + /// * the generic argument `F` must be the exact same type originally used to construct + /// this instance via `try_from_inner`. + unsafe fn as_pin_mut_ref(self: Pin<&mut Self>) -> Pin<&mut F> { + // SAFETY: `StackFuture` is only created by `StackFuture::try_from_inner`, which // writes an `F` to `self.as_mut_ptr(), so it's okay to cast the `*mut F` // to an `&mut F` with the same lifetime as `self`. // @@ -277,26 +385,27 @@ impl<'a, T, const STACK_SIZE: usize> StackFuture<'a, T, { STACK_SIZE }> { } } -impl<'a, T, const STACK_SIZE: usize> Future for StackFuture<'a, T, { STACK_SIZE }> { +impl<'a, T, const STACK_SIZE: usize, const SEND: bool> Future for StackFutureImpl<'a, T, { STACK_SIZE }, SEND> { type Output = T; - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - // SAFETY: This is doing pin projection. We unpin self so we can - // access self.poll_fn, and then re-pin self to pass it into poll_in. - // The part of the struct that needs to be pinned is data, since it - // contains a potentially self-referential future object, but since we - // do not touch that while self is unpinned and we do not move self - // while unpinned we are okay. - unsafe { - let this = self.get_unchecked_mut(); - (this.poll_fn)(Pin::new_unchecked(this), cx) - } + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + // SAFETY: + // `self.poll_fn`` is private and never reassigned, so this must be the same instance + // that `poll_fn` was originally obtained from. + unsafe { (self.poll_fn)(self, cx) } } } -impl<'a, T, const STACK_SIZE: usize> Drop for StackFuture<'a, T, { STACK_SIZE }> { +impl<'a, T, const STACK_SIZE: usize, const SEND: bool> Drop for StackFutureImpl<'a, T, { STACK_SIZE }, SEND> { fn drop(&mut self) { - (self.drop_fn)(self); + // SAFETY: + // * `self.drop_fn`` is private and never reassigned, so this must be the same instance + // that `drop_fn` was originally obtained from. + // * we are calling this from the drop impl of this type, which is the only valid place + // to call `drop_fn`. + unsafe { + (self.drop_fn)(self); + } } }