Skip to content

Commit

Permalink
GH-111693: Propagate correct asyncio.CancelledError instance out of a…
Browse files Browse the repository at this point in the history
…syncio.Condition.wait() (#111694)

Also fix a race condition in `asyncio.Semaphore.acquire()` when cancelled.
  • Loading branch information
kristjanvalur authored Jan 8, 2024
1 parent c6ca562 commit 5216178
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 25 deletions.
3 changes: 0 additions & 3 deletions Lib/asyncio/futures.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,6 @@ def _make_cancelled_error(self):
exc = exceptions.CancelledError()
else:
exc = exceptions.CancelledError(self._cancel_message)
exc.__context__ = self._cancelled_exc
# Remove the reference since we don't need this anymore.
self._cancelled_exc = None
return exc

def cancel(self, msg=None):
Expand Down
61 changes: 39 additions & 22 deletions Lib/asyncio/locks.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ async def acquire(self):
This method blocks until the lock is unlocked, then sets it to
locked and returns True.
"""
# Implement fair scheduling, where thread always waits
# its turn. Jumping the queue if all are cancelled is an optimization.
if (not self._locked and (self._waiters is None or
all(w.cancelled() for w in self._waiters))):
self._locked = True
Expand All @@ -105,19 +107,22 @@ async def acquire(self):
fut = self._get_loop().create_future()
self._waiters.append(fut)

# Finally block should be called before the CancelledError
# handling as we don't want CancelledError to call
# _wake_up_first() and attempt to wake up itself.
try:
try:
await fut
finally:
self._waiters.remove(fut)
except exceptions.CancelledError:
# Currently the only exception designed be able to occur here.

# Ensure the lock invariant: If lock is not claimed (or about
# to be claimed by us) and there is a Task in waiters,
# ensure that the Task at the head will run.
if not self._locked:
self._wake_up_first()
raise

# assert self._locked is False
self._locked = True
return True

Expand All @@ -139,17 +144,15 @@ def release(self):
raise RuntimeError('Lock is not acquired.')

def _wake_up_first(self):
"""Wake up the first waiter if it isn't done."""
"""Ensure that the first waiter will wake up."""
if not self._waiters:
return
try:
fut = next(iter(self._waiters))
except StopIteration:
return

# .done() necessarily means that a waiter will wake up later on and
# either take the lock, or, if it was cancelled and lock wasn't
# taken already, will hit this again and wake up a new waiter.
# .done() means that the waiter is already set to wake up.
if not fut.done():
fut.set_result(True)

Expand Down Expand Up @@ -269,17 +272,22 @@ async def wait(self):
self._waiters.remove(fut)

finally:
# Must reacquire lock even if wait is cancelled
cancelled = False
# Must re-acquire lock even if wait is cancelled.
# We only catch CancelledError here, since we don't want any
# other (fatal) errors with the future to cause us to spin.
err = None
while True:
try:
await self.acquire()
break
except exceptions.CancelledError:
cancelled = True
except exceptions.CancelledError as e:
err = e

if cancelled:
raise exceptions.CancelledError
if err:
try:
raise err # Re-raise most recent exception instance.
finally:
err = None # Break reference cycles.

async def wait_for(self, predicate):
"""Wait until a predicate becomes true.
Expand Down Expand Up @@ -357,6 +365,7 @@ def __repr__(self):

def locked(self):
"""Returns True if semaphore cannot be acquired immediately."""
# Due to state, or FIFO rules (must allow others to run first).
return self._value == 0 or (
any(not w.cancelled() for w in (self._waiters or ())))

Expand All @@ -370,6 +379,7 @@ async def acquire(self):
True.
"""
if not self.locked():
# Maintain FIFO, wait for others to start even if _value > 0.
self._value -= 1
return True

Expand All @@ -378,22 +388,27 @@ async def acquire(self):
fut = self._get_loop().create_future()
self._waiters.append(fut)

# Finally block should be called before the CancelledError
# handling as we don't want CancelledError to call
# _wake_up_first() and attempt to wake up itself.
try:
try:
await fut
finally:
self._waiters.remove(fut)
except exceptions.CancelledError:
if not fut.cancelled():
# Currently the only exception designed be able to occur here.
if fut.done() and not fut.cancelled():
# Our Future was successfully set to True via _wake_up_next(),
# but we are not about to successfully acquire(). Therefore we
# must undo the bookkeeping already done and attempt to wake
# up someone else.
self._value += 1
self._wake_up_next()
raise

if self._value > 0:
self._wake_up_next()
finally:
# New waiters may have arrived but had to wait due to FIFO.
# Wake up as many as are allowed.
while self._value > 0:
if not self._wake_up_next():
break # There was no-one to wake up.
return True

def release(self):
Expand All @@ -408,13 +423,15 @@ def release(self):
def _wake_up_next(self):
"""Wake up the first waiter that isn't done."""
if not self._waiters:
return
return False

for fut in self._waiters:
if not fut.done():
self._value -= 1
fut.set_result(True)
return
# `fut` is now `done()` and not `cancelled()`.
return True
return False


class BoundedSemaphore(Semaphore):
Expand Down
113 changes: 113 additions & 0 deletions Lib/test/test_asyncio/test_locks.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,6 +758,63 @@ async def test_timeout_in_block(self):
with self.assertRaises(asyncio.TimeoutError):
await asyncio.wait_for(condition.wait(), timeout=0.5)

async def test_cancelled_error_wakeup(self):
# Test that a cancelled error, received when awaiting wakeup,
# will be re-raised un-modified.
wake = False
raised = None
cond = asyncio.Condition()

async def func():
nonlocal raised
async with cond:
with self.assertRaises(asyncio.CancelledError) as err:
await cond.wait_for(lambda: wake)
raised = err.exception
raise raised

task = asyncio.create_task(func())
await asyncio.sleep(0)
# Task is waiting on the condition, cancel it there.
task.cancel(msg="foo")
with self.assertRaises(asyncio.CancelledError) as err:
await task
self.assertEqual(err.exception.args, ("foo",))
# We should have got the _same_ exception instance as the one
# originally raised.
self.assertIs(err.exception, raised)

async def test_cancelled_error_re_aquire(self):
# Test that a cancelled error, received when re-aquiring lock,
# will be re-raised un-modified.
wake = False
raised = None
cond = asyncio.Condition()

async def func():
nonlocal raised
async with cond:
with self.assertRaises(asyncio.CancelledError) as err:
await cond.wait_for(lambda: wake)
raised = err.exception
raise raised

task = asyncio.create_task(func())
await asyncio.sleep(0)
# Task is waiting on the condition
await cond.acquire()
wake = True
cond.notify()
await asyncio.sleep(0)
# Task is now trying to re-acquire the lock, cancel it there.
task.cancel(msg="foo")
cond.release()
with self.assertRaises(asyncio.CancelledError) as err:
await task
self.assertEqual(err.exception.args, ("foo",))
# We should have got the _same_ exception instance as the one
# originally raised.
self.assertIs(err.exception, raised)

class SemaphoreTests(unittest.IsolatedAsyncioTestCase):

Expand Down Expand Up @@ -1044,6 +1101,62 @@ async def c3(result):
await asyncio.gather(*tasks, return_exceptions=True)
self.assertEqual([2, 3], result)

async def test_acquire_fifo_order_4(self):
# Test that a successfule `acquire()` will wake up multiple Tasks
# that were waiting in the Semaphore queue due to FIFO rules.
sem = asyncio.Semaphore(0)
result = []
count = 0

async def c1(result):
# First task immediatlly waits for semaphore. It will be awoken by c2.
self.assertEqual(sem._value, 0)
await sem.acquire()
# We should have woken up all waiting tasks now.
self.assertEqual(sem._value, 0)
# Create a fourth task. It should run after c3, not c2.
nonlocal t4
t4 = asyncio.create_task(c4(result))
result.append(1)
return True

async def c2(result):
# The second task begins by releasing semaphore three times,
# for c1, c2, and c3.
sem.release()
sem.release()
sem.release()
self.assertEqual(sem._value, 2)
# It is locked, because c1 hasn't woken up yet.
self.assertTrue(sem.locked())
await sem.acquire()
result.append(2)
return True

async def c3(result):
await sem.acquire()
self.assertTrue(sem.locked())
result.append(3)
return True

async def c4(result):
result.append(4)
return True

t1 = asyncio.create_task(c1(result))
t2 = asyncio.create_task(c2(result))
t3 = asyncio.create_task(c3(result))
t4 = None

await asyncio.sleep(0)
# Three tasks are in the queue, the first hasn't woken up yet.
self.assertEqual(sem._value, 2)
self.assertEqual(len(sem._waiters), 3)
await asyncio.sleep(0)

tasks = [t1, t2, t3, t4]
await asyncio.gather(*tasks)
self.assertEqual([1, 2, 3, 4], result)

class BarrierTests(unittest.IsolatedAsyncioTestCase):

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
:func:`asyncio.Condition.wait()` now re-raises the same :exc:`CancelledError` instance that may have caused it to be interrupted. Fixed race condition in :func:`asyncio.Semaphore.aquire` when interrupted with a :exc:`CancelledError`.

0 comments on commit 5216178

Please sign in to comment.