60
60
#include "aes_icm_ext.h"
61
61
#endif
62
62
63
+ #include <stddef.h>
64
+ #include <string.h>
63
65
#include <limits.h>
64
66
#ifdef HAVE_NETINET_IN_H
65
67
#include <netinet/in.h>
66
68
#elif defined(HAVE_WINSOCK2_H )
67
69
#include <winsock2.h>
68
70
#endif
69
71
72
+ #if defined(__SSE2__ )
73
+ #include <emmintrin.h>
74
+ #if defined(_MSC_VER )
75
+ #include <intrin.h>
76
+ #endif
77
+ #endif
78
+
70
79
/* the debug module for srtp */
71
80
srtp_debug_module_t mod_srtp = {
72
81
0 , /* debugging is off by default */
@@ -79,6 +88,17 @@ srtp_debug_module_t mod_srtp = {
79
88
#define uint32s_in_rtcp_header 2
80
89
#define octets_in_rtp_extn_hdr 4
81
90
91
+ #ifndef SRTP_NO_STREAM_LIST
92
+ static inline uint32_t srtp_stream_list_size (srtp_stream_list_t list );
93
+ static srtp_err_status_t srtp_stream_list_reserve (srtp_stream_list_t list ,
94
+ uint32_t new_capacity );
95
+ static uint32_t srtp_stream_list_find (srtp_stream_list_t list ,
96
+ uint32_t ssrc );
97
+ static inline srtp_stream_t srtp_stream_list_get_at (srtp_stream_list_t list ,
98
+ uint32_t pos );
99
+ static void srtp_stream_list_remove_at (srtp_stream_list_t list , uint32_t pos );
100
+ #endif // SRTP_NO_STREAM_LIST
101
+
82
102
static srtp_err_status_t srtp_validate_rtp_header (void * rtp_hdr ,
83
103
int * pkt_octet_len )
84
104
{
@@ -3030,18 +3050,31 @@ srtp_err_status_t srtp_remove_stream(srtp_t session, uint32_t ssrc)
3030
3050
{
3031
3051
srtp_stream_ctx_t * stream ;
3032
3052
srtp_err_status_t status ;
3053
+ #if !defined(SRTP_NO_STREAM_LIST )
3054
+ uint32_t pos ;
3055
+ #endif
3033
3056
3034
3057
/* sanity check arguments */
3035
- if (session == NULL )
3058
+ if (session == NULL ) {
3036
3059
return srtp_err_status_bad_param ;
3060
+ }
3037
3061
3038
3062
/* find and remove stream from the list */
3063
+ #if !defined(SRTP_NO_STREAM_LIST )
3064
+ pos = srtp_stream_list_find (session -> stream_list , ssrc );
3065
+ if (pos >= srtp_stream_list_size (session -> stream_list ))
3066
+ return srtp_err_status_no_ctx ;
3067
+
3068
+ stream = srtp_stream_list_get_at (session -> stream_list , pos );
3069
+ srtp_stream_list_remove_at (session -> stream_list , pos );
3070
+ #else
3039
3071
stream = srtp_stream_list_get (session -> stream_list , ssrc );
3040
3072
if (stream == NULL ) {
3041
3073
return srtp_err_status_no_ctx ;
3042
3074
}
3043
3075
3044
3076
srtp_stream_list_remove (session -> stream_list , stream );
3077
+ #endif
3045
3078
3046
3079
/* deallocate the stream */
3047
3080
status = srtp_stream_dealloc (stream , session -> stream_template );
@@ -4840,11 +4873,11 @@ srtp_err_status_t srtp_get_stream_roc(srtp_t session,
4840
4873
4841
4874
#ifndef SRTP_NO_STREAM_LIST
4842
4875
4843
- /* in the default implementation, we have an intrusive doubly-linked list */
4844
4876
typedef struct srtp_stream_list_ctx_t_ {
4845
- /* a stub stream that just holds pointers to the beginning and end of the
4846
- * list */
4847
- srtp_stream_ctx_t data ;
4877
+ uint32_t * ssrcs ;
4878
+ srtp_stream_ctx_t * * streams ;
4879
+ uint32_t size ;
4880
+ uint32_t capacity ;
4848
4881
} srtp_stream_list_ctx_t_ ;
4849
4882
4850
4883
srtp_err_status_t srtp_stream_list_alloc (srtp_stream_list_t * list_ptr )
@@ -4855,73 +4888,204 @@ srtp_err_status_t srtp_stream_list_alloc(srtp_stream_list_t *list_ptr)
4855
4888
return srtp_err_status_alloc_fail ;
4856
4889
}
4857
4890
4858
- list -> data .next = NULL ;
4859
- list -> data .prev = NULL ;
4860
-
4861
4891
* list_ptr = list ;
4862
4892
return srtp_err_status_ok ;
4863
4893
}
4864
4894
4865
4895
srtp_err_status_t srtp_stream_list_dealloc (srtp_stream_list_t list )
4866
4896
{
4867
4897
/* list must be empty */
4868
- if (list -> data . next ) {
4898
+ if (list -> size != 0u ) {
4869
4899
return srtp_err_status_fail ;
4870
4900
}
4901
+ srtp_crypto_free (list -> streams );
4902
+ srtp_crypto_free (list -> ssrcs );
4871
4903
srtp_crypto_free (list );
4872
4904
return srtp_err_status_ok ;
4873
4905
}
4874
4906
4907
+ static inline uint32_t srtp_stream_list_size (srtp_stream_list_t list )
4908
+ {
4909
+ return list -> size ;
4910
+ }
4911
+
4912
+ static srtp_err_status_t srtp_stream_list_reserve (srtp_stream_list_t list ,
4913
+ uint32_t new_capacity )
4914
+ {
4915
+ if (new_capacity > list -> capacity ) {
4916
+ uint32_t * ssrcs ;
4917
+ srtp_stream_ctx_t * * stream_ptrs ;
4918
+
4919
+ if (new_capacity > (UINT32_MAX - 15u ))
4920
+ return srtp_err_status_alloc_fail ;
4921
+
4922
+ new_capacity = (new_capacity + 15u ) & ~((uint32_t )15u );
4923
+
4924
+ ssrcs = (uint32_t * )srtp_crypto_alloc ((size_t )new_capacity *
4925
+ sizeof (uint32_t ));
4926
+ if (!ssrcs )
4927
+ return srtp_err_status_alloc_fail ;
4928
+ stream_ptrs = (srtp_stream_ctx_t * * )srtp_crypto_alloc (
4929
+ (size_t )new_capacity * sizeof (srtp_stream_ctx_t * ));
4930
+ if (!stream_ptrs ) {
4931
+ srtp_crypto_free (ssrcs );
4932
+ return srtp_err_status_alloc_fail ;
4933
+ }
4934
+
4935
+ if (list -> size > 0u ) {
4936
+ memcpy (ssrcs , list -> ssrcs , (size_t )list -> size * sizeof (uint32_t ));
4937
+ memcpy (stream_ptrs , list -> streams ,
4938
+ (size_t )list -> size * sizeof (srtp_stream_ctx_t * ));
4939
+ }
4940
+
4941
+ srtp_crypto_free (list -> ssrcs );
4942
+ srtp_crypto_free (list -> streams );
4943
+ list -> streams = stream_ptrs ;
4944
+ list -> ssrcs = ssrcs ;
4945
+
4946
+ list -> capacity = new_capacity ;
4947
+ }
4948
+
4949
+ return srtp_err_status_ok ;
4950
+ }
4951
+
4875
4952
srtp_err_status_t srtp_stream_list_insert (srtp_stream_list_t list ,
4876
4953
srtp_stream_t stream )
4877
4954
{
4878
- /* insert at the head of the list */
4879
- stream -> next = list -> data . next ;
4880
- if (stream -> next != NULL ) {
4881
- stream -> next -> prev = stream ;
4882
- }
4883
- list -> data . next = stream ;
4884
- stream -> prev = & ( list -> data ) ;
4955
+ uint32_t pos ;
4956
+ srtp_err_status_t status = srtp_stream_list_reserve ( list , list -> size + 1u ) ;
4957
+ if (status )
4958
+ return status ;
4959
+ pos = list -> size ++ ;
4960
+ list -> ssrcs [ pos ] = stream -> ssrc ;
4961
+ list -> streams [ pos ] = stream ;
4885
4962
4886
4963
return srtp_err_status_ok ;
4887
4964
}
4888
4965
4889
- srtp_stream_t srtp_stream_list_get (srtp_stream_list_t list , uint32_t ssrc )
4966
+ static uint32_t srtp_stream_list_find (srtp_stream_list_t list , uint32_t ssrc )
4890
4967
{
4891
- /* walk down list until ssrc is found */
4892
- srtp_stream_t stream = list -> data .next ;
4893
- while (stream != NULL ) {
4894
- if (stream -> ssrc == ssrc ) {
4895
- return stream ;
4968
+ #if defined(__SSE2__ )
4969
+ const uint32_t * const ssrcs = list -> ssrcs ;
4970
+ const __m128i mm_ssrc = _mm_set1_epi32 (ssrc );
4971
+ uint32_t pos = 0u , n = (list -> size + 7u ) & ~(uint32_t )(7u );
4972
+ for (uint32_t m = n & ~(uint32_t )(15u ); pos < m ; pos += 16u ) {
4973
+ __m128i mm1 = _mm_loadu_si128 ((const __m128i * )(ssrcs + pos ));
4974
+ __m128i mm2 = _mm_loadu_si128 ((const __m128i * )(ssrcs + pos + 4u ));
4975
+ __m128i mm3 = _mm_loadu_si128 ((const __m128i * )(ssrcs + pos + 8u ));
4976
+ __m128i mm4 = _mm_loadu_si128 ((const __m128i * )(ssrcs + pos + 12u ));
4977
+ mm1 = _mm_cmpeq_epi32 (mm1 , mm_ssrc );
4978
+ mm2 = _mm_cmpeq_epi32 (mm2 , mm_ssrc );
4979
+ mm3 = _mm_cmpeq_epi32 (mm3 , mm_ssrc );
4980
+ mm4 = _mm_cmpeq_epi32 (mm4 , mm_ssrc );
4981
+ mm1 = _mm_packs_epi32 (mm1 , mm2 );
4982
+ mm3 = _mm_packs_epi32 (mm3 , mm4 );
4983
+ mm1 = _mm_packs_epi16 (mm1 , mm3 );
4984
+ uint32_t mask = _mm_movemask_epi8 (mm1 );
4985
+ if (mask ) {
4986
+ #if defined(_MSC_VER )
4987
+ unsigned long bit_pos ;
4988
+ _BitScanForward (& bit_pos , mask );
4989
+ pos += bit_pos ;
4990
+ #else
4991
+ pos += __builtin_ctz (mask );
4992
+ #endif
4993
+
4994
+ goto done ;
4995
+ }
4996
+ }
4997
+
4998
+ if (pos < n ) {
4999
+ __m128i mm1 = _mm_loadu_si128 ((const __m128i * )(ssrcs + pos ));
5000
+ __m128i mm2 = _mm_loadu_si128 ((const __m128i * )(ssrcs + pos + 4u ));
5001
+ mm1 = _mm_cmpeq_epi32 (mm1 , mm_ssrc );
5002
+ mm2 = _mm_cmpeq_epi32 (mm2 , mm_ssrc );
5003
+ mm1 = _mm_packs_epi32 (mm1 , mm2 );
5004
+
5005
+ uint32_t mask = _mm_movemask_epi8 (mm1 );
5006
+ if (mask ) {
5007
+ #if defined(_MSC_VER )
5008
+ unsigned long bit_pos ;
5009
+ _BitScanForward (& bit_pos , mask );
5010
+ pos += bit_pos / 2u ;
5011
+ #else
5012
+ pos += __builtin_ctz (mask ) / 2u ;
5013
+ #endif
5014
+ goto done ;
4896
5015
}
4897
- stream = stream -> next ;
5016
+
5017
+ pos += 8u ;
5018
+ }
5019
+
5020
+ done :
5021
+ return pos ;
5022
+ #else
5023
+ /* walk down list until ssrc is found */
5024
+ uint32_t pos = 0u , n = list -> size ;
5025
+ for (; pos < n ; ++ pos ) {
5026
+ if (list -> ssrcs [pos ] == ssrc )
5027
+ break ;
4898
5028
}
4899
5029
5030
+ return pos ;
5031
+ #endif
5032
+ }
5033
+
5034
+ static inline srtp_stream_t srtp_stream_list_get_at (srtp_stream_list_t list ,
5035
+ uint32_t pos )
5036
+ {
5037
+ return list -> streams [pos ];
5038
+ }
5039
+
5040
+ srtp_stream_t srtp_stream_list_get (srtp_stream_list_t list , uint32_t ssrc )
5041
+ {
5042
+ uint32_t pos = srtp_stream_list_find (list , ssrc );
5043
+ if (pos < list -> size )
5044
+ return list -> streams [pos ];
5045
+
4900
5046
/* we haven't found our ssrc, so return a null */
4901
5047
return NULL ;
4902
5048
}
4903
5049
4904
- void srtp_stream_list_remove (srtp_stream_list_t list ,
4905
- srtp_stream_t stream_to_remove )
5050
+ static void srtp_stream_list_remove_at (srtp_stream_list_t list , uint32_t pos )
4906
5051
{
4907
- ( void ) list ;
5052
+ uint32_t tail_size , last_pos ;
4908
5053
4909
- stream_to_remove -> prev -> next = stream_to_remove -> next ;
4910
- if (stream_to_remove -> next != NULL ) {
4911
- stream_to_remove -> next -> prev = stream_to_remove -> prev ;
5054
+ last_pos = -- list -> size ;
5055
+ tail_size = last_pos - pos ;
5056
+ if (tail_size > 0u ) {
5057
+ memmove (list -> streams + pos , list -> streams + pos + 1 ,
5058
+ (size_t )tail_size * sizeof (* list -> streams ));
5059
+ memmove (list -> ssrcs + pos , list -> ssrcs + pos + 1 ,
5060
+ (size_t )tail_size * sizeof (* list -> ssrcs ));
4912
5061
}
5062
+
5063
+ list -> streams [last_pos ] = NULL ;
5064
+ list -> ssrcs [last_pos ] = 0u ;
5065
+ }
5066
+
5067
+ void srtp_stream_list_remove (srtp_stream_list_t list ,
5068
+ srtp_stream_t stream_to_remove )
5069
+ {
5070
+ uint32_t pos = srtp_stream_list_find (list , stream_to_remove -> ssrc );
5071
+ if (pos < list -> size )
5072
+ srtp_stream_list_remove_at (list , pos );
4913
5073
}
4914
5074
4915
5075
void srtp_stream_list_for_each (srtp_stream_list_t list ,
4916
5076
int (* callback )(srtp_stream_t , void * ),
4917
5077
void * data )
4918
5078
{
4919
- srtp_stream_t stream = list -> data .next ;
4920
- while (stream != NULL ) {
4921
- srtp_stream_t tmp = stream ;
4922
- stream = stream -> next ;
4923
- if (callback (tmp , data ))
5079
+ uint32_t size = list -> size ;
5080
+ for (uint32_t i = 0u ; i < size ;) {
5081
+ if (callback (list -> streams [i ], data ))
4924
5082
break ;
5083
+
5084
+ /* check if the callback removed the current element */
5085
+ if (size == list -> size )
5086
+ ++ i ;
5087
+ else
5088
+ size = list -> size ;
4925
5089
}
4926
5090
}
4927
5091
0 commit comments