// Copyright 2023 Google LLC
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

pub mod error;
pub mod event;

use std::collections::HashMap;
use std::sync::Arc;

use futures::stream::{BoxStream, SelectAll, StreamExt};
use parking_lot::Mutex;
use tracing::warn;

use bt_common::packet_encoding::{Decodable, Encodable};
use bt_gatt::client::{CharacteristicNotification, PeerService, ServiceCharacteristic};
use bt_gatt::types::{Handle, WriteMode};

use crate::client::error::{Error, ServiceError};
use crate::client::event::*;
use crate::types::*;

const READ_CHARACTERISTIC_BUFFER_SIZE: usize = 255;

/// Keeps track of Source_ID and Broadcast_ID that are associated together.
/// Source_ID is assigned by the BASS server to a Broadcast Receive State characteristic.
/// If the remote peer with the BASS server autonomously synchronized to a PA
/// or accepted the Add Source operation, the server selects an empty Broadcast
/// Receive State characteristic to update or deletes one of the existing one to update.
/// However, because the concept of Source_ID is unqiue to BASS, we track the Broadcast_ID
/// that a Source_ID is associated so that it can be used by upper layers.
#[derive(Default)]
pub struct BroadcastSourceIdTracker {
    source_to_broadcast: HashMap<SourceId, BroadcastId>,
}

impl BroadcastSourceIdTracker {
    fn new() -> Self {
        Self::default()
    }

    /// Updates the broadcast ID associated with the given source ID and returns the
    /// previously-associated broadcast ID if it exists.
    fn update(&mut self, source_id: SourceId, broadcast_id: BroadcastId) -> Option<BroadcastId> {
        self.source_to_broadcast.insert(source_id, broadcast_id)
    }

    fn source_id(&self, broadcast_id: &BroadcastId) -> Option<SourceId> {
        self.source_to_broadcast
            .iter()
            .find_map(|(sid, bid)| (*bid == *broadcast_id).then_some(*sid))
    }
}

/// Manages connection to the Broadcast Audio Scan Service at the
/// remote Scan Delegator and writes/reads characteristics to/from it.
pub struct BroadcastAudioScanServiceClient<T: bt_gatt::GattTypes> {
    gatt_client: Box<T::PeerService>,
    id_tracker: Arc<Mutex<BroadcastSourceIdTracker>>,
    /// Broadcast Audio Scan Service only has one Broadcast Audio Scan Control Point characteristic
    /// according to BASS Section 3. There shall
    /// be one or more Broadcast Receive State characteristics.
    audio_scan_control_point: Handle,
    /// Broadcast Receive State characteristics can be used to determine the BASS status.
    receive_states: Arc<Mutex<HashMap<Handle, Option<BroadcastReceiveState>>>>,
    /// Keeps track of the broadcast codes that were sent to the remote BASS server.
    broadcast_codes: HashMap<SourceId, [u8; 16]>,
    // GATT notification streams for BRS characteristic value changes.
    notification_streams: Option<
        SelectAll<BoxStream<'static, Result<CharacteristicNotification, bt_gatt::types::Error>>>,
    >,
}

impl<T: bt_gatt::GattTypes> BroadcastAudioScanServiceClient<T> {
    async fn create(gatt_client: T::PeerService) -> Result<Self, Error>
    where
        <T as bt_gatt::GattTypes>::NotificationStream: std::marker::Send,
    {
        // BASS server should have a single Broadcast Audio Scan Control Point Characteristic.
        let bascp =
            ServiceCharacteristic::<T>::find(&gatt_client, BROADCAST_AUDIO_SCAN_CONTROL_POINT_UUID)
                .await
                .map_err(|e| Error::Gatt(e))?;
        if bascp.len() != 1 {
            let err = if bascp.len() == 0 {
                Error::Service(ServiceError::MissingCharacteristic)
            } else {
                Error::Service(ServiceError::ExtraScanControlPointCharacteristic)
            };
            return Err(err);
        }
        let bascp_handle = *bascp[0].handle();
        let brs_chars = Self::discover_brs_characteristics(&gatt_client).await?;
        let mut id_tracker = BroadcastSourceIdTracker::new();
        for c in brs_chars.values() {
            if let Some(read_value) = c {
                if let BroadcastReceiveState::NonEmpty(state) = read_value {
                    let _ = id_tracker.update(state.source_id(), state.broadcast_id());
                }
            }
        }

        let mut client = Self {
            gatt_client: Box::new(gatt_client),
            id_tracker: Arc::new(Mutex::new(id_tracker)),
            audio_scan_control_point: bascp_handle,
            receive_states: Arc::new(Mutex::new(brs_chars)),
            broadcast_codes: HashMap::new(),
            notification_streams: None,
        };
        client.register_notifications();
        Ok(client)
    }

    // Discover all the Broadcast Receive State characteristics.
    // On success, returns the HashMap of all Broadcast Received State Characteristics.
    async fn discover_brs_characteristics(
        gatt_client: &T::PeerService,
    ) -> Result<HashMap<Handle, Option<BroadcastReceiveState>>, Error> {
        let brs = ServiceCharacteristic::<T>::find(gatt_client, BROADCAST_RECEIVE_STATE_UUID)
            .await
            .map_err(|e| Error::Gatt(e))?;
        if brs.len() == 0 {
            return Err(Error::Service(ServiceError::MissingCharacteristic));
        }
        let mut brs_map = HashMap::new();
        for c in brs {
            // Read the value of the Broadcast Recieve State at the time of discovery for
            // record.
            let mut buf = vec![0; READ_CHARACTERISTIC_BUFFER_SIZE];
            match c.read(&mut buf[..]).await {
                Ok(read_bytes) => match BroadcastReceiveState::decode(&buf[0..read_bytes]) {
                    Ok((decoded, _decoded_bytes)) => {
                        brs_map.insert(*c.handle(), Some(decoded));
                        continue;
                    }
                    Err(e) => warn!(
                        "Failed to decode characteristic ({:?}) to Broadcast Receive State value: {:?}",
                        *c.handle(),
                        e
                    ),
                },
                Err(e) => warn!("Failed to read characteristic ({:?}) value: {:?}", *c.handle(), e),
            }
            brs_map.insert(*c.handle(), None);
        }
        Ok(brs_map)
    }

    fn register_notifications(&mut self)
    where
        <T as bt_gatt::GattTypes>::NotificationStream: std::marker::Send,
    {
        let mut notification_streams = SelectAll::new();
        {
            let lock = self.receive_states.lock();
            for handle in lock.keys() {
                let stream = self.gatt_client.subscribe(&handle);
                notification_streams.push(stream.boxed());
            }
        }
        self.notification_streams = Some(notification_streams);
    }

    /// Returns a stream that can be used by the upper layer to poll for BroadcastAudioScanServiceEvent.
    /// BroadcastAudioScanServiceEvents are generated based on BRS characteristic change received from GATT notification that
    /// are processed by BroadcastAudioScanServiceClient.
    /// This method should only be called once.
    /// Returns an error if the method is called for a second time.
    pub fn take_event_stream(&mut self) -> Option<BroadcastAudioScanServiceEventStream> {
        let notification_streams = self.notification_streams.take();
        let Some(streams) = notification_streams else {
            return None;
        };
        let event_stream = BroadcastAudioScanServiceEventStream::new(
            streams,
            self.id_tracker.clone(),
            self.receive_states.clone(),
        );
        Some(event_stream)
    }

    /// Sets the broadcast code for a particular broadcast stream.
    pub async fn set_broadcast_code(
        &mut self,
        broadcast_id: BroadcastId,
        broadcast_code: [u8; 16],
    ) -> Result<(), Error> {
        let source_id = self.id_tracker.lock().source_id(&broadcast_id).ok_or(Error::Generic(
            format!("Cannot find Source ID for specified Broadcast ID {broadcast_id:?}"),
        ))?;

        let op = SetBroadcastCodeOperation::new(source_id, broadcast_code.clone());
        let mut buf = vec![0; op.encoded_len()];
        let _ = op.encode(&mut buf[..]).map_err(|e| Error::Packet(e))?;

        let c = &self.audio_scan_control_point;
        let _ = self
            .gatt_client
            .write_characteristic(c, WriteMode::WithoutResponse, 0, buf.as_slice())
            .await
            .map_err(|e| Error::Gatt(e))?;

        // Save the broadcast code we sent.
        self.broadcast_codes.insert(source_id, broadcast_code);
        Ok(())
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    use std::task::Poll;

    use futures::executor::block_on;
    use futures::{pin_mut, FutureExt};

    use bt_common::Uuid;
    use bt_gatt::test_utils::*;
    use bt_gatt::types::{
        AttributePermissions, CharacteristicProperties, CharacteristicProperty, Handle,
    };
    use bt_gatt::Characteristic;

    const RECEIVE_STATE_1_HANDLE: Handle = Handle(1);
    const RECEIVE_STATE_2_HANDLE: Handle = Handle(2);
    const RECEIVE_STATE_3_HANDLE: Handle = Handle(3);
    const RANDOME_CHAR_HANDLE: Handle = Handle(4);
    const AUDIO_SCAN_CONTROL_POINT_HANDLE: Handle = Handle(5);

    fn setup_client() -> (BroadcastAudioScanServiceClient<FakeTypes>, FakePeerService) {
        let mut fake_peer_service = FakePeerService::new();
        // Add 3 Broadcast Receive State Characteristics, 1 Broadcast Audio Scan Control Point Characteristic,
        // and 1 random one.
        fake_peer_service.add_characteristic(
            Characteristic {
                handle: RECEIVE_STATE_1_HANDLE,
                uuid: BROADCAST_RECEIVE_STATE_UUID,
                properties: CharacteristicProperties(vec![
                    CharacteristicProperty::Broadcast,
                    CharacteristicProperty::Notify,
                ]),
                permissions: AttributePermissions::default(),
                descriptors: vec![],
            },
            vec![],
        );
        fake_peer_service.add_characteristic(
            Characteristic {
                handle: RECEIVE_STATE_2_HANDLE,
                uuid: BROADCAST_RECEIVE_STATE_UUID,
                properties: CharacteristicProperties(vec![
                    CharacteristicProperty::Broadcast,
                    CharacteristicProperty::Notify,
                ]),
                permissions: AttributePermissions::default(),
                descriptors: vec![],
            },
            vec![],
        );
        fake_peer_service.add_characteristic(
            Characteristic {
                handle: RECEIVE_STATE_3_HANDLE,
                uuid: BROADCAST_RECEIVE_STATE_UUID,
                properties: CharacteristicProperties(vec![
                    CharacteristicProperty::Broadcast,
                    CharacteristicProperty::Notify,
                ]),
                permissions: AttributePermissions::default(),
                descriptors: vec![],
            },
            vec![],
        );
        fake_peer_service.add_characteristic(
            Characteristic {
                handle: RANDOME_CHAR_HANDLE,
                uuid: Uuid::from_u16(0x1234),
                properties: CharacteristicProperties(vec![CharacteristicProperty::Notify]),
                permissions: AttributePermissions::default(),
                descriptors: vec![],
            },
            vec![],
        );
        fake_peer_service.add_characteristic(
            Characteristic {
                handle: AUDIO_SCAN_CONTROL_POINT_HANDLE,
                uuid: BROADCAST_AUDIO_SCAN_CONTROL_POINT_UUID,
                properties: CharacteristicProperties(vec![CharacteristicProperty::Broadcast]),
                permissions: AttributePermissions::default(),
                descriptors: vec![],
            },
            vec![],
        );

        let mut noop_cx = futures::task::Context::from_waker(futures::task::noop_waker_ref());
        let create_result =
            BroadcastAudioScanServiceClient::<FakeTypes>::create(fake_peer_service.clone());
        pin_mut!(create_result);
        let polled = create_result.poll_unpin(&mut noop_cx);
        let Poll::Ready(Ok(client)) = polled else {
            panic!("Expected BroadcastAudioScanServiceClient to be succesfully created");
        };

        (client, fake_peer_service)
    }

    #[test]
    fn create_client() {
        let (client, _) = setup_client();

        // Check that all the characteristics have been discovered.
        assert_eq!(client.audio_scan_control_point, AUDIO_SCAN_CONTROL_POINT_HANDLE);
        let broadcast_receive_states = client.receive_states.lock();
        assert!(broadcast_receive_states.contains_key(&RECEIVE_STATE_1_HANDLE));
        assert!(broadcast_receive_states.contains_key(&RECEIVE_STATE_2_HANDLE));
        assert!(broadcast_receive_states.contains_key(&RECEIVE_STATE_3_HANDLE));
    }

    #[test]
    fn create_client_fails_missing_characteristics() {
        // Missing scan control point characteristic.
        let mut fake_peer_service = FakePeerService::new();
        fake_peer_service.add_characteristic(
            Characteristic {
                handle: Handle(1),
                uuid: BROADCAST_RECEIVE_STATE_UUID,
                properties: CharacteristicProperties(vec![
                    CharacteristicProperty::Broadcast,
                    CharacteristicProperty::Notify,
                ]),
                permissions: AttributePermissions::default(),
                descriptors: vec![],
            },
            vec![],
        );

        let mut noop_cx = futures::task::Context::from_waker(futures::task::noop_waker_ref());
        let create_result =
            BroadcastAudioScanServiceClient::<FakeTypes>::create(fake_peer_service.clone());
        pin_mut!(create_result);
        let polled = create_result.poll_unpin(&mut noop_cx);
        let Poll::Ready(Err(_)) = polled else {
            panic!("Expected BroadcastAudioScanServiceClient to have failed");
        };

        // Missing receive state characteristic.
        let mut fake_peer_service: FakePeerService = FakePeerService::new();
        fake_peer_service.add_characteristic(
            Characteristic {
                handle: Handle(1),
                uuid: BROADCAST_AUDIO_SCAN_CONTROL_POINT_UUID,
                properties: CharacteristicProperties(vec![CharacteristicProperty::Broadcast]),
                permissions: AttributePermissions::default(),
                descriptors: vec![],
            },
            vec![],
        );

        let mut noop_cx = futures::task::Context::from_waker(futures::task::noop_waker_ref());
        let create_result =
            BroadcastAudioScanServiceClient::<FakeTypes>::create(fake_peer_service.clone());
        pin_mut!(create_result);
        let polled = create_result.poll_unpin(&mut noop_cx);
        let Poll::Ready(Err(_)) = polled else {
            panic!("Expected BroadcastAudioScanServiceClient to have failed");
        };
    }

    #[test]
    fn create_client_fails_duplicate_characteristics() {
        // More than one scan control point characteristics.
        let mut fake_peer_service = FakePeerService::new();
        fake_peer_service.add_characteristic(
            Characteristic {
                handle: Handle(1),
                uuid: BROADCAST_RECEIVE_STATE_UUID,
                properties: CharacteristicProperties(vec![
                    CharacteristicProperty::Broadcast,
                    CharacteristicProperty::Notify,
                ]),
                permissions: AttributePermissions::default(),
                descriptors: vec![],
            },
            vec![],
        );
        fake_peer_service.add_characteristic(
            Characteristic {
                handle: Handle(2),
                uuid: BROADCAST_AUDIO_SCAN_CONTROL_POINT_UUID,
                properties: CharacteristicProperties(vec![CharacteristicProperty::Broadcast]),
                permissions: AttributePermissions::default(),
                descriptors: vec![],
            },
            vec![],
        );
        fake_peer_service.add_characteristic(
            Characteristic {
                handle: Handle(3),
                uuid: BROADCAST_AUDIO_SCAN_CONTROL_POINT_UUID,
                properties: CharacteristicProperties(vec![CharacteristicProperty::Broadcast]),
                permissions: AttributePermissions::default(),
                descriptors: vec![],
            },
            vec![],
        );

        let mut noop_cx = futures::task::Context::from_waker(futures::task::noop_waker_ref());
        let create_result =
            BroadcastAudioScanServiceClient::<FakeTypes>::create(fake_peer_service.clone());
        pin_mut!(create_result);
        let polled = create_result.poll_unpin(&mut noop_cx);
        let Poll::Ready(Err(_)) = polled else {
            panic!("Expected BroadcastAudioScanServiceClient to have failed");
        };
    }

    #[test]
    fn pump_notifications() {
        let (mut client, mut fake_peer_service) = setup_client();
        let mut event_stream = client.take_event_stream().expect("stream was created");

        // Send notification for updating BRS characteristic to indicate it's synced and requires broadcast code.
        #[rustfmt::skip]
        fake_peer_service.add_characteristic(
            Characteristic {
                handle: RECEIVE_STATE_2_HANDLE,
                uuid: BROADCAST_RECEIVE_STATE_UUID,
                properties: CharacteristicProperties(vec![
                    CharacteristicProperty::Broadcast,
                    CharacteristicProperty::Notify,
                ]),
                permissions: AttributePermissions::default(),
                descriptors: vec![],
            },
            vec![
                0x02, AddressType::Public as u8,                      // source id and address type
                0x02, 0x03, 0x04, 0x05, 0x06, 0x07,                   // address
                0x01, 0x02, 0x03, 0x04,                               // ad set id and broadcast id
                PaSyncState::Synced as u8,
                EncryptionStatus::BroadcastCodeRequired.raw_value(),
                0x00,                                                 // no subgroups
            ],
        );

        // Check that synced and broadcast code required events were sent out.
        let mut noop_cx = futures::task::Context::from_waker(futures::task::noop_waker_ref());
        let recv_fut = event_stream.select_next_some();
        let event = block_on(recv_fut).expect("should receive event");
        assert_eq!(event, BroadcastAudioScanServiceEvent::SyncedToPa(BroadcastId::new(0x040302)));

        let recv_fut = event_stream.select_next_some();
        let event = block_on(recv_fut).expect("should receive event");
        assert_eq!(
            event,
            BroadcastAudioScanServiceEvent::BroadcastCodeRequired(BroadcastId::new(0x040302))
        );

        // Send notification for updating BRS characteristic to indicate it requires sync info.
        // Notification for updating the BRS characteristic value for characteristic with handle 3.
        #[rustfmt::skip]
        fake_peer_service.add_characteristic(
            Characteristic {
                handle: RECEIVE_STATE_3_HANDLE,
                uuid: BROADCAST_RECEIVE_STATE_UUID,
                properties: CharacteristicProperties(vec![
                    CharacteristicProperty::Broadcast,
                    CharacteristicProperty::Notify,
                ]),
                permissions: AttributePermissions::default(),
                descriptors: vec![],
            },
            vec![
                0x03, AddressType::Public as u8,             // source id and address type
                0x03, 0x04, 0x05, 0x06, 0x07, 0x08,          // address
                0x01, 0x03, 0x04, 0x05,                      // ad set id and broadcast id
                PaSyncState::SyncInfoRequest as u8,
                EncryptionStatus::NotEncrypted.raw_value(),
                0x00,                                        // no subgroups
            ],
        );

        // Check that sync info required event was sent out.
        let recv_fut = event_stream.select_next_some();
        let event = block_on(recv_fut).expect("should receive event");
        assert_eq!(
            event,
            BroadcastAudioScanServiceEvent::SyncInfoRequested(BroadcastId::new(0x050403))
        );

        // Stream should be pending since no more notifications.
        assert!(event_stream.poll_next_unpin(&mut noop_cx).is_pending());
    }

    #[test]
    fn set_broadcast_code() {
        let (mut client, mut fake_peer_service) = setup_client();

        // Manually update the internal id tracker for testing purposes.
        // In practice, this would have been updated from BRS value change notification.
        client.id_tracker.lock().update(0x01, BroadcastId::new(0x030201));

        {
            fake_peer_service.expect_characteristic_value(
                &AUDIO_SCAN_CONTROL_POINT_HANDLE,
                vec![0x04, 0x01, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
            );

            let mut noop_cx = futures::task::Context::from_waker(futures::task::noop_waker_ref());
            let set_code_fut = client.set_broadcast_code(BroadcastId::new(0x030201), [1; 16]);
            pin_mut!(set_code_fut);
            let polled = set_code_fut.poll_unpin(&mut noop_cx);
            let Poll::Ready(Ok(_)) = polled else {
                panic!("Expected to succeed");
            };
        }
    }

    #[test]
    fn set_broadcast_code_fails() {
        let (mut client, _) = setup_client();

        {
            let mut noop_cx = futures::task::Context::from_waker(futures::task::noop_waker_ref());
            let set_code_fut = client.set_broadcast_code(BroadcastId::new(0x030201), [1; 16]);
            pin_mut!(set_code_fut);
            let polled = set_code_fut.poll_unpin(&mut noop_cx);

            // Should fail because we cannot get source id for the broadcast id since BRS Characteristic value wasn't updated.
            let Poll::Ready(Err(_)) = polled else {
                panic!("Expected to fail");
            };
        }
    }
}
