| // Copyright 2024 Google LLC |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| use bt_bap::types::BroadcastId; |
| use bt_bass::client::error::Error as BassClientError; |
| use bt_bass::client::event::Event as BassEvent; |
| use bt_bass::client::BigToBisSync; |
| use bt_bass::types::PaSync; |
| use bt_common::debug_command::CommandRunner; |
| use bt_common::debug_command::CommandSet; |
| use bt_common::gen_commandset; |
| use bt_common::PeerId; |
| use bt_gatt::pii::GetPeerAddr; |
| |
| use futures::stream::FusedStream; |
| use futures::Future; |
| use futures::Stream; |
| use num::Num; |
| use parking_lot::Mutex; |
| use std::collections::HashSet; |
| use std::num::ParseIntError; |
| use std::sync::Arc; |
| |
| use crate::assistant::event::*; |
| use crate::assistant::peer::Peer; |
| use crate::assistant::Error; |
| use crate::*; |
| |
| gen_commandset! { |
| AssistantCmd { |
| Info = ("info", [], [], "Print information from broadcast assistant"), |
| Connect = ("connect", [], ["peer_id"], "Attempt connection to scan delegator"), |
| Disconnect = ("disconnect", [], [], "Disconnect from connected scan delegator"), |
| SendBroadcastCode = ("set-broadcast-code", [], ["broadcast_id", "broadcast_code"], "Attempt to send decryption key for a particular broadcast source to the scan delegator"), |
| AddBroadcastSource = ("add-broadcast-source", [], ["broadcast_source_pid", "pa_sync", "[bis_sync]"], "Attempt to add a particular broadcast source to the scan delegator"), |
| UpdatePaSync = ("update-pa-sync", [], ["broadcast_id", "pa_sync", "[bis_sync]"], "Attempt to update the scan delegator's desired pa sync to a particular broadcast source"), |
| RemoveBroadcastSource = ("remove-broadcast-source", [], ["broadcast_id"], "Attempt to remove a particular broadcast source to the scan delegator"), |
| RemoteScanStarted = ("inform-scan-started", [], [], "Inform the scan delegator that we have started scanning on behalf of it"), |
| RemoteScanStopped = ("inform-scan-stopped", [], [], "Inform the scan delegator that we have stopped scanning on behalf of it"), |
| // TODO(http://b/433285146): Once PA scanning is implemented, remove bottom 3 commands. |
| ForceDiscoverBroadcastSource = ("force-discover-broadcast-source", [], ["broadcast_source_pid", "address", "address_type", "advertising_sid"], "Force the broadcast assistant to become aware of the provided broadcast source"), |
| ForceDiscoverSourceMetadata = ("force-discover-source-metadata", [], ["broadcast_source_pid", "comma_separated_raw_metadata"], "Force the broadcast assistant to become aware of the provided metadata, each BIG's metadata is comma separated"), |
| ForceDiscoverEmptySourceMetadata = ("force-discover-empty-source-metadata", [], ["broadcast_source_pid", "num_big"], "Force the broadcast assistant to become aware of the provided empty metadata, as many as # BIGs specified"), |
| } |
| } |
| |
| pub struct AssistantDebug<T: bt_gatt::GattTypes, R: GetPeerAddr> { |
| assistant: BroadcastAssistant<T>, |
| connected_peer: Mutex<Option<Arc<Peer<T>>>>, |
| started: bool, |
| peer_addr_getter: R, |
| } |
| |
| impl<T: bt_gatt::GattTypes + 'static, R: GetPeerAddr> AssistantDebug<T, R> { |
| pub fn new(central: T::Central, peer_addr_getter: R) -> Self |
| where |
| <T as bt_gatt::GattTypes>::NotificationStream: std::marker::Send, |
| { |
| Self { |
| assistant: BroadcastAssistant::<T>::new(central), |
| connected_peer: Mutex::new(None), |
| started: false, |
| peer_addr_getter, |
| } |
| } |
| |
| pub fn start(&mut self) -> Result<EventStream<T>, Error> { |
| let event_stream = self.assistant.start()?; |
| self.started = true; |
| Ok(event_stream) |
| } |
| |
| pub fn look_for_scan_delegators(&mut self) -> T::ScanResultStream { |
| self.assistant.scan_for_scan_delegators() |
| } |
| |
| pub fn take_connected_peer_event_stream( |
| &mut self, |
| ) -> Result<impl Stream<Item = Result<BassEvent, BassClientError>> + FusedStream, Error> { |
| let mut lock = self.connected_peer.lock(); |
| let Some(peer_arc) = lock.as_mut() else { |
| return Err(Error::Generic(format!("not connected to any scan delegator peer"))); |
| }; |
| let Some(peer) = Arc::get_mut(peer_arc) else { |
| return Err(Error::Generic(format!( |
| "cannot get mutable peer reference, it is shared elsewhere" |
| ))); |
| }; |
| peer.take_event_stream().map_err(|e| Error::Generic(format!("{e:?}"))) |
| } |
| |
| async fn with_peer<F, Fut>(&self, f: F) |
| where |
| F: FnOnce(Arc<Peer<T>>) -> Fut, |
| Fut: Future<Output = Result<(), crate::assistant::peer::Error>>, |
| { |
| let Some(peer) = self.connected_peer.lock().clone() else { |
| eprintln!("not connected to a scan delegator"); |
| return; |
| }; |
| if let Err(e) = f(peer).await { |
| eprintln!("failed to perform oepration: {e:?}"); |
| } |
| } |
| } |
| |
| /// Attempt to parse a string into an integer. If the string begins with 0x, |
| /// treat the rest of the string as a hex value, otherwise treat it as decimal. |
| pub(crate) fn parse_int<N>(input: &str) -> Result<N, ParseIntError> |
| where |
| N: Num<FromStrRadixErr = ParseIntError>, |
| { |
| if input.starts_with("0x") { |
| N::from_str_radix(&input[2..], 16) |
| } else { |
| N::from_str_radix(input, 10) |
| } |
| } |
| |
| fn parse_peer_id(input: &str) -> Result<PeerId, String> { |
| let raw_id = match parse_int(input) { |
| Err(_) => return Err(format!("falied to parse int from {input}")), |
| Ok(i) => i, |
| }; |
| |
| Ok(PeerId(raw_id)) |
| } |
| |
| #[cfg(any(test, feature = "debug"))] |
| fn parse_bd_addr(input: &str) -> Result<[u8; 6], String> { |
| let tokens: Vec<u8> = |
| input.split(':').map(|t| u8::from_str_radix(t, 16)).filter_map(Result::ok).collect(); |
| if tokens.len() != 6 { |
| return Err(format!("failed to parse bd address from {input}")); |
| } |
| tokens.try_into().map_err(|e| format!("{e:?}")) |
| } |
| |
| fn parse_broadcast_id(input: &str) -> Result<BroadcastId, String> { |
| let raw_id: u32 = match parse_int(input) { |
| Err(_) => return Err(format!("falied to parse int from {input}")), |
| Ok(i) => i, |
| }; |
| raw_id.try_into().map_err(|e| format!("{e:?}")) |
| } |
| |
| fn parse_bis_sync(input: &str) -> BigToBisSync { |
| input.split(',').filter_map(|t| { |
| let parts: Vec<_> = t.split('-').collect(); |
| if parts.len() != 2 { |
| eprintln!("invalid big-bis sync info {t}. should be in <Ith_BIG>-<BIS_INDEX> format, will be ignored"); |
| return None; |
| } |
| let ith_big = parse_int(parts[0]).ok()?; |
| let bis_index = parse_int(parts[1]).ok()?; |
| Some((ith_big, bis_index)) |
| }).collect() |
| } |
| |
| impl<T: bt_gatt::GattTypes + 'static, R: GetPeerAddr> CommandRunner for AssistantDebug<T, R> |
| where |
| <T as bt_gatt::GattTypes>::NotificationStream: std::marker::Send, |
| { |
| type Set = AssistantCmd; |
| |
| fn run( |
| &self, |
| cmd: Self::Set, |
| args: Vec<String>, |
| ) -> impl futures::Future<Output = Result<(), impl std::error::Error>> { |
| async move { |
| match cmd { |
| AssistantCmd::Info => { |
| let known = self.assistant.known_broadcast_sources(); |
| eprintln!("Known Broadcast Sources:"); |
| for (id, s) in known { |
| eprintln!("PeerId ({id}), source: {s:?}"); |
| } |
| } |
| AssistantCmd::Connect => { |
| if self.connected_peer.lock().is_some() { |
| eprintln!( |
| "peer already connected. Call `disconnect` first: {}", |
| AssistantCmd::Disconnect.help_simple() |
| ); |
| return Ok(()); |
| } |
| if args.len() != 1 { |
| eprintln!("usage: {}", AssistantCmd::Connect.help_simple()); |
| return Ok(()); |
| } |
| |
| let Ok(peer_id) = parse_peer_id(&args[0]) else { |
| eprintln!("invalid peer id: {}", args[0]); |
| return Ok(()); |
| }; |
| |
| let peer = self.assistant.connect_to_scan_delegator(peer_id).await; |
| match peer { |
| Ok(peer) => { |
| *self.connected_peer.lock() = Some(Arc::new(peer)); |
| } |
| Err(e) => { |
| eprintln!("failed to connect to scan delegator: {e:?}"); |
| } |
| }; |
| } |
| AssistantCmd::Disconnect => { |
| if self.connected_peer.lock().take().is_none() { |
| eprintln!("not connected to a scan delegator"); |
| } |
| } |
| AssistantCmd::SendBroadcastCode => { |
| if args.len() != 2 { |
| eprintln!("usage: {}", AssistantCmd::SendBroadcastCode.help_simple()); |
| return Ok(()); |
| } |
| |
| let Ok(broadcast_id) = parse_broadcast_id(&args[0]) else { |
| eprintln!("invalid broadcast id: {}", args[0]); |
| return Ok(()); |
| }; |
| |
| let code = args[1].as_bytes(); |
| if code.len() > 16 { |
| eprintln!( |
| "invalid broadcast code: {}. should be at max length 16", |
| args[1] |
| ); |
| return Ok(()); |
| } |
| let mut passcode_vec = vec![0; 16]; |
| passcode_vec[16 - code.len()..16].copy_from_slice(code); |
| self.with_peer(|peer| async move { |
| peer.send_broadcast_code(broadcast_id, passcode_vec.try_into().unwrap()) |
| .await |
| }) |
| .await; |
| } |
| AssistantCmd::AddBroadcastSource => { |
| if args.len() < 2 { |
| eprintln!("usage: {}", AssistantCmd::AddBroadcastSource.help_simple()); |
| return Ok(()); |
| } |
| |
| let Ok(broadcast_source_pid) = parse_peer_id(&args[0]) else { |
| eprintln!("invalid broadcast id: {}", args[0]); |
| return Ok(()); |
| }; |
| |
| let pa_sync = match parse_int::<u8>(&args[1]) { |
| Ok(raw_val) if PaSync::try_from(raw_val).is_ok() => { |
| PaSync::try_from(raw_val).unwrap() |
| } |
| _ => { |
| eprintln!("invalid pa_sync: {}", args[1]); |
| return Ok(()); |
| } |
| }; |
| |
| let bis_sync = |
| if args.len() == 3 { parse_bis_sync(&args[2]) } else { HashSet::new() }; |
| |
| self.with_peer(|peer| async move { |
| peer.add_broadcast_source( |
| broadcast_source_pid, |
| &self.peer_addr_getter, |
| pa_sync, |
| bis_sync, |
| ) |
| .await |
| }) |
| .await; |
| } |
| AssistantCmd::UpdatePaSync => { |
| if args.len() < 2 { |
| eprintln!("usage: {}", AssistantCmd::UpdatePaSync.help_simple()); |
| return Ok(()); |
| } |
| |
| let Ok(broadcast_id) = parse_broadcast_id(&args[0]) else { |
| eprintln!("invalid broadcast id: {}", args[0]); |
| return Ok(()); |
| }; |
| |
| let pa_sync = match parse_int::<u8>(&args[1]) { |
| Ok(raw_val) if PaSync::try_from(raw_val).is_ok() => { |
| PaSync::try_from(raw_val).unwrap() |
| } |
| _ => { |
| eprintln!("invalid pa_sync: {}", args[1]); |
| return Ok(()); |
| } |
| }; |
| |
| let bis_sync = |
| if args.len() == 3 { parse_bis_sync(&args[2]) } else { HashSet::new() }; |
| |
| self.with_peer(|peer| async move { |
| peer.update_broadcast_source_sync(broadcast_id, pa_sync, bis_sync).await |
| }) |
| .await; |
| } |
| AssistantCmd::RemoveBroadcastSource => { |
| if args.len() != 1 { |
| eprintln!("usage: {}", AssistantCmd::RemoveBroadcastSource.help_simple()); |
| return Ok(()); |
| } |
| |
| let Ok(broadcast_id) = parse_broadcast_id(&args[0]) else { |
| eprintln!("invalid broadcast id: {}", args[0]); |
| return Ok(()); |
| }; |
| |
| self.with_peer(|peer| async move { |
| peer.remove_broadcast_source(broadcast_id).await |
| }) |
| .await; |
| } |
| AssistantCmd::RemoteScanStarted => { |
| self.with_peer(|peer| async move { peer.inform_remote_scan_started().await }) |
| .await; |
| } |
| AssistantCmd::RemoteScanStopped => { |
| self.with_peer(|peer| async move { peer.inform_remote_scan_stopped().await }) |
| .await; |
| } |
| #[cfg(feature = "debug")] |
| AssistantCmd::ForceDiscoverBroadcastSource => { |
| if args.len() != 4 { |
| eprintln!( |
| "usage: {}", |
| AssistantCmd::ForceDiscoverBroadcastSource.help_simple() |
| ); |
| return Ok(()); |
| } |
| |
| let Ok(peer_id) = parse_peer_id(&args[0]) else { |
| eprintln!("invalid peer id: {}", args[0]); |
| return Ok(()); |
| }; |
| |
| let Ok(address) = parse_bd_addr(&args[1]) else { |
| eprintln!("invalid address: {}", args[1]); |
| return Ok(()); |
| }; |
| |
| let Ok(raw_addr_type) = parse_int::<u8>(&args[2]) else { |
| eprintln!("invalid address type: {}", args[2]); |
| return Ok(()); |
| }; |
| |
| let Ok(raw_ad_sid) = parse_int::<u8>(&args[3]) else { |
| eprintln!("invalid advertising sid: {}", args[3]); |
| return Ok(()); |
| }; |
| |
| match self.assistant.force_discover_broadcast_source( |
| peer_id, |
| address, |
| raw_addr_type, |
| raw_ad_sid, |
| ) { |
| Ok(source) => { |
| eprintln!("broadcast source after additional info: {source:?}") |
| } |
| Err(e) => { |
| eprintln!("failed to enter in broadcast source information: {e:?}") |
| } |
| } |
| } |
| #[cfg(feature = "debug")] |
| AssistantCmd::ForceDiscoverSourceMetadata => { |
| if args.len() < 2 { |
| eprintln!( |
| "usage: {}", |
| AssistantCmd::ForceDiscoverSourceMetadata.help_simple() |
| ); |
| return Ok(()); |
| } |
| |
| let Ok(peer_id) = parse_peer_id(&args[0]) else { |
| println!("invalid peer id: {}", args[0]); |
| return Ok(()); |
| }; |
| |
| let mut raw_metadata = Vec::new(); |
| for i in 1..args.len() { |
| let ith_metadata: Vec<u8> = args[i] |
| .split(',') |
| .map(|t| parse_int(t)) |
| .filter_map(Result::ok) |
| .collect(); |
| raw_metadata.push(ith_metadata); |
| } |
| |
| match self |
| .assistant |
| .force_discover_broadcast_source_metadata(peer_id, raw_metadata) |
| { |
| Ok(source) => eprintln!("broadcast source with metadata: {source:?}"), |
| Err(e) => eprintln!("failed to enter in broadcast source metadata: {e:?}"), |
| } |
| } |
| #[cfg(feature = "debug")] |
| AssistantCmd::ForceDiscoverEmptySourceMetadata => { |
| if args.len() != 2 { |
| eprintln!( |
| "usage: {}", |
| AssistantCmd::ForceDiscoverEmptySourceMetadata.help_simple() |
| ); |
| return Ok(()); |
| } |
| |
| let Ok(peer_id) = parse_peer_id(&args[0]) else { |
| eprintln!("invalid peer id: {}", args[0]); |
| return Ok(()); |
| }; |
| |
| let Ok(num_big) = parse_int::<usize>(&args[1]) else { |
| eprintln!("invalid # of bigs: {}", args[1]); |
| return Ok(()); |
| }; |
| |
| let mut raw_metadata = Vec::new(); |
| for _i in 0..num_big { |
| raw_metadata.push(vec![]); |
| } |
| |
| match self |
| .assistant |
| .force_discover_broadcast_source_metadata(peer_id, raw_metadata) |
| { |
| Ok(source) => eprintln!("broadcast source with metadata: {source:?}"), |
| Err(e) => { |
| eprintln!("failed to enter in empty broadcast source metadata: {e:?}") |
| } |
| } |
| } |
| #[cfg(not(feature = "debug"))] |
| c => eprintln!("unknown command: {c:?}"), |
| } |
| Ok::<(), Error>(()) |
| } |
| } |
| } |
| |
| #[cfg(test)] |
| mod tests { |
| use super::*; |
| |
| #[test] |
| fn test_parse_peer_id() { |
| // In hex string. |
| assert_eq!(parse_peer_id("0x678abc").expect("should be ok"), PeerId(0x678abc)); |
| // Decimal equivalent. |
| assert_eq!(parse_peer_id("6785724").expect("should be ok"), PeerId(0x678abc)); |
| |
| // Invalid peer id. |
| let _ = parse_peer_id("0123zzz").expect_err("should fail"); |
| } |
| |
| #[test] |
| fn test_parse_bd_addr() { |
| assert_eq!( |
| parse_bd_addr("3c:80:f1:ed:32:2c").expect("should be ok"), |
| [0x3c, 0x80, 0xf1, 0xed, 0x32, 0x2c] |
| ); |
| // Address with 5 parts is invalid. |
| let _ = parse_bd_addr("3c:80:f1:ed:32").expect_err("should fail"); |
| // Address with 6 parts but one of them empty is invalid. |
| let _ = parse_bd_addr("3c:80:f1::32:2c").expect_err("should fail"); |
| let _ = parse_bd_addr(":80:f1:ed:32:2c").expect_err("should fail"); |
| let _ = parse_bd_addr("3c:80:f1:ed:32:").expect_err("should fail"); |
| // Address not delimited by : is invalid. |
| let _ = parse_bd_addr("3c.80.f1.ed.32.2c").expect_err("should fail"); |
| } |
| |
| #[test] |
| fn test_parse_broadcast_id() { |
| assert_eq!(parse_broadcast_id("0xABCD").expect("should work"), 0xABCD.try_into().unwrap()); |
| assert_eq!(parse_broadcast_id("123456").expect("should work"), 123456.try_into().unwrap()); |
| |
| // Invalid string cannot be parsed. |
| let _ = parse_broadcast_id("0xABYZ").expect_err("should fail"); |
| |
| // Broadcast ID is actually a 3 byte long number. |
| let _ = parse_broadcast_id("16777216").expect_err("should fail"); |
| } |
| |
| #[test] |
| fn test_parse_bis_sync() { |
| let bis_sync = parse_bis_sync("0-1,0-2,1-1"); |
| assert_eq!(bis_sync.len(), 3); |
| bis_sync.contains(&(0, 1)); |
| bis_sync.contains(&(0, 2)); |
| bis_sync.contains(&(1, 1)); |
| |
| // Will ignore invalid values. |
| let bis_sync = parse_bis_sync("0-1,0-2,1:1,1-1-1,"); |
| assert_eq!(bis_sync.len(), 2); |
| bis_sync.contains(&(0, 1)); |
| bis_sync.contains(&(0, 2)); |
| |
| let bis_sync = parse_bis_sync("hellothisistoallynotvalid"); |
| assert_eq!(bis_sync.len(), 0); |
| } |
| } |