Skip to content

Commit

Permalink
Switched mid and rid to SmolStr
Browse files Browse the repository at this point in the history
  • Loading branch information
gautamprikshit1 authored and rainliu committed Jun 3, 2023
1 parent 7bf1438 commit f57676f
Show file tree
Hide file tree
Showing 9 changed files with 96 additions and 52 deletions.
1 change: 1 addition & 0 deletions webrtc/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ bytes = "1"
thiserror = "1.0"
waitgroup = "0.1.2"
regex = "1.7.1"
smol_str = "0.2.0"
url = "2.2"
rustls = { version = "0.19.0", features = ["dangerous_configuration"]}
rcgen = { version = "0.10.0", features = ["pem", "x509-parser"]}
Expand Down
24 changes: 24 additions & 0 deletions webrtc/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ pub use rtcp;
pub use rtp;
pub use sctp;
pub use sdp;
use serde::{Deserialize, Serialize};
use smol_str::SmolStr;
pub use srtp;
pub use stun;
pub use turn;
Expand Down Expand Up @@ -43,3 +45,25 @@ pub(crate) const SDP_ATTRIBUTE_RID: &str = "rid";
pub(crate) const GENERATED_CERTIFICATE_ORIGIN: &str = "WebRTC";
pub(crate) const SDES_REPAIR_RTP_STREAM_ID_URI: &str =
"urn:ietf:params:rtp-hdrext:sdes:repaired-rtp-stream-id";

#[derive(Clone, Debug, Default, PartialEq)]
pub struct SmallStr(SmolStr);

impl Serialize for SmallStr {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str(self.0.as_str())
}
}

impl<'de> Deserialize<'de> for SmallStr {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
Ok(SmallStr(SmolStr::new(s)))
}
}
50 changes: 30 additions & 20 deletions webrtc/src/peer_connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ use crate::sctp_transport::RTCSctpTransport;
use crate::stats::StatsReport;
use crate::track::track_local::TrackLocal;
use crate::track::track_remote::TrackRemote;
use crate::SmallStr;

use ::ice::candidate::candidate_base::unmarshal_candidate;
use ::ice::candidate::Candidate;
Expand All @@ -77,6 +78,7 @@ use interceptor::{stats, Attributes, Interceptor, RTCPWriter};
use peer_connection_internal::*;
use rand::{thread_rng, Rng};
use rcgen::KeyPair;
use smol_str::SmolStr;
use srtp::stream::Stream;
use std::future::Future;
use std::pin::Pin;
Expand Down Expand Up @@ -454,7 +456,9 @@ impl RTCPeerConnection {
// return true
// }
let mid = t.mid();
let m = mid.as_ref().and_then(|mid| get_by_mid(mid, local_desc));
let m = mid
.as_ref()
.and_then(|mid| get_by_mid(mid.0.as_str(), local_desc));
// Step 5.2
if !t.stopped.load(Ordering::SeqCst) {
if m.is_none() {
Expand Down Expand Up @@ -493,8 +497,9 @@ impl RTCPeerConnection {
RTCSdpType::Offer => {
// Step 5.3.2
if let Some(remote_desc) = &current_remote_description {
if let Some(rm) =
t.mid().and_then(|mid| get_by_mid(&mid, remote_desc))
if let Some(rm) = t
.mid()
.and_then(|mid| get_by_mid(mid.0.as_str(), remote_desc))
{
if get_peer_direction(m) != t.direction()
&& get_peer_direction(rm) != t.direction().reverse()
Expand All @@ -511,18 +516,20 @@ impl RTCPeerConnection {
Some(d) => d,
None => return true,
};
let offered_direction =
match t.mid().and_then(|mid| get_by_mid(&mid, remote_desc)) {
Some(d) => {
let dir = get_peer_direction(d);
if dir == RTCRtpTransceiverDirection::Unspecified {
RTCRtpTransceiverDirection::Inactive
} else {
dir
}
let offered_direction = match t
.mid()
.and_then(|mid| get_by_mid(mid.0.as_str(), remote_desc))
{
Some(d) => {
let dir = get_peer_direction(d);
if dir == RTCRtpTransceiverDirection::Unspecified {
RTCRtpTransceiverDirection::Inactive
} else {
dir
}
None => RTCRtpTransceiverDirection::Inactive,
};
}
None => RTCRtpTransceiverDirection::Inactive,
};

let current_direction = get_peer_direction(m);
// Step 5.3.3
Expand All @@ -544,8 +551,8 @@ impl RTCPeerConnection {
};

if let Some(remote_desc) = &*params.current_remote_description.lock().await {
return get_by_mid(&search_mid, local_desc).is_some()
|| get_by_mid(&search_mid, remote_desc).is_some();
return get_by_mid(search_mid.0.as_str(), local_desc).is_some()
|| get_by_mid(search_mid.0.as_str(), remote_desc).is_some();
}
}
}
Expand Down Expand Up @@ -795,10 +802,13 @@ impl RTCPeerConnection {
}
}

t.set_mid(mid)?;
t.set_mid(crate::SmallStr(SmolStr::from(mid)))?;
} else {
let greater_mid = self.internal.greater_mid.fetch_add(1, Ordering::SeqCst);
t.set_mid(format!("{}", greater_mid + 1))?;
t.set_mid(crate::SmallStr(SmolStr::from(format!(
"{}",
greater_mid + 1
))))?;
}
}

Expand Down Expand Up @@ -1358,7 +1368,7 @@ impl RTCPeerConnection {

if let Some(t) = t {
if t.mid().is_none() {
t.set_mid(mid_value.to_owned())?;
t.set_mid(crate::SmallStr(SmolStr::from(mid_value)))?;
}
} else {
let local_direction =
Expand Down Expand Up @@ -1404,7 +1414,7 @@ impl RTCPeerConnection {
self.internal.add_rtp_transceiver(Arc::clone(&t)).await;

if t.mid().is_none() {
t.set_mid(mid_value.to_owned())?;
t.set_mid(SmallStr(SmolStr::from(mid_value)))?;
}
}
}
Expand Down
23 changes: 12 additions & 11 deletions webrtc/src/peer_connection/peer_connection_internal.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use bytes::Bytes;
use smol_str::SmolStr;
use tokio::time::Instant;

use super::*;
Expand All @@ -9,7 +10,7 @@ use crate::stats::{
StatsReportType,
};
use crate::track::TrackStream;
use crate::{SDES_REPAIR_RTP_STREAM_ID_URI, SDP_ATTRIBUTE_RID};
use crate::{SmallStr, SDES_REPAIR_RTP_STREAM_ID_URI, SDP_ATTRIBUTE_RID};
use arc_swap::ArcSwapOption;
use std::collections::VecDeque;
use std::sync::atomic::AtomicIsize;
Expand Down Expand Up @@ -177,7 +178,7 @@ impl PeerConnectionInternal {
for t in tracks {
if !t.rid().is_empty() {
if let Some(details) =
track_details_for_rid(&track_details, t.rid().to_owned())
track_details_for_rid(&track_details, SmallStr(SmolStr::from(t.rid())))
{
t.set_id(details.id.clone());
t.set_stream_id(details.stream_id.clone());
Expand Down Expand Up @@ -668,7 +669,7 @@ impl PeerConnectionInternal {
// TODO: This is dubious because of rollbacks.
t.sender().set_negotiated();
media_sections.push(MediaSection {
id: t.mid().unwrap(),
id: t.mid().unwrap().0.to_string(),
transceivers: vec![Arc::clone(t)],
..Default::default()
});
Expand Down Expand Up @@ -782,7 +783,7 @@ impl PeerConnectionInternal {
for t in &local_transceivers {
t.sender().set_negotiated();
media_sections.push(MediaSection {
id: t.mid().unwrap(),
id: t.mid().unwrap().0.to_string(),
transceivers: vec![Arc::clone(t)],
..Default::default()
});
Expand Down Expand Up @@ -1000,7 +1001,7 @@ impl PeerConnectionInternal {

let transceivers = self.rtp_transceivers.lock().await;
for t in &*transceivers {
if t.mid().as_ref() != Some(&mid) {
if t.mid().as_ref() != Some(&SmallStr(SmolStr::from(&mid))) {
continue;
}

Expand All @@ -1024,7 +1025,7 @@ impl PeerConnectionInternal {

let track = receiver
.receive_for_rid(
rid,
SmallStr(SmolStr::from(rid)),
params,
TrackStream {
stream_info: Some(stream_info.clone()),
Expand Down Expand Up @@ -1163,7 +1164,7 @@ impl PeerConnectionInternal {
pub(super) async fn has_local_description_changed(&self, desc: &RTCSessionDescription) -> bool {
let rtp_transceivers = self.rtp_transceivers.lock().await;
for t in &*rtp_transceivers {
let m = match t.mid().and_then(|mid| get_by_mid(&mid, desc)) {
let m = match t.mid().and_then(|mid| get_by_mid(mid.0.as_str(), desc)) {
Some(m) => m,
None => return true,
};
Expand Down Expand Up @@ -1200,7 +1201,7 @@ impl PeerConnectionInternal {
// TODO: There's a lot of await points here that could run concurrently with `futures::join_all`.
struct TrackInfo {
ssrc: SSRC,
mid: String,
mid: SmallStr,
track_id: String,
kind: &'static str,
}
Expand Down Expand Up @@ -1325,8 +1326,8 @@ impl PeerConnectionInternal {
struct TrackInfo {
track_id: String,
ssrc: SSRC,
mid: String,
rid: Option<String>,
mid: SmallStr,
rid: Option<SmallStr>,
kind: &'static str,
}
let mut track_infos = vec![];
Expand All @@ -1353,7 +1354,7 @@ impl PeerConnectionInternal {
track_infos.push(TrackInfo {
track_id,
ssrc: sender.ssrc,
mid: mid.clone(),
mid,
rid: None,
kind,
});
Expand Down
18 changes: 10 additions & 8 deletions webrtc/src/peer_connection/sdp/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,15 @@ pub mod sdp_type;
pub mod session_description;

use crate::peer_connection::MEDIA_SECTION_APPLICATION;
use crate::SDP_ATTRIBUTE_RID;
use crate::{SmallStr, SDP_ATTRIBUTE_RID};
use ice::candidate::candidate_base::unmarshal_candidate;
use ice::candidate::Candidate;
use sdp::description::common::{Address, ConnectionInformation};
use sdp::description::media::{MediaDescription, MediaName, RangedPort};
use sdp::description::session::*;
use sdp::extmap::ExtMap;
use sdp::util::ConnectionRole;
use smol_str::SmolStr;
use std::collections::HashMap;
use std::convert::From;
use std::io::BufReader;
Expand All @@ -37,13 +38,13 @@ use url::Url;
/// This isn't keyed by SSRC because it also needs to support rid based sources
#[derive(Default, Debug, Clone)]
pub(crate) struct TrackDetails {
pub(crate) mid: String,
pub(crate) mid: SmallStr,
pub(crate) kind: RTPCodecType,
pub(crate) stream_id: String,
pub(crate) id: String,
pub(crate) ssrcs: Vec<SSRC>,
pub(crate) repair_ssrc: SSRC,
pub(crate) rids: Vec<String>,
pub(crate) rids: Vec<SmallStr>,
}

pub(crate) fn track_details_for_ssrc(
Expand All @@ -55,7 +56,7 @@ pub(crate) fn track_details_for_ssrc(

pub(crate) fn track_details_for_rid(
track_details: &[TrackDetails],
rid: String,
rid: SmallStr,
) -> Option<&TrackDetails> {
track_details.iter().find(|x| x.rids.contains(&rid))
}
Expand Down Expand Up @@ -185,15 +186,16 @@ pub(crate) fn track_details_from_sdp(
}

if track_idx < tracks_in_media_section.len() {
tracks_in_media_section[track_idx].mid = mid_value.to_owned();
tracks_in_media_section[track_idx].mid =
SmallStr(SmolStr::from(mid_value));
tracks_in_media_section[track_idx].kind = codec_type;
tracks_in_media_section[track_idx].stream_id = stream_id.to_owned();
tracks_in_media_section[track_idx].id = track_id.to_owned();
tracks_in_media_section[track_idx].ssrcs = vec![ssrc];
tracks_in_media_section[track_idx].repair_ssrc = repair_ssrc;
} else {
let track_details = TrackDetails {
mid: mid_value.to_owned(),
mid: SmallStr(SmolStr::from(mid_value)),
kind: codec_type,
stream_id: stream_id.to_owned(),
id: track_id.to_owned(),
Expand All @@ -212,15 +214,15 @@ pub(crate) fn track_details_from_sdp(
let rids = get_rids(media);
if !rids.is_empty() && !track_id.is_empty() && !stream_id.is_empty() {
let mut simulcast_track = TrackDetails {
mid: mid_value.to_owned(),
mid: SmallStr(SmolStr::from(mid_value)),
kind: codec_type,
stream_id: stream_id.to_owned(),
id: track_id.to_owned(),
rids: vec![],
..Default::default()
};
for rid in rids.keys() {
simulcast_track.rids.push(rid.to_owned());
simulcast_track.rids.push(SmallStr(SmolStr::from(rid)));
}
if simulcast_track.rids.len() == tracks_in_media_section.len() {
for track in &tracks_in_media_section {
Expand Down
12 changes: 7 additions & 5 deletions webrtc/src/rtp_transceiver/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ use interceptor::{
stream_info::{RTPHeaderExtension, StreamInfo},
Attributes,
};
use smol_str::SmolStr;

use crate::SmallStr;
use log::trace;
use serde::{Deserialize, Serialize};
use std::fmt;
Expand Down Expand Up @@ -95,7 +97,7 @@ pub struct RTCRtpRtxParameters {
/// <http://draft.ortc.org/#dom-rtcrtpcodingparameters>
#[derive(Default, Debug, Clone, Serialize, Deserialize)]
pub struct RTCRtpCodingParameters {
pub rid: String,
pub rid: SmallStr,
pub ssrc: SSRC,
pub payload_type: PayloadType,
pub rtx: RTCRtpRtxParameters,
Expand Down Expand Up @@ -174,7 +176,7 @@ pub type TriggerNegotiationNeededFnOption =

/// RTPTransceiver represents a combination of an RTPSender and an RTPReceiver that share a common mid.
pub struct RTCRtpTransceiver {
mid: OnceCell<String>, //atomic.Value
mid: OnceCell<SmallStr>, //atomic.Value
sender: SyncMutex<Arc<RTCRtpSender>>, //atomic.Value
receiver: SyncMutex<Arc<RTCRtpReceiver>>, //atomic.Value

Expand Down Expand Up @@ -293,14 +295,14 @@ impl RTCRtpTransceiver {
}

/// set_mid sets the RTPTransceiver's mid. If it was already set, will return an error.
pub(crate) fn set_mid(&self, mid: String) -> Result<()> {
pub(crate) fn set_mid(&self, mid: SmallStr) -> Result<()> {
self.mid
.set(mid)
.map_err(|_| Error::ErrRTPTransceiverCannotChangeMid)
}

/// mid gets the Transceiver's mid value. When not already set, this value will be set in CreateOffer or create_answer.
pub fn mid(&self) -> Option<String> {
pub fn mid(&self) -> Option<SmallStr> {
self.mid.get().map(Clone::clone)
}

Expand Down Expand Up @@ -479,7 +481,7 @@ pub(crate) async fn find_by_mid(
local_transceivers: &mut Vec<Arc<RTCRtpTransceiver>>,
) -> Option<Arc<RTCRtpTransceiver>> {
for (i, t) in local_transceivers.iter().enumerate() {
if t.mid().as_deref() == Some(mid) {
if t.mid() == Some(SmallStr(SmolStr::from(mid))) {
return Some(local_transceivers.remove(i));
}
}
Expand Down
Loading

0 comments on commit f57676f

Please sign in to comment.