use crate::task::{waker_ref, ArcWake};
use futures_core::future::{FusedFuture, Future};
use futures_core::task::{Context, Poll, Waker};
use slab::Slab;
use std::cell::UnsafeCell;
use std::fmt;
use std::hash::Hasher;
use std::pin::Pin;
use std::ptr;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering::{Acquire, SeqCst};
use std::sync::{Arc, Mutex, Weak};
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct Shared<Fut: Future> {
inner: Option<Arc<Inner<Fut>>>,
waker_key: usize,
}
struct Inner<Fut: Future> {
future_or_output: UnsafeCell<FutureOrOutput<Fut>>,
notifier: Arc<Notifier>,
}
struct Notifier {
state: AtomicUsize,
wakers: Mutex<Option<Slab<Option<Waker>>>>,
}
pub struct WeakShared<Fut: Future>(Weak<Inner<Fut>>);
impl<Fut: Future> Clone for WeakShared<Fut> {
fn clone(&self) -> Self {
Self(self.0.clone())
}
}
impl<Fut: Future> Unpin for Shared<Fut> {}
impl<Fut: Future> fmt::Debug for Shared<Fut> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Shared")
.field("inner", &self.inner)
.field("waker_key", &self.waker_key)
.finish()
}
}
impl<Fut: Future> fmt::Debug for Inner<Fut> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Inner").finish()
}
}
impl<Fut: Future> fmt::Debug for WeakShared<Fut> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("WeakShared").finish()
}
}
enum FutureOrOutput<Fut: Future> {
Future(Fut),
Output(Fut::Output),
}
unsafe impl<Fut> Send for Inner<Fut>
where
Fut: Future + Send,
Fut::Output: Send + Sync,
{
}
unsafe impl<Fut> Sync for Inner<Fut>
where
Fut: Future + Send,
Fut::Output: Send + Sync,
{
}
const IDLE: usize = 0;
const POLLING: usize = 1;
const COMPLETE: usize = 2;
const POISONED: usize = 3;
const NULL_WAKER_KEY: usize = usize::max_value();
impl<Fut: Future> Shared<Fut> {
pub(super) fn new(future: Fut) -> Self {
let inner = Inner {
future_or_output: UnsafeCell::new(FutureOrOutput::Future(future)),
notifier: Arc::new(Notifier {
state: AtomicUsize::new(IDLE),
wakers: Mutex::new(Some(Slab::new())),
}),
};
Self { inner: Some(Arc::new(inner)), waker_key: NULL_WAKER_KEY }
}
}
impl<Fut> Shared<Fut>
where
Fut: Future,
{
pub fn peek(&self) -> Option<&Fut::Output> {
if let Some(inner) = self.inner.as_ref() {
match inner.notifier.state.load(SeqCst) {
COMPLETE => unsafe { return Some(inner.output()) },
POISONED => panic!("inner future panicked during poll"),
_ => {}
}
}
None
}
pub fn downgrade(&self) -> Option<WeakShared<Fut>> {
if let Some(inner) = self.inner.as_ref() {
return Some(WeakShared(Arc::downgrade(inner)));
}
None
}
#[allow(clippy::unnecessary_safety_doc)]
pub fn strong_count(&self) -> Option<usize> {
self.inner.as_ref().map(|arc| Arc::strong_count(arc))
}
#[allow(clippy::unnecessary_safety_doc)]
pub fn weak_count(&self) -> Option<usize> {
self.inner.as_ref().map(|arc| Arc::weak_count(arc))
}
pub fn ptr_hash<H: Hasher>(&self, state: &mut H) {
match self.inner.as_ref() {
Some(arc) => {
state.write_u8(1);
ptr::hash(Arc::as_ptr(arc), state);
}
None => {
state.write_u8(0);
}
}
}
pub fn ptr_eq(&self, rhs: &Self) -> bool {
let lhs = match self.inner.as_ref() {
Some(lhs) => lhs,
None => return false,
};
let rhs = match rhs.inner.as_ref() {
Some(rhs) => rhs,
None => return false,
};
Arc::ptr_eq(lhs, rhs)
}
}
impl<Fut> Inner<Fut>
where
Fut: Future,
{
unsafe fn output(&self) -> &Fut::Output {
match &*self.future_or_output.get() {
FutureOrOutput::Output(ref item) => item,
FutureOrOutput::Future(_) => unreachable!(),
}
}
}
impl<Fut> Inner<Fut>
where
Fut: Future,
Fut::Output: Clone,
{
fn record_waker(&self, waker_key: &mut usize, cx: &mut Context<'_>) {
let mut wakers_guard = self.notifier.wakers.lock().unwrap();
let wakers = match wakers_guard.as_mut() {
Some(wakers) => wakers,
None => return,
};
let new_waker = cx.waker();
if *waker_key == NULL_WAKER_KEY {
*waker_key = wakers.insert(Some(new_waker.clone()));
} else {
match wakers[*waker_key] {
Some(ref old_waker) if new_waker.will_wake(old_waker) => {}
ref mut slot => *slot = Some(new_waker.clone()),
}
}
debug_assert!(*waker_key != NULL_WAKER_KEY);
}
unsafe fn take_or_clone_output(self: Arc<Self>) -> Fut::Output {
match Arc::try_unwrap(self) {
Ok(inner) => match inner.future_or_output.into_inner() {
FutureOrOutput::Output(item) => item,
FutureOrOutput::Future(_) => unreachable!(),
},
Err(inner) => inner.output().clone(),
}
}
}
impl<Fut> FusedFuture for Shared<Fut>
where
Fut: Future,
Fut::Output: Clone,
{
fn is_terminated(&self) -> bool {
self.inner.is_none()
}
}
impl<Fut> Future for Shared<Fut>
where
Fut: Future,
Fut::Output: Clone,
{
type Output = Fut::Output;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = &mut *self;
let inner = this.inner.take().expect("Shared future polled again after completion");
if inner.notifier.state.load(Acquire) == COMPLETE {
return unsafe { Poll::Ready(inner.take_or_clone_output()) };
}
inner.record_waker(&mut this.waker_key, cx);
match inner
.notifier
.state
.compare_exchange(IDLE, POLLING, SeqCst, SeqCst)
.unwrap_or_else(|x| x)
{
IDLE => {
}
POLLING => {
this.inner = Some(inner);
return Poll::Pending;
}
COMPLETE => {
return unsafe { Poll::Ready(inner.take_or_clone_output()) };
}
POISONED => panic!("inner future panicked during poll"),
_ => unreachable!(),
}
let waker = waker_ref(&inner.notifier);
let mut cx = Context::from_waker(&waker);
struct Reset<'a> {
state: &'a AtomicUsize,
did_not_panic: bool,
}
impl Drop for Reset<'_> {
fn drop(&mut self) {
if !self.did_not_panic {
self.state.store(POISONED, SeqCst);
}
}
}
let mut reset = Reset { state: &inner.notifier.state, did_not_panic: false };
let output = {
let future = unsafe {
match &mut *inner.future_or_output.get() {
FutureOrOutput::Future(fut) => Pin::new_unchecked(fut),
_ => unreachable!(),
}
};
let poll_result = future.poll(&mut cx);
reset.did_not_panic = true;
match poll_result {
Poll::Pending => {
if inner.notifier.state.compare_exchange(POLLING, IDLE, SeqCst, SeqCst).is_ok()
{
drop(reset);
this.inner = Some(inner);
return Poll::Pending;
} else {
unreachable!()
}
}
Poll::Ready(output) => output,
}
};
unsafe {
*inner.future_or_output.get() = FutureOrOutput::Output(output);
}
inner.notifier.state.store(COMPLETE, SeqCst);
let mut wakers_guard = inner.notifier.wakers.lock().unwrap();
let mut wakers = wakers_guard.take().unwrap();
for waker in wakers.drain().flatten() {
waker.wake();
}
drop(reset); drop(wakers_guard);
unsafe { Poll::Ready(inner.take_or_clone_output()) }
}
}
impl<Fut> Clone for Shared<Fut>
where
Fut: Future,
{
fn clone(&self) -> Self {
Self { inner: self.inner.clone(), waker_key: NULL_WAKER_KEY }
}
}
impl<Fut> Drop for Shared<Fut>
where
Fut: Future,
{
fn drop(&mut self) {
if self.waker_key != NULL_WAKER_KEY {
if let Some(ref inner) = self.inner {
if let Ok(mut wakers) = inner.notifier.wakers.lock() {
if let Some(wakers) = wakers.as_mut() {
wakers.remove(self.waker_key);
}
}
}
}
}
}
impl ArcWake for Notifier {
fn wake_by_ref(arc_self: &Arc<Self>) {
let wakers = &mut *arc_self.wakers.lock().unwrap();
if let Some(wakers) = wakers.as_mut() {
for (_key, opt_waker) in wakers {
if let Some(waker) = opt_waker.take() {
waker.wake();
}
}
}
}
}
impl<Fut: Future> WeakShared<Fut> {
pub fn upgrade(&self) -> Option<Shared<Fut>> {
Some(Shared { inner: Some(self.0.upgrade()?), waker_key: NULL_WAKER_KEY })
}
}