diff --git a/drivers/android/binder/process.rs b/drivers/android/binder/process.rs index 5f64c59bc3fe..90823aa3e119 100644 --- a/drivers/android/binder/process.rs +++ b/drivers/android/binder/process.rs @@ -1634,7 +1634,7 @@ pub(crate) fn mmap( pub(crate) fn poll( this: ArcBorrow<'_, Process>, file: &File, - table: &mut PollTable, + table: PollTable<'_>, ) -> Result { let thread = this.get_current_thread()?; let (from_proc, mut mask) = thread.poll(file, table); diff --git a/drivers/android/binder/rust_binder.rs b/drivers/android/binder/rust_binder.rs index 2a8aa474fea7..7f59679c9602 100644 --- a/drivers/android/binder/rust_binder.rs +++ b/drivers/android/binder/rust_binder.rs @@ -472,7 +472,7 @@ unsafe impl Sync for AssertSync {} // SAFETY: The caller ensures that the file is valid. let fileref = unsafe { File::from_raw_file(file) }; // SAFETY: The caller ensures that the `PollTable` is valid. - match Process::poll(f, fileref, unsafe { PollTable::from_ptr(wait) }) { + match Process::poll(f, fileref, unsafe { PollTable::from_raw(wait) }) { Ok(v) => v, Err(_) => bindings::POLLERR, } diff --git a/drivers/android/binder/thread.rs b/drivers/android/binder/thread.rs index 8b48eb3cb129..de9138d3bf74 100644 --- a/drivers/android/binder/thread.rs +++ b/drivers/android/binder/thread.rs @@ -1614,7 +1614,7 @@ pub(crate) fn write_read(self: &Arc, data: UserSlice, wait: bool) -> Resul ret } - pub(crate) fn poll(&self, file: &File, table: &mut PollTable) -> (bool, u32) { + pub(crate) fn poll(&self, file: &File, table: PollTable<'_>) -> (bool, u32) { table.register_wait(file, &self.work_condvar); let mut inner = self.inner.lock(); (inner.should_use_process_work_queue(), inner.poll()) diff --git a/rust/helpers/helpers.c b/rust/helpers/helpers.c index bf2bea15e227..41cde887a9ba 100644 --- a/rust/helpers/helpers.c +++ b/rust/helpers/helpers.c @@ -21,6 +21,7 @@ #include "mman.c" #include "mutex.c" #include "page.c" +#include "poll.c" #include "rbtree.c" #include "refcount.c" #include "security.c" diff --git a/rust/helpers/poll.c b/rust/helpers/poll.c new file mode 100644 index 000000000000..7e5b1751c2d5 --- /dev/null +++ b/rust/helpers/poll.c @@ -0,0 +1,10 @@ +// SPDX-License-Identifier: GPL-2.0 + +#include +#include + +void rust_helper_poll_wait(struct file *filp, wait_queue_head_t *wait_address, + poll_table *p) +{ + poll_wait(filp, wait_address, p); +} diff --git a/rust/kernel/sync/poll.rs b/rust/kernel/sync/poll.rs index d5f17153b424..729d7e093d85 100644 --- a/rust/kernel/sync/poll.rs +++ b/rust/kernel/sync/poll.rs @@ -9,9 +9,8 @@ fs::File, prelude::*, sync::{CondVar, LockClassKey}, - types::Opaque, }; -use core::ops::Deref; +use core::{marker::PhantomData, ops::Deref}; /// Creates a [`PollCondVar`] initialiser with the given name and a newly-created lock class. #[macro_export] @@ -23,58 +22,40 @@ macro_rules! new_poll_condvar { }; } -/// Wraps the kernel's `struct poll_table`. +/// Wraps the kernel's `poll_table`. /// /// # Invariants /// -/// This struct contains a valid `struct poll_table`. -/// -/// For a `struct poll_table` to be valid, its `_qproc` function must follow the safety -/// requirements of `_qproc` functions: -/// -/// * The `_qproc` function is given permission to enqueue a waiter to the provided `poll_table` -/// during the call. Once the waiter is removed and an rcu grace period has passed, it must no -/// longer access the `wait_queue_head`. +/// The pointer must be null or reference a valid `poll_table`. #[repr(transparent)] -pub struct PollTable(Opaque); +pub struct PollTable<'a> { + table: *mut bindings::poll_table, + _lifetime: PhantomData<&'a bindings::poll_table>, +} -impl PollTable { - /// Creates a reference to a [`PollTable`] from a valid pointer. +impl<'a> PollTable<'a> { + /// Creates a [`PollTable`] from a valid pointer. /// /// # Safety /// - /// The caller must ensure that for the duration of 'a, the pointer will point at a valid poll - /// table (as defined in the type invariants). - /// - /// The caller must also ensure that the `poll_table` is only accessed via the returned - /// reference for the duration of 'a. - pub unsafe fn from_ptr<'a>(ptr: *mut bindings::poll_table) -> &'a mut PollTable { - // SAFETY: The safety requirements guarantee the validity of the dereference, while the - // `PollTable` type being transparent makes the cast ok. - unsafe { &mut *ptr.cast() } - } - - fn get_qproc(&self) -> bindings::poll_queue_proc { - let ptr = self.0.get(); - // SAFETY: The `ptr` is valid because it originates from a reference, and the `_qproc` - // field is not modified concurrently with this call since we have an immutable reference. - unsafe { (*ptr)._qproc } + /// The pointer must be null or reference a valid `poll_table` for the duration of `'a`. + pub unsafe fn from_raw(table: *mut bindings::poll_table) -> Self { + // INVARIANTS: The safety requirements are the same as the struct invariants. + PollTable { + table, + _lifetime: PhantomData, + } } /// Register this [`PollTable`] with the provided [`PollCondVar`], so that it can be notified /// using the condition variable. - pub fn register_wait(&mut self, file: &File, cv: &PollCondVar) { - if let Some(qproc) = self.get_qproc() { - // SAFETY: The pointers to `file` and `self` need to be valid for the duration of this - // call to `qproc`, which they are because they are references. - // - // The `cv.wait_queue_head` pointer must be valid until an rcu grace period after the - // waiter is removed. The `PollCondVar` is pinned, so before `cv.wait_queue_head` can - // be destroyed, the destructor must run. That destructor first removes all waiters, - // and then waits for an rcu grace period. Therefore, `cv.wait_queue_head` is valid for - // long enough. - unsafe { qproc(file.as_ptr() as _, cv.wait_queue_head.get(), self.0.get()) }; - } + pub fn register_wait(&self, file: &File, cv: &PollCondVar) { + // SAFETY: The pointers `self.table` and `file` are valid for the duration of this call. + // The `cv.wait_queue_head` pointer must be valid until an rcu grace period after the + // waiter is removed. The `PollCondVar` is pinned, so before `cv.wait_queue_head` can be + // destroyed, the destructor must run. That destructor first removes all waiters, and then + // waits for an rcu grace period. Therefore, `cv.wait_queue_head` is valid for long enough. + unsafe { bindings::poll_wait(file.as_ptr(), cv.wait_queue_head.get(), self.table) } } }