blob: 8fae56fef8741a4fde46801f417110bf97ac3368 [file] [log] [blame]
// Copyright 2023 Google LLC
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
use std::collections::{HashMap, VecDeque};
use std::sync::Arc;
use std::task::Poll;
use bt_common::packet_encoding::Decodable;
use bt_gatt::client::CharacteristicNotification;
use bt_gatt::types::{Error as BtGattError, Handle};
use futures::stream::{BoxStream, FusedStream, SelectAll};
use futures::{Stream, StreamExt};
use parking_lot::Mutex;
use crate::client::error::Error;
use crate::client::error::ServiceError;
use crate::client::BroadcastSourceIdTracker;
use crate::types::*;
#[derive(Clone, Debug, PartialEq)]
pub enum BroadcastAudioScanServiceEvent {
// Broadcast Audio Scan Service (BASS) server requested for SyncInfo through PAST procedure.
SyncInfoRequested(BroadcastId),
// BASS server failed to synchornize to PA or did not synchronize to PA.
NotSyncedToPa(BroadcastId),
// BASS server successfully synced to PA.
SyncedToPa(BroadcastId),
// BASS server failed to sync to PA since SyncInfo wasn't received.
SyncedFailedNoPast(BroadcastId),
// BASS server requires code to since the BIS is encrypted.
BroadcastCodeRequired(BroadcastId),
// BASS server failed to decrypt BIS using the previously provided code.
InvalidBroadcastCode(BroadcastId, [u8; 16]),
// Received a packet from the BASS server not recognized by this library.
UnknownPacket,
}
impl BroadcastAudioScanServiceEvent {
pub(crate) fn from_broadcast_receive_state(
state: &ReceiveState,
) -> Vec<BroadcastAudioScanServiceEvent> {
let mut events = Vec::new();
let pa_sync_state = state.pa_sync_state();
let broadcast_id = state.broadcast_id();
match pa_sync_state {
PaSyncState::SyncInfoRequest => {
events.push(BroadcastAudioScanServiceEvent::SyncInfoRequested(broadcast_id))
}
PaSyncState::Synced => {
events.push(BroadcastAudioScanServiceEvent::SyncedToPa(broadcast_id))
}
PaSyncState::FailedToSync | PaSyncState::NotSynced => {
events.push(BroadcastAudioScanServiceEvent::NotSyncedToPa(broadcast_id))
}
PaSyncState::NoPast => {
events.push(BroadcastAudioScanServiceEvent::SyncedFailedNoPast(broadcast_id))
}
}
match state.big_encryption() {
EncryptionStatus::BroadcastCodeRequired => {
events.push(BroadcastAudioScanServiceEvent::BroadcastCodeRequired(broadcast_id))
}
EncryptionStatus::BadCode(code) => events.push(
BroadcastAudioScanServiceEvent::InvalidBroadcastCode(broadcast_id, code.clone()),
),
_ => {}
};
events
}
}
pub struct BroadcastAudioScanServiceEventStream {
// Polled to receive BASS notifications.
notification_streams:
SelectAll<BoxStream<'static, Result<CharacteristicNotification, BtGattError>>>,
event_queue: VecDeque<Result<BroadcastAudioScanServiceEvent, Error>>,
terminated: bool,
// States to be updated.
id_tracker: Arc<Mutex<BroadcastSourceIdTracker>>,
receive_states: Arc<Mutex<HashMap<Handle, Option<BroadcastReceiveState>>>>,
}
impl BroadcastAudioScanServiceEventStream {
pub(crate) fn new(
notification_streams: SelectAll<
BoxStream<'static, Result<CharacteristicNotification, BtGattError>>,
>,
id_tracker: Arc<Mutex<BroadcastSourceIdTracker>>,
receive_states: Arc<Mutex<HashMap<Handle, Option<BroadcastReceiveState>>>>,
) -> Self {
Self {
notification_streams,
event_queue: VecDeque::new(),
terminated: false,
id_tracker,
receive_states,
}
}
}
impl FusedStream for BroadcastAudioScanServiceEventStream {
fn is_terminated(&self) -> bool {
self.terminated
}
}
impl Stream for BroadcastAudioScanServiceEventStream {
type Item = Result<BroadcastAudioScanServiceEvent, Error>;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
if self.terminated {
return Poll::Ready(None);
}
loop {
match self.notification_streams.poll_next_unpin(cx) {
Poll::Pending => {}
Poll::Ready(None) => {
self.event_queue.push_back(Err(Error::Service(
super::error::ServiceError::NotificationChannelClosed(format!(
"GATT notification stream for BRS characteristics closed"
)),
)));
}
Poll::Ready(Some(received)) => {
match received {
Err(error) => match error {
BtGattError::PeerNotRecognized(_)
| BtGattError::ScanFailed(_)
| BtGattError::Other(_) => {
self.event_queue.push_back(Err(Error::Service(
ServiceError::NotificationChannelClosed(format!(
"unexpected error encountered from GATT notification"
)),
)));
}
BtGattError::PeerDisconnected(id) => {
self.event_queue.push_back(Err(Error::Service(
ServiceError::NotificationChannelClosed(format!(
"peer ({id}) disconnected"
)),
)));
}
_ => {} // TODO(b/308483171): decide what to do for non-critical errors.
},
Ok(notification) => {
let Ok((brs, _)) =
BroadcastReceiveState::decode(notification.value.as_slice())
else {
self.event_queue
.push_back(Ok(BroadcastAudioScanServiceEvent::UnknownPacket));
break;
};
match &brs {
BroadcastReceiveState::Empty => {}
BroadcastReceiveState::NonEmpty(state) => {
let events = BroadcastAudioScanServiceEvent::from_broadcast_receive_state(state);
events
.into_iter()
.for_each(|e| self.event_queue.push_back(Ok(e)));
// Update broadcast ID to source ID mapping.
let _ = self
.id_tracker
.lock()
.update(state.source_id(), state.broadcast_id());
}
};
{
// Update the Broadcast Receive States.
let mut lock = self.receive_states.lock();
let char = lock.get_mut(&notification.handle).unwrap();
char.replace(brs);
}
}
}
}
};
break;
}
let popped = self.event_queue.pop_front();
match popped {
None => Poll::Pending,
Some(item) => match item {
Ok(event) => Poll::Ready(Some(Ok(event))),
Err(e) => {
// If an error was received, we terminate the event stream, but send an error to indicate why it was terminated.
self.terminated = true;
Poll::Ready(Some(Err(e)))
}
},
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures::channel::mpsc::unbounded;
#[test]
fn poll_broadcast_audio_scan_service_event_stream() {
let mut streams = SelectAll::new();
let (sender1, receiver1) = unbounded();
let (sender2, receiver2) = unbounded();
streams.push(receiver1.boxed());
streams.push(receiver2.boxed());
let id_tracker = Arc::new(Mutex::new(BroadcastSourceIdTracker::new()));
let receive_states =
Arc::new(Mutex::new(HashMap::from([(Handle(0x1), None), (Handle(0x2), None)])));
let mut event_streams =
BroadcastAudioScanServiceEventStream::new(streams, id_tracker, receive_states);
// Send notifications to underlying streams.
let bad_code_status =
EncryptionStatus::BadCode([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]);
#[rustfmt::skip]
sender1
.unbounded_send(Ok(CharacteristicNotification {
handle: Handle(0x1),
value: vec![
0x01, AddressType::Public as u8, // source id and address type
0x02, 0x03, 0x04, 0x05, 0x06, 0x07, // address
0x01, 0x01, 0x02, 0x03, // ad set id and broadcast id
PaSyncState::FailedToSync as u8,
bad_code_status.raw_value(),
1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16, // bad code
0x00, // no subgroups
],
maybe_truncated: false,
}))
.expect("should send");
#[rustfmt::skip]
sender2
.unbounded_send(Ok(CharacteristicNotification {
handle: Handle(0x2),
value: vec![
0x02, AddressType::Public as u8, // source id and address type
0x03, 0x04, 0x05, 0x06, 0x07, 0x08, // address
0x01, 0x02, 0x03, 0x04, // ad set id and broadcast id
PaSyncState::NoPast as u8,
EncryptionStatus::NotEncrypted.raw_value(),
0x00, // no subgroups
],
maybe_truncated: false,
}))
.expect("should send");
// Events should have been generated from notifications.
let mut noop_cx = futures::task::Context::from_waker(futures::task::noop_waker_ref());
match event_streams.poll_next_unpin(&mut noop_cx) {
Poll::Ready(Some(Ok(event))) => assert_eq!(
event,
BroadcastAudioScanServiceEvent::NotSyncedToPa(BroadcastId::new(0x030201))
),
_ => panic!("should have received event"),
}
match event_streams.poll_next_unpin(&mut noop_cx) {
Poll::Ready(Some(Ok(event))) => assert_eq!(
event,
BroadcastAudioScanServiceEvent::InvalidBroadcastCode(
BroadcastId::new(0x030201),
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]
)
),
_ => panic!("should have received event"),
}
match event_streams.poll_next_unpin(&mut noop_cx) {
Poll::Ready(Some(Ok(event))) => assert_eq!(
event,
BroadcastAudioScanServiceEvent::SyncedFailedNoPast(BroadcastId::new(0x040302))
),
_ => panic!("should have received event"),
}
// Should be pending because no more events generated from notifications.
assert!(event_streams.poll_next_unpin(&mut noop_cx).is_pending());
// Send notifications to underlying streams.
#[rustfmt::skip]
sender2
.unbounded_send(Ok(CharacteristicNotification {
handle: Handle(0x2),
value: vec![
0x02, AddressType::Public as u8, // source id and address type
0x03, 0x04, 0x05, 0x06, 0x07, 0x08, // address
0x01, 0x02, 0x03, 0x04, // ad set id and broadcast id
PaSyncState::Synced as u8,
EncryptionStatus::NotEncrypted.raw_value(),
0x00, // no subgroups
],
maybe_truncated: false,
}))
.expect("should send");
// Event should have been generated from notification.
let mut noop_cx = futures::task::Context::from_waker(futures::task::noop_waker_ref());
match event_streams.poll_next_unpin(&mut noop_cx) {
Poll::Ready(Some(Ok(event))) => assert_eq!(
event,
BroadcastAudioScanServiceEvent::SyncedToPa(BroadcastId::new(0x040302))
),
_ => panic!("should have received event"),
}
// Should be pending because no more events generated from notifications.
assert!(event_streams.poll_next_unpin(&mut noop_cx).is_pending());
}
}