Skip to content

Commit

Permalink
Push locking outside of noise state machine (#304)
Browse files Browse the repository at this point in the history
  • Loading branch information
jeff-hiner authored Aug 15, 2022
1 parent 29b99af commit 1466836
Show file tree
Hide file tree
Showing 8 changed files with 96 additions and 169 deletions.
1 change: 1 addition & 0 deletions boringtun/src/device/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ fn api_get(writer: &mut BufWriter<&UnixStream>, d: &Device) -> i32 {
}

for (k, p) in d.peers.iter() {
let p = p.lock();
writeln!(writer, "public_key={}", encode_hex(k.as_bytes()));

if let Some(ref key) = p.preshared_key() {
Expand Down
63 changes: 35 additions & 28 deletions boringtun/src/device/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ use crate::noise::handshake::parse_handshake_anon;
use crate::noise::rate_limiter::RateLimiter;
use crate::noise::{Packet, Tunn, TunnResult};
use allowed_ips::AllowedIps;
use parking_lot::Mutex;
use peer::{AllowedIP, Peer};
use poll::{EventPoll, EventRef, WaitResult};
use tun::{errno, errno_str, TunSocket};
Expand Down Expand Up @@ -128,9 +129,9 @@ pub struct Device {
yield_notice: Option<EventRef>,
exit_notice: Option<EventRef>,

peers: HashMap<x25519_dalek::PublicKey, Arc<Peer>>,
peers_by_ip: AllowedIps<Arc<Peer>>,
peers_by_idx: HashMap<u32, Arc<Peer>>,
peers: HashMap<x25519_dalek::PublicKey, Arc<Mutex<Peer>>>,
peers_by_ip: AllowedIps<Arc<Mutex<Peer>>>,
peers_by_idx: HashMap<u32, Arc<Mutex<Peer>>>,
next_index: u32,

config: DeviceConfig,
Expand Down Expand Up @@ -275,10 +276,13 @@ impl Device {
fn remove_peer(&mut self, pub_key: &x25519_dalek::PublicKey) {
if let Some(peer) = self.peers.remove(pub_key) {
// Found a peer to remove, now purge all references to it:
peer.shutdown_endpoint(); // close open udp socket and free the closure
self.peers_by_idx.remove(&peer.index()); // peers_by_idx
{
let p = peer.lock();
p.shutdown_endpoint(); // close open udp socket and free the closure
self.peers_by_idx.remove(&p.index());
}
self.peers_by_ip
.remove(&|p: &Arc<Peer>| Arc::ptr_eq(&peer, p)); // peers_by_ip
.remove(&|p: &Arc<Mutex<Peer>>| Arc::ptr_eq(&peer, p));

tracing::info!("Peer removed");
}
Expand Down Expand Up @@ -324,7 +328,7 @@ impl Device {

let peer = Peer::new(tunn, next_index, endpoint, allowed_ips, preshared_key);

let peer = Arc::new(peer);
let peer = Arc::new(Mutex::new(peer));
self.peers.insert(pub_key, Arc::clone(&peer));
self.peers_by_idx.insert(next_index, Arc::clone(&peer));

Expand Down Expand Up @@ -408,7 +412,7 @@ impl Device {
}

for peer in self.peers.values() {
peer.shutdown_endpoint();
peer.lock().shutdown_endpoint();
}

// Then open new sockets and bind to the port
Expand Down Expand Up @@ -456,8 +460,7 @@ impl Device {
let rate_limiter = Arc::new(RateLimiter::new(&public_key, HANDSHAKE_RATE_LIMIT));

for peer in self.peers.values_mut() {
let peer_mut =
Arc::<Peer>::get_mut(peer).expect("set_key requires other threads to be stopped");
let mut peer_mut = peer.lock();

if peer_mut
.tunnel
Expand All @@ -470,7 +473,7 @@ impl Device {
{
// In case we encounter an error, we will remove that peer
// An error will be a result of bad public key/secret key combination
bad_peers.push(peer);
bad_peers.push(Arc::clone(peer));
}
}

Expand All @@ -497,7 +500,7 @@ impl Device {

// Then on all currently connected sockets
for peer in self.peers.values() {
if let Some(ref sock) = peer.endpoint().conn {
if let Some(ref sock) = peer.lock().endpoint().conn {
sock.set_fwmark(mark)?
}
}
Expand Down Expand Up @@ -550,15 +553,16 @@ impl Device {

// Go over each peer and invoke the timer function
for peer in peer_map.values() {
let endpoint_addr = match peer.endpoint().addr {
let mut p = peer.lock();
let endpoint_addr = match p.endpoint().addr {
Some(addr) => addr,
None => continue,
};

match peer.update_timers(&mut t.dst_buf[..]) {
match p.update_timers(&mut t.dst_buf[..]) {
TunnResult::Done => {}
TunnResult::Err(WireGuardError::ConnectionExpired) => {
peer.shutdown_endpoint(); // close open udp socket
p.shutdown_endpoint(); // close open udp socket
}
TunnResult::Err(e) => tracing::error!(message = "Timer error", error = ?e),
TunnResult::WriteToNetwork(packet) => {
Expand Down Expand Up @@ -634,9 +638,11 @@ impl Device {
Some(peer) => peer,
};

let mut p = peer.lock();

// We found a peer, use it to decapsulate the message+
let mut flush = false; // Are there packets to send from the queue?
match peer
match p
.tunnel
.handle_verified_packet(parsed_packet, &mut t.dst_buf[..])
{
Expand All @@ -647,12 +653,12 @@ impl Device {
udp.sendto(packet, addr);
}
TunnResult::WriteToTunnelV4(packet, addr) => {
if peer.is_allowed_ip(addr) {
if p.is_allowed_ip(addr) {
t.iface.write4(packet);
}
}
TunnResult::WriteToTunnelV6(packet, addr) => {
if peer.is_allowed_ip(addr) {
if p.is_allowed_ip(addr) {
t.iface.write6(packet);
}
}
Expand All @@ -661,17 +667,17 @@ impl Device {
if flush {
// Flush pending queue
while let TunnResult::WriteToNetwork(packet) =
peer.tunnel.decapsulate(None, &[], &mut t.dst_buf[..])
p.tunnel.decapsulate(None, &[], &mut t.dst_buf[..])
{
udp.sendto(packet, addr);
}
}

// This packet was OK, that means we want to create a connected socket for this peer
let ip_addr = addr.ip();
peer.set_endpoint(addr);
p.set_endpoint(addr);
if d.config.use_connected_socket {
if let Ok(sock) = peer.connect_endpoint(d.listen_port, d.fwmark) {
if let Ok(sock) = p.connect_endpoint(d.listen_port, d.fwmark) {
d.register_conn_handler(Arc::clone(peer), sock, ip_addr)
.unwrap();
}
Expand All @@ -690,7 +696,7 @@ impl Device {

fn register_conn_handler(
&self,
peer: Arc<Peer>,
peer: Arc<Mutex<Peer>>,
udp: Arc<UDPSocket>,
peer_addr: IpAddr,
) -> Result<(), Error> {
Expand All @@ -705,7 +711,8 @@ impl Device {

while let Ok(src) = udp.read(&mut t.src_buf[..]) {
let mut flush = false;
match peer
let mut p = peer.lock();
match p
.tunnel
.decapsulate(Some(peer_addr), src, &mut t.dst_buf[..])
{
Expand All @@ -716,12 +723,12 @@ impl Device {
udp.write(packet);
}
TunnResult::WriteToTunnelV4(packet, addr) => {
if peer.is_allowed_ip(addr) {
if p.is_allowed_ip(addr) {
iface.write4(packet);
}
}
TunnResult::WriteToTunnelV6(packet, addr) => {
if peer.is_allowed_ip(addr) {
if p.is_allowed_ip(addr) {
iface.write6(packet);
}
}
Expand All @@ -730,7 +737,7 @@ impl Device {
if flush {
// Flush pending queue
while let TunnResult::WriteToNetwork(packet) =
peer.tunnel.decapsulate(None, &[], &mut t.dst_buf[..])
p.tunnel.decapsulate(None, &[], &mut t.dst_buf[..])
{
udp.write(packet);
}
Expand Down Expand Up @@ -785,8 +792,8 @@ impl Device {
None => continue,
};

let peer = match peers.find(dst_addr) {
Some(peer) => peer,
let mut peer = match peers.find(dst_addr) {
Some(peer) => peer.lock(),
None => continue,
};

Expand Down
10 changes: 6 additions & 4 deletions boringtun/src/device/peer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ pub struct Endpoint {
}

pub struct Peer {
pub(crate) tunnel: Box<Tunn>, // The associated tunnel struct
index: u32, // The index the tunnel uses
/// The associated tunnel struct
pub(crate) tunnel: Tunn,
/// The index the tunnel uses
index: u32,
endpoint: RwLock<Endpoint>,
allowed_ips: AllowedIps<()>,
preshared_key: Option<[u8; 32]>,
Expand Down Expand Up @@ -53,7 +55,7 @@ impl FromStr for AllowedIP {

impl Peer {
pub fn new(
tunnel: Box<Tunn>,
tunnel: Tunn,
index: u32,
endpoint: Option<SocketAddr>,
allowed_ips: &[AllowedIP],
Expand All @@ -71,7 +73,7 @@ impl Peer {
}
}

pub fn update_timers<'a>(&self, dst: &'a mut [u8]) -> TunnResult<'a> {
pub fn update_timers<'a>(&mut self, dst: &'a mut [u8]) -> TunnResult<'a> {
self.tunnel.update_timers(dst)
}

Expand Down
27 changes: 14 additions & 13 deletions boringtun/src/ffi/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use super::noise::{Tunn, TunnResult};
use base64::{decode, encode};
use hex::encode as encode_hex;
use libc::{raise, SIGSEGV};
use parking_lot::Mutex;
use rand_core::OsRng;
use x25519_dalek::{PublicKey, StaticSecret};

Expand Down Expand Up @@ -168,7 +169,7 @@ pub unsafe extern "C" fn new_tunnel(
preshared_key: *const c_char,
keep_alive: u16,
index: u32,
) -> *mut Tunn {
) -> *mut Mutex<Tunn> {
let c_str = CStr::from_ptr(static_private);
let static_private = match c_str.to_str() {
Err(_) => return ptr::null_mut(),
Expand Down Expand Up @@ -221,7 +222,7 @@ pub unsafe extern "C" fn new_tunnel(
index,
None,
) {
Ok(t) => t,
Ok(t) => Box::new(Mutex::new(t)),
Err(_) => return ptr::null_mut(),
};

Expand All @@ -237,21 +238,21 @@ pub unsafe extern "C" fn new_tunnel(

/// Drops the Tunn object
#[no_mangle]
pub unsafe extern "C" fn tunnel_free(tunnel: *mut Tunn) {
pub unsafe extern "C" fn tunnel_free(tunnel: *mut Mutex<Tunn>) {
Box::from_raw(tunnel);
}

/// Write an IP packet from the tunnel interface.
/// For more details check noise::tunnel_to_network functions.
#[no_mangle]
pub unsafe extern "C" fn wireguard_write(
tunnel: *mut Tunn,
tunnel: *const Mutex<Tunn>,
src: *const u8,
src_size: u32,
dst: *mut u8,
dst_size: u32,
) -> wireguard_result {
let tunnel = tunnel.as_ref().unwrap();
let mut tunnel = tunnel.as_ref().unwrap().lock();
// Slices are not owned, and therefore will not be freed by Rust
let src = slice::from_raw_parts(src, src_size as usize);
let dst = slice::from_raw_parts_mut(dst, dst_size as usize);
Expand All @@ -262,13 +263,13 @@ pub unsafe extern "C" fn wireguard_write(
/// For more details check noise::network_to_tunnel functions.
#[no_mangle]
pub unsafe extern "C" fn wireguard_read(
tunnel: *mut Tunn,
tunnel: *const Mutex<Tunn>,
src: *const u8,
src_size: u32,
dst: *mut u8,
dst_size: u32,
) -> wireguard_result {
let tunnel = tunnel.as_ref().unwrap();
let mut tunnel = tunnel.as_ref().unwrap().lock();
// Slices are not owned, and therefore will not be freed by Rust
let src = slice::from_raw_parts(src, src_size as usize);
let dst = slice::from_raw_parts_mut(dst, dst_size as usize);
Expand All @@ -279,11 +280,11 @@ pub unsafe extern "C" fn wireguard_read(
/// Recommended interval: 100ms.
#[no_mangle]
pub unsafe extern "C" fn wireguard_tick(
tunnel: *mut Tunn,
tunnel: *const Mutex<Tunn>,
dst: *mut u8,
dst_size: u32,
) -> wireguard_result {
let tunnel = tunnel.as_ref().unwrap();
let mut tunnel = tunnel.as_ref().unwrap().lock();
// Slices are not owned, and therefore will not be freed by Rust
let dst = slice::from_raw_parts_mut(dst, dst_size as usize);
wireguard_result::from(tunnel.update_timers(dst))
Expand All @@ -292,11 +293,11 @@ pub unsafe extern "C" fn wireguard_tick(
/// Force the tunnel to initiate a new handshake, dst buffer must be at least 148 byte long.
#[no_mangle]
pub unsafe extern "C" fn wireguard_force_handshake(
tunnel: *mut Tunn,
tunnel: *const Mutex<Tunn>,
dst: *mut u8,
dst_size: u32,
) -> wireguard_result {
let tunnel = tunnel.as_ref().unwrap();
let mut tunnel = tunnel.as_ref().unwrap().lock();
// Slices are not owned, and therefore will not be freed by Rust
let dst = slice::from_raw_parts_mut(dst, dst_size as usize);
wireguard_result::from(tunnel.format_handshake_initiation(dst, true))
Expand All @@ -307,8 +308,8 @@ pub unsafe extern "C" fn wireguard_force_handshake(
/// Number of data bytes encapsulated
/// Number of data bytes decapsulated
#[no_mangle]
pub unsafe extern "C" fn wireguard_stats(tunnel: *mut Tunn) -> stats {
let tunnel = tunnel.as_ref().unwrap();
pub unsafe extern "C" fn wireguard_stats(tunnel: *const Mutex<Tunn>) -> stats {
let tunnel = tunnel.as_ref().unwrap().lock();
let (time, tx_bytes, rx_bytes, estimated_loss, estimated_rtt) = tunnel.stats();
stats {
time_since_last_handshake: time.map(|t| t.as_secs() as i64).unwrap_or(-1),
Expand Down
8 changes: 5 additions & 3 deletions boringtun/src/jni.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use jni::objects::{JByteBuffer, JClass, JString};
use jni::strings::JNIStr;
use jni::sys::{jbyteArray, jint, jlong, jshort, jstring};
use jni::JNIEnv;
use parking_lot::Mutex;

use crate::ffi::new_tunnel;
use crate::ffi::wireguard_read;
Expand Down Expand Up @@ -193,7 +194,7 @@ pub unsafe extern "C" fn encrypt_raw_packet(
};

let output: wireguard_result = wireguard_write(
tunnel as *mut Tunn,
tunnel as *const Mutex<Tunn>,
env.convert_byte_array(src).unwrap().as_mut_ptr(),
src_size as u32,
dst_ptr,
Expand Down Expand Up @@ -228,7 +229,7 @@ pub unsafe extern "C" fn decrypt_to_raw_packet(
};

let output: wireguard_result = wireguard_read(
tunnel as *mut Tunn,
tunnel as *const Mutex<Tunn>,
env.convert_byte_array(src).unwrap().as_mut_ptr(),
src_size as u32,
dst_ptr,
Expand Down Expand Up @@ -261,7 +262,8 @@ pub unsafe extern "C" fn run_periodic_task(
Err(_) => return 0,
};

let output: wireguard_result = wireguard_tick(tunnel as *mut Tunn, dst_ptr, dst_size as u32);
let output: wireguard_result =
wireguard_tick(tunnel as *const Mutex<Tunn>, dst_ptr, dst_size as u32);

*op_ptr = output.op as u8;

Expand Down
Loading

0 comments on commit 1466836

Please sign in to comment.