Skip to content

Commit

Permalink
ntsync: Introduce alertable waits.
Browse files Browse the repository at this point in the history
NT waits can optionally be made "alertable". This is a special channel for
thread wakeup that is mildly similar to SIGIO. A thread has an internal single
bit of "alerted" state, and if a thread is alerted while an alertable wait, the
wait will return a special value, consume the "alerted" state, and will not
consume any of its objects.

Alerts are implemented using events; the user-space NT emulator is expected to
create an internal ntsync event for each thread and pass that event to wait
functions.

Signed-off-by: Elizabeth Figura <[email protected]>
Signed-off-by: Alexandre Frade <[email protected]>
  • Loading branch information
Elizabeth Figura authored and xanmod committed Sep 16, 2024
1 parent ec13d9b commit 1f2e05a
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 10 deletions.
70 changes: 61 additions & 9 deletions drivers/misc/ntsync.c
Original file line number Diff line number Diff line change
Expand Up @@ -885,22 +885,29 @@ static int setup_wait(struct ntsync_device *dev,
const struct ntsync_wait_args *args, bool all,
struct ntsync_q **ret_q)
{
int fds[NTSYNC_MAX_WAIT_COUNT + 1];
const __u32 count = args->count;
int fds[NTSYNC_MAX_WAIT_COUNT];
struct ntsync_q *q;
__u32 total_count;
__u32 i, j;

if (args->pad[0] || args->pad[1] || (args->flags & ~NTSYNC_WAIT_REALTIME))
if (args->pad || (args->flags & ~NTSYNC_WAIT_REALTIME))
return -EINVAL;

if (args->count > NTSYNC_MAX_WAIT_COUNT)
return -EINVAL;

total_count = count;
if (args->alert)
total_count++;

if (copy_from_user(fds, u64_to_user_ptr(args->objs),
array_size(count, sizeof(*fds))))
return -EFAULT;
if (args->alert)
fds[count] = args->alert;

q = kmalloc(struct_size(q, entries, count), GFP_KERNEL);
q = kmalloc(struct_size(q, entries, total_count), GFP_KERNEL);
if (!q)
return -ENOMEM;
q->task = current;
Expand All @@ -910,7 +917,7 @@ static int setup_wait(struct ntsync_device *dev,
q->ownerdead = false;
q->count = count;

for (i = 0; i < count; i++) {
for (i = 0; i < total_count; i++) {
struct ntsync_q_entry *entry = &q->entries[i];
struct ntsync_obj *obj = get_obj(dev, fds[i]);

Expand Down Expand Up @@ -960,10 +967,10 @@ static void try_wake_any_obj(struct ntsync_obj *obj)
static int ntsync_wait_any(struct ntsync_device *dev, void __user *argp)
{
struct ntsync_wait_args args;
__u32 i, total_count;
struct ntsync_q *q;
int signaled;
bool all;
__u32 i;
int ret;

if (copy_from_user(&args, argp, sizeof(args)))
Expand All @@ -973,9 +980,13 @@ static int ntsync_wait_any(struct ntsync_device *dev, void __user *argp)
if (ret < 0)
return ret;

total_count = args.count;
if (args.alert)
total_count++;

/* queue ourselves */

for (i = 0; i < args.count; i++) {
for (i = 0; i < total_count; i++) {
struct ntsync_q_entry *entry = &q->entries[i];
struct ntsync_obj *obj = entry->obj;

Expand All @@ -984,9 +995,15 @@ static int ntsync_wait_any(struct ntsync_device *dev, void __user *argp)
ntsync_unlock_obj(dev, obj, all);
}

/* check if we are already signaled */
/*
* Check if we are already signaled.
*
* Note that the API requires that normal objects are checked before
* the alert event. Hence we queue the alert event last, and check
* objects in order.
*/

for (i = 0; i < args.count; i++) {
for (i = 0; i < total_count; i++) {
struct ntsync_obj *obj = q->entries[i].obj;

if (atomic_read(&q->signaled) != -1)
Expand All @@ -1003,7 +1020,7 @@ static int ntsync_wait_any(struct ntsync_device *dev, void __user *argp)

/* and finally, unqueue */

for (i = 0; i < args.count; i++) {
for (i = 0; i < total_count; i++) {
struct ntsync_q_entry *entry = &q->entries[i];
struct ntsync_obj *obj = entry->obj;

Expand Down Expand Up @@ -1063,13 +1080,36 @@ static int ntsync_wait_all(struct ntsync_device *dev, void __user *argp)
*/
list_add_tail(&entry->node, &obj->all_waiters);
}
if (args.alert) {
struct ntsync_q_entry *entry = &q->entries[args.count];
struct ntsync_obj *obj = entry->obj;

dev_lock_obj(dev, obj);
list_add_tail(&entry->node, &obj->any_waiters);
dev_unlock_obj(dev, obj);
}

/* check if we are already signaled */

try_wake_all(dev, q, NULL);

mutex_unlock(&dev->wait_all_lock);

/*
* Check if the alert event is signaled, making sure to do so only
* after checking if the other objects are signaled.
*/

if (args.alert) {
struct ntsync_obj *obj = q->entries[args.count].obj;

if (atomic_read(&q->signaled) == -1) {
bool all = ntsync_lock_obj(dev, obj);
try_wake_any_obj(obj);
ntsync_unlock_obj(dev, obj, all);
}
}

/* sleep */

ret = ntsync_schedule(q, &args);
Expand All @@ -1095,6 +1135,18 @@ static int ntsync_wait_all(struct ntsync_device *dev, void __user *argp)

mutex_unlock(&dev->wait_all_lock);

if (args.alert) {
struct ntsync_q_entry *entry = &q->entries[args.count];
struct ntsync_obj *obj = entry->obj;
bool all;

all = ntsync_lock_obj(dev, obj);
list_del(&entry->node);
ntsync_unlock_obj(dev, obj, all);

put_obj(obj);
}

signaled = atomic_read(&q->signaled);
if (signaled != -1) {
struct ntsync_wait_args __user *user_args = argp;
Expand Down
3 changes: 2 additions & 1 deletion include/uapi/linux/ntsync.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ struct ntsync_wait_args {
__u32 index;
__u32 flags;
__u32 owner;
__u32 pad[2];
__u32 alert;
__u32 pad;
};

#define NTSYNC_MAX_WAIT_COUNT 64
Expand Down

0 comments on commit 1f2e05a

Please sign in to comment.