This commit is contained in:
Hailey Somerville 2023-08-14 19:16:35 +10:00
parent 9a16d0d721
commit 9e44cb7881
7 changed files with 1742 additions and 0 deletions

1
.gitignore vendored Normal file
View file

@ -0,0 +1 @@
/target

1150
Cargo.lock generated Normal file

File diff suppressed because it is too large Load diff

13
Cargo.toml Normal file
View file

@ -0,0 +1,13 @@
[package]
name = "bark"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
bytemuck = { version = "1.13.1", features = ["derive"] }
cpal = "0.15.2"
derive_more = "0.99.17"
num-rational = "0.4.1"
structopt = "0.3.26"

222
src/main.rs Normal file
View file

@ -0,0 +1,222 @@
pub mod receive;
pub mod protocol;
pub mod time;
use std::net::{UdpSocket, Ipv4Addr, SocketAddrV4};
use std::process::ExitCode;
use std::sync::{Mutex, Arc};
use std::time::Duration;
use bytemuck::Zeroable;
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
use cpal::{OutputCallbackInfo, StreamConfig, InputCallbackInfo, BuildStreamError, PlayStreamError, BufferSize};
use structopt::StructOpt;
use protocol::{TimestampMicros, Packet, PacketBuffer};
use crate::time::{SampleDuration, Timestamp};
#[derive(StructOpt)]
enum Opt {
Stream(StreamOpt),
Receive(ReceiveOpt),
}
#[derive(StructOpt)]
struct ReceiveOpt {
#[structopt(long, short)]
pub group: Ipv4Addr,
#[structopt(long, short)]
pub port: u16,
#[structopt(long, short)]
pub bind: Ipv4Addr,
}
#[derive(StructOpt)]
struct StreamOpt {
#[structopt(long, short)]
pub group: Ipv4Addr,
#[structopt(long, short)]
pub port: u16,
#[structopt(long, short)]
pub bind: Option<SocketAddrV4>,
}
#[derive(Debug)]
enum RunError {
BindSocket(SocketAddrV4, std::io::Error),
JoinMulticast(std::io::Error),
NoDeviceAvailable,
BuildStream(BuildStreamError),
Stream(PlayStreamError),
Socket(std::io::Error),
}
fn main() -> Result<(), ExitCode> {
let opt = Opt::from_args();
let result = match opt {
Opt::Stream(opt) => run_stream(opt),
Opt::Receive(opt) => run_receive(opt),
};
result.map_err(|err| {
eprintln!("error: {err:?}");
ExitCode::FAILURE
})
}
fn run_stream(opt: StreamOpt) -> Result<(), RunError> {
let host = cpal::default_host();
let device = host.default_input_device()
.ok_or(RunError::NoDeviceAvailable)?;
let config = StreamConfig {
channels: protocol::CHANNELS,
sample_rate: protocol::SAMPLE_RATE,
buffer_size: BufferSize::Fixed(protocol::FRAMES_PER_PACKET as u32),
};
let bind = opt.bind.unwrap_or(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0));
let socket = UdpSocket::bind(bind)
.map_err(|e| RunError::BindSocket(bind, e))?;
let mut packet = Packet {
magic: protocol::MAGIC,
flags: 0,
seq: 1,
pts: TimestampMicros(0),
buffer: PacketBuffer::zeroed(),
};
let mut packet_written = SampleDuration::zero();
let stream = device.build_input_stream(&config,
move |mut data: &[f32], _: &InputCallbackInfo| {
// assert data only contains complete frames:
assert!(data.len() % usize::from(protocol::CHANNELS) == 0);
let mut timestamp = Timestamp::now().add(SampleDuration::ONE_PACKET.mul(6));
if packet.pts.0 == 0 {
packet.pts = timestamp.to_micros_lossy();
}
while data.len() > 0 {
let buffer_offset = packet_written.as_buffer_offset();
let buffer_remaining = packet.buffer.0.len() - buffer_offset;
let copy_count = std::cmp::min(data.len(), buffer_remaining);
let buffer_copy_end = buffer_offset + copy_count;
packet.buffer.0[buffer_offset..buffer_copy_end]
.copy_from_slice(&data[0..copy_count]);
data = &data[copy_count..];
packet_written = SampleDuration::from_buffer_offset(buffer_copy_end);
timestamp = timestamp.add(SampleDuration::from_buffer_offset(copy_count));
if packet_written == SampleDuration::ONE_PACKET {
// packet is full! send:
let dest = SocketAddrV4::new(opt.group, opt.port);
socket.send_to(bytemuck::bytes_of(&packet), dest)
.expect("UdpSocket::send");
// reset rest of packet for next:
packet.seq += 1;
packet.pts = timestamp.to_micros_lossy();
packet_written = SampleDuration::zero();
}
}
// if there is data waiting in the packet buffer at the end of the
// callback, the pts we just calculated is valid. if the packet is
// empty, reset the pts to 0. this signals the next callback to set
// pts to the current time when it fires.
if packet_written == SampleDuration::zero() {
packet.pts.0 = 0;
}
},
move |err| {
eprintln!("stream error! {err:?}");
},
None
).map_err(RunError::BuildStream)?;
stream.play().map_err(RunError::Stream)?;
loop {
std::thread::sleep(Duration::from_secs(1));
}
}
fn run_receive(opt: ReceiveOpt) -> Result<(), RunError> {
let host = cpal::default_host();
let device = host.default_output_device()
.ok_or(RunError::NoDeviceAvailable)?;
let config = StreamConfig {
channels: protocol::CHANNELS,
sample_rate: protocol::SAMPLE_RATE,
buffer_size: BufferSize::Fixed(protocol::FRAMES_PER_PACKET as u32),
};
struct SharedState {
pub recv: receive::Receiver,
}
let state = Arc::new(Mutex::new(SharedState {
recv: receive::Receiver::new(),
}));
let _stream = device.build_output_stream(&config,
{
let state = state.clone();
move |data: &mut [f32], info: &OutputCallbackInfo| {
let stream_timestamp = info.timestamp();
let output_latency = stream_timestamp.playback
.duration_since(&stream_timestamp.callback)
.unwrap_or_default();
let output_latency = SampleDuration::from_std_duration_lossy(output_latency);
let now = Timestamp::now();
let pts = now.add(output_latency);
let mut state = state.lock().unwrap();
state.recv.fill_stream_buffer(data, pts);
}
},
move |err| {
eprintln!("stream error! {err:?}");
},
None
).map_err(RunError::BuildStream)?;
let bind = SocketAddrV4::new(opt.bind, opt.port);
let socket = UdpSocket::bind(bind)
.map_err(|e| RunError::BindSocket(bind, e))?;
socket.join_multicast_v4(&opt.group, &opt.bind)
.map_err(RunError::JoinMulticast)?;
loop {
let mut packet = Packet::zeroed();
let nread = socket.recv(bytemuck::bytes_of_mut(&mut packet))
.map_err(RunError::Socket)?;
if nread < std::mem::size_of::<Packet>() {
eprintln!("packet wrong size! ignoring");
continue;
}
let mut state = state.lock().unwrap();
state.recv.push_packet(&packet);
}
}

48
src/protocol.rs Normal file
View file

@ -0,0 +1,48 @@
use bytemuck::{Pod, Zeroable};
use cpal::{SampleFormat, SampleRate, ChannelCount};
pub const SAMPLE_FORMAT: SampleFormat = SampleFormat::F32;
pub const SAMPLE_RATE: SampleRate = SampleRate(48000);
pub const CHANNELS: ChannelCount = 2;
pub const FRAMES_PER_PACKET: usize = 160;
pub const SAMPLES_PER_PACKET: usize = CHANNELS as usize * FRAMES_PER_PACKET;
pub const RECEIVER_BUFFERED_PACKETS: usize = 12; // about ~20ms latency
pub const MAGIC: u32 = 0x00a79ae2;
/// our network Packet struct
/// we don't need to worry about endianness, because according to the rust docs:
///
/// Floats and Ints have the same endianness on all supported platforms.
/// IEEE 754 very precisely specifies the bit layout of floats.
///
/// - https://doc.rust-lang.org/std/primitive.f32.html
#[derive(Debug, Clone, Copy, Zeroable, Pod)]
#[repr(C)]
pub struct Packet {
pub magic: u32,
pub flags: u32,
pub seq: u64,
pub pts: TimestampMicros,
pub buffer: PacketBuffer,
}
#[derive(Debug, Clone, Copy)]
#[repr(transparent)]
pub struct PacketBuffer(pub [f32; SAMPLES_PER_PACKET]);
/// SAFETY: Pod is impl'd for f32, and [T: Pod; N: usize]
/// but for some reason doesn't like N == SAMPLES_PER_PACKET?
unsafe impl Pod for PacketBuffer {}
/// SAFETY: Zeroable is impl'd for f32, and [T: Zeroable; N: usize]
/// but for some reason doesn't like N == SAMPLES_PER_PACKET?
unsafe impl Zeroable for PacketBuffer {
fn zeroed() -> Self {
PacketBuffer([0f32; SAMPLES_PER_PACKET])
}
}
#[derive(Debug, Clone, Copy, Zeroable, Pod)]
#[repr(transparent)]
pub struct TimestampMicros(pub u64);

217
src/receive.rs Normal file
View file

@ -0,0 +1,217 @@
use std::collections::VecDeque;
use crate::protocol::{Packet, RECEIVER_BUFFERED_PACKETS, self};
use crate::time::{Timestamp, SampleDuration};
pub struct Receiver {
start: Option<StreamStart>,
queue: VecDeque<QueueEntry>,
}
struct QueueEntry {
seq: u64,
pts: Timestamp,
consumed: SampleDuration,
packet: Option<Packet>,
}
impl QueueEntry {
pub fn as_full_buffer(&self) -> &[f32; protocol::SAMPLES_PER_PACKET] {
self.packet.as_ref()
.map(|packet| &packet.buffer.0)
.unwrap_or(&[0f32; protocol::SAMPLES_PER_PACKET])
}
}
struct StreamStart {
pts: Timestamp,
seq: u64,
sync: bool,
}
impl StreamStart {
pub fn from_packet(packet: &Packet) -> Self {
StreamStart {
pts: Timestamp::from_micros_lossy(packet.pts),
seq: packet.seq,
sync: false,
}
}
pub fn pts_for_seq(&self, seq: u64) -> Timestamp {
let delta = seq.checked_sub(self.seq).expect("seq < start seq in pts_for_seq");
let duration = SampleDuration::ONE_PACKET.mul(delta);
self.pts.add(duration)
}
}
pub enum PacketDisposition {
Pop,
Pass,
}
impl Receiver {
pub fn new() -> Self {
Receiver {
start: None,
queue: VecDeque::new(),
}
}
pub fn push_packet(&mut self, packet: &Packet) {
if let Some(start) = self.start.as_mut() {
if packet.seq < start.seq {
eprintln!("received packet with seq before start, dropping");
return;
}
if let Some(front) = self.queue.front() {
if packet.seq <= front.seq {
eprintln!("received packet with seq <= queue front seq, dropping");
return;
}
}
if let Some(back) = self.queue.back() {
if back.seq + RECEIVER_BUFFERED_PACKETS as u64 <= packet.seq {
eprintln!("received packet with seq too far in future, resetting stream");
self.start = Some(StreamStart::from_packet(packet));
self.queue.clear();
}
}
} else {
self.start = Some(StreamStart::from_packet(packet));
}
let start = self.start.as_ref().unwrap();
// INVARIANT: at this point we are guaranteed that, if there are
// packets in the queue, the seq of the incoming packet is less than
// back.seq + RECEIVER_BUFFERED_PACKETS
// expand queue to make space for new packet
if let Some(back) = self.queue.back() {
if packet.seq > back.seq {
// extend queue from back to make space for new packet
// this also allows for out of order packets
for seq in (back.seq + 1)..=packet.seq {
self.queue.push_back(QueueEntry {
seq,
pts: start.pts_for_seq(seq),
consumed: SampleDuration::zero(),
packet: None,
})
}
}
} else {
// queue is empty, insert missing packet slot for the packet we are about to receive
self.queue.push_back(QueueEntry {
seq: packet.seq,
pts: start.pts_for_seq(packet.seq),
consumed: SampleDuration::zero(),
packet: None,
});
}
// INVARIANT: at this point queue is non-empty and contains an
// allocated slot for the packet we just received
let front_seq = self.queue.front().unwrap().seq;
let idx_for_packet = (packet.seq - front_seq) as usize;
let slot = self.queue.get_mut(idx_for_packet).unwrap();
assert!(slot.seq == packet.seq);
slot.packet = Some(*packet);
}
pub fn fill_stream_buffer(&mut self, mut data: &mut [f32], pts: Timestamp) {
// complete frames only:
assert!(data.len() % 2 == 0);
// get stream start timing information:
let Some(start) = self.start.as_mut() else {
// stream hasn't started, just fill buffer with silence and return
data.fill(0f32);
return;
};
// sync up to stream if necessary:
if !start.sync {
loop {
let Some(front) = self.queue.front_mut() else {
// nothing at front of queue?
data.fill(0f32);
return;
};
eprintln!("something at front of queue!");
if pts > front.pts {
// frame has already begun, we are late
let late = pts.duration_since(front.pts);
if late >= SampleDuration::ONE_PACKET {
// we are late by more than a packet, skip to the next
eprintln!("late by more than a packet, pts: {:?}, front pts: {:?}, late: {:?}", pts, front.pts, late);
self.queue.pop_front();
continue;
}
// partially consume this packet to sync up
front.consumed = late;
// we are synced
println!("SYNC!");
start.sync = true;
break;
}
// otherwise we are early
let early = front.pts.duration_since(pts);
if early >= SampleDuration::ONE_PACKET {
// we are early by more than a packet, fill buffer with silence and return
eprintln!("early by more than a packet");
data.fill(0f32);
return;
}
// we are early, but not an entire packet timing's early
// partially output some zeroes
let zero_count = early.as_buffer_offset();
data[0..zero_count].fill(0f32);
data = &mut data[zero_count..];
// then mark ourselves as synced and fall through to regular processing
println!("SYNC!");
start.sync = true;
break;
}
}
// copy data to out
while data.len() > 0 {
let Some(front) = self.queue.front_mut() else {
eprintln!("nothing at the front of the queue!");
data.fill(0f32);
return;
};
let buffer = front.as_full_buffer();
let buffer_offset = front.consumed.as_buffer_offset();
let buffer_remaining = buffer.len() - buffer_offset;
let copy_count = std::cmp::min(data.len(), buffer_remaining);
let buffer_copy_end = buffer_offset + copy_count;
data[0..copy_count].copy_from_slice(&buffer[buffer_offset..buffer_copy_end]);
data = &mut data[copy_count..];
front.consumed = SampleDuration::from_buffer_offset(buffer_copy_end);
// pop packet if fully consumed
if front.consumed == SampleDuration::ONE_PACKET {
self.queue.pop_front();
}
}
}
}

91
src/time.rs Normal file
View file

@ -0,0 +1,91 @@
use std::time::SystemTime;
use crate::protocol::{self, TimestampMicros};
/// A timestamp with implicit denominator SAMPLE_RATE
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub struct Timestamp(u64);
impl Timestamp {
pub fn now() -> Timestamp {
// SystemTime::now uses CLOCK_REALTIME on Linux, which is exactly what we want
// https://doc.rust-lang.org/std/time/struct.SystemTime.html#platform-specific-behavior
let micros = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.expect("SystemTime::now before UNIX_EPOCH!")
.as_micros();
let micros = u64::try_from(micros)
.expect("can't narrow timestamp to u64");
Timestamp::from_micros_lossy(TimestampMicros(micros))
}
}
impl Timestamp {
pub fn to_micros_lossy(&self) -> TimestampMicros {
let ts = u128::from(self.0);
let micros = (ts * 1_000_000) / u128::from(protocol::SAMPLE_RATE.0);
let micros = u64::try_from(micros)
.expect("can't narrow timestamp to u64");
TimestampMicros(micros)
}
pub fn from_micros_lossy(micros: TimestampMicros) -> Timestamp {
let micros = u128::from(micros.0);
let ts = (micros * u128::from(protocol::SAMPLE_RATE.0)) / 1_000_000;
let ts = u64::try_from(ts)
.expect("can't narrow timestamp to u64");
Timestamp(ts)
}
pub fn add(&self, duration: SampleDuration) -> Timestamp {
Timestamp(self.0.checked_add(duration.0).unwrap())
}
pub fn sub(&self, duration: SampleDuration) -> Timestamp {
Timestamp(self.0.checked_sub(duration.0).unwrap())
}
pub fn duration_since(&self, other: Timestamp) -> SampleDuration {
SampleDuration(self.0.checked_sub(other.0).unwrap())
}
}
/// A duration with implicit denominator SAMPLE_RATE
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub struct SampleDuration(u64);
impl SampleDuration {
pub const ONE_PACKET: SampleDuration = SampleDuration::from_sample_count(protocol::FRAMES_PER_PACKET as u64);
pub const fn zero() -> Self {
SampleDuration(0)
}
pub const fn from_sample_count(samples: u64) -> Self {
SampleDuration(samples)
}
pub fn from_std_duration_lossy(duration: std::time::Duration) -> SampleDuration {
let duration = duration.as_micros() * u128::from(protocol::SAMPLE_RATE.0) / 1_000_000;
let duration = u64::try_from(duration).expect("can't narrow duration to u64");
SampleDuration(duration)
}
pub fn mul(&self, times: u64) -> Self {
SampleDuration(self.0.checked_mul(times).unwrap())
}
pub fn as_buffer_offset(&self) -> usize {
let offset = self.0 * u64::from(protocol::CHANNELS);
usize::try_from(offset).unwrap()
}
pub fn from_buffer_offset(offset: usize) -> Self {
let channels = usize::from(protocol::CHANNELS);
assert!(offset % channels == 0);
SampleDuration(u64::try_from(offset / channels).unwrap())
}
}