Skip to content

Commit

Permalink
ntsync: Introduce NTSYNC_IOC_CREATE_MUTEX.
Browse files Browse the repository at this point in the history
This corresponds to the NT syscall NtCreateMutant().

An NT mutex is recursive, with a 32-bit recursion counter. When acquired via
NtWaitForMultipleObjects(), the recursion counter is incremented by one. The OS
records the thread which acquired it.

The OS records the thread which acquired it. However, in order to keep this
driver self-contained, the owning thread ID is managed by user-space, and passed
as a parameter to all relevant ioctls.

The initial owner and recursion count, if any, are specified when the mutex is
created.

Signed-off-by: Elizabeth Figura <[email protected]>
Signed-off-by: Alexandre Frade <[email protected]>
  • Loading branch information
Elizabeth Figura authored and xanmod committed Sep 16, 2024
1 parent 577bd21 commit 2740171
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 4 deletions.
77 changes: 74 additions & 3 deletions drivers/misc/ntsync.c
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

enum ntsync_type {
NTSYNC_TYPE_SEM,
NTSYNC_TYPE_MUTEX,
};

/*
Expand Down Expand Up @@ -55,6 +56,10 @@ struct ntsync_obj {
__u32 count;
__u32 max;
} sem;
struct {
__u32 count;
pid_t owner;
} mutex;
} u;

/*
Expand Down Expand Up @@ -92,6 +97,7 @@ struct ntsync_q_entry {

struct ntsync_q {
struct task_struct *task;
__u32 owner;

/*
* Protected via atomic_try_cmpxchg(). Only the thread that wins the
Expand Down Expand Up @@ -214,13 +220,17 @@ static void ntsync_unlock_obj(struct ntsync_device *dev, struct ntsync_obj *obj,
((lockdep_is_held(&(obj)->dev->wait_all_lock) != LOCK_STATE_NOT_HELD) && \
(obj)->dev_locked))

static bool is_signaled(struct ntsync_obj *obj)
static bool is_signaled(struct ntsync_obj *obj, __u32 owner)
{
ntsync_assert_held(obj);

switch (obj->type) {
case NTSYNC_TYPE_SEM:
return !!obj->u.sem.count;
case NTSYNC_TYPE_MUTEX:
if (obj->u.mutex.owner && obj->u.mutex.owner != owner)
return false;
return obj->u.mutex.count < UINT_MAX;
}

WARN(1, "bad object type %#x\n", obj->type);
Expand Down Expand Up @@ -250,7 +260,7 @@ static void try_wake_all(struct ntsync_device *dev, struct ntsync_q *q,
}

for (i = 0; i < count; i++) {
if (!is_signaled(q->entries[i].obj)) {
if (!is_signaled(q->entries[i].obj, q->owner)) {
can_wake = false;
break;
}
Expand All @@ -264,6 +274,10 @@ static void try_wake_all(struct ntsync_device *dev, struct ntsync_q *q,
case NTSYNC_TYPE_SEM:
obj->u.sem.count--;
break;
case NTSYNC_TYPE_MUTEX:
obj->u.mutex.count++;
obj->u.mutex.owner = q->owner;
break;
}
}
wake_up_process(q->task);
Expand Down Expand Up @@ -307,6 +321,30 @@ static void try_wake_any_sem(struct ntsync_obj *sem)
}
}

static void try_wake_any_mutex(struct ntsync_obj *mutex)
{
struct ntsync_q_entry *entry;

ntsync_assert_held(mutex);
lockdep_assert(mutex->type == NTSYNC_TYPE_MUTEX);

list_for_each_entry(entry, &mutex->any_waiters, node) {
struct ntsync_q *q = entry->q;
int signaled = -1;

if (mutex->u.mutex.count == UINT_MAX)
break;
if (mutex->u.mutex.owner && mutex->u.mutex.owner != q->owner)
continue;

if (atomic_try_cmpxchg(&q->signaled, &signaled, entry->index)) {
mutex->u.mutex.count++;
mutex->u.mutex.owner = q->owner;
wake_up_process(q->task);
}
}
}

/*
* Actually change the semaphore state, returning -EOVERFLOW if it is made
* invalid.
Expand Down Expand Up @@ -455,6 +493,33 @@ static int ntsync_create_sem(struct ntsync_device *dev, void __user *argp)
return put_user(fd, &user_args->sem);
}

static int ntsync_create_mutex(struct ntsync_device *dev, void __user *argp)
{
struct ntsync_mutex_args __user *user_args = argp;
struct ntsync_mutex_args args;
struct ntsync_obj *mutex;
int fd;

if (copy_from_user(&args, argp, sizeof(args)))
return -EFAULT;

if (!args.owner != !args.count)
return -EINVAL;

mutex = ntsync_alloc_obj(dev, NTSYNC_TYPE_MUTEX);
if (!mutex)
return -ENOMEM;
mutex->u.mutex.count = args.count;
mutex->u.mutex.owner = args.owner;
fd = ntsync_obj_get_fd(mutex);
if (fd < 0) {
kfree(mutex);
return fd;
}

return put_user(fd, &user_args->mutex);
}

static struct ntsync_obj *get_obj(struct ntsync_device *dev, int fd)
{
struct file *file = fget(fd);
Expand Down Expand Up @@ -524,7 +589,7 @@ static int setup_wait(struct ntsync_device *dev,
struct ntsync_q *q;
__u32 i, j;

if (args->pad[0] || args->pad[1] || args->pad[2] || (args->flags & ~NTSYNC_WAIT_REALTIME))
if (args->pad[0] || args->pad[1] || (args->flags & ~NTSYNC_WAIT_REALTIME))
return -EINVAL;

if (args->count > NTSYNC_MAX_WAIT_COUNT)
Expand All @@ -538,6 +603,7 @@ static int setup_wait(struct ntsync_device *dev,
if (!q)
return -ENOMEM;
q->task = current;
q->owner = args->owner;
atomic_set(&q->signaled, -1);
q->all = all;
q->count = count;
Expand Down Expand Up @@ -580,6 +646,9 @@ static void try_wake_any_obj(struct ntsync_obj *obj)
case NTSYNC_TYPE_SEM:
try_wake_any_sem(obj);
break;
case NTSYNC_TYPE_MUTEX:
try_wake_any_mutex(obj);
break;
}
}

Expand Down Expand Up @@ -769,6 +838,8 @@ static long ntsync_char_ioctl(struct file *file, unsigned int cmd,
void __user *argp = (void __user *)parm;

switch (cmd) {
case NTSYNC_IOC_CREATE_MUTEX:
return ntsync_create_mutex(dev, argp);
case NTSYNC_IOC_CREATE_SEM:
return ntsync_create_sem(dev, argp);
case NTSYNC_IOC_WAIT_ALL:
Expand Down
10 changes: 9 additions & 1 deletion include/uapi/linux/ntsync.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ struct ntsync_sem_args {
__u32 max;
};

struct ntsync_mutex_args {
__u32 mutex;
__u32 owner;
__u32 count;
};

#define NTSYNC_WAIT_REALTIME 0x1

struct ntsync_wait_args {
Expand All @@ -24,14 +30,16 @@ struct ntsync_wait_args {
__u32 count;
__u32 index;
__u32 flags;
__u32 pad[3];
__u32 owner;
__u32 pad[2];
};

#define NTSYNC_MAX_WAIT_COUNT 64

#define NTSYNC_IOC_CREATE_SEM _IOWR('N', 0x80, struct ntsync_sem_args)
#define NTSYNC_IOC_WAIT_ANY _IOWR('N', 0x82, struct ntsync_wait_args)
#define NTSYNC_IOC_WAIT_ALL _IOWR('N', 0x83, struct ntsync_wait_args)
#define NTSYNC_IOC_CREATE_MUTEX _IOWR('N', 0x84, struct ntsync_sem_args)

#define NTSYNC_IOC_SEM_POST _IOWR('N', 0x81, __u32)

Expand Down

0 comments on commit 2740171

Please sign in to comment.