Skip to content

Commit 398f299

Browse files
authored
Merge pull request #507 from Lastique/optimize_datatypes
Add x86 SIMD optimizations to crypto datatypes
2 parents 5cee9ea + faa7998 commit 398f299

File tree

3 files changed

+258
-8
lines changed

3 files changed

+258
-8
lines changed

config_in_cmake.h

+11
Original file line numberDiff line numberDiff line change
@@ -122,3 +122,14 @@
122122
#define inline
123123
#endif
124124
#endif
125+
126+
/* Define gcc/clang-style SSE macros on compilers that don't define them (primarilly, MSVC). */
127+
#if !defined(__SSE2__) && (defined(_M_X64) || (defined(_M_IX86_FP) && _M_IX86_FP >= 2))
128+
#define __SSE2__
129+
#endif
130+
#if !defined(__SSSE3__) && defined(__AVX__)
131+
#define __SSSE3__
132+
#endif
133+
#if !defined(__SSE4_1__) && defined(__AVX__)
134+
#define __SSE4_1__
135+
#endif

crypto/include/datatypes.h

+31-5
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,10 @@
6262
#error "Platform not recognized"
6363
#endif
6464

65+
#if defined(__SSE2__)
66+
#include <emmintrin.h>
67+
#endif
68+
6569
#ifdef __cplusplus
6670
extern "C" {
6771
#endif
@@ -90,6 +94,26 @@ void v128_left_shift(v128_t *x, int shift_index);
9094
*
9195
*/
9296

97+
#if defined(__SSE2__)
98+
99+
#define v128_set_to_zero(x) \
100+
(_mm_storeu_si128((__m128i *)(x), _mm_setzero_si128()))
101+
102+
#define v128_copy(x, y) \
103+
(_mm_storeu_si128((__m128i *)(x), _mm_loadu_si128((const __m128i *)(y))))
104+
105+
#define v128_xor(z, x, y) \
106+
(_mm_storeu_si128((__m128i *)(z), \
107+
_mm_xor_si128(_mm_loadu_si128((const __m128i *)(x)), \
108+
_mm_loadu_si128((const __m128i *)(y)))))
109+
110+
#define v128_xor_eq(z, x) \
111+
(_mm_storeu_si128((__m128i *)(z), \
112+
_mm_xor_si128(_mm_loadu_si128((const __m128i *)(x)), \
113+
_mm_loadu_si128((const __m128i *)(z)))))
114+
115+
#else /* defined(__SSE2__) */
116+
93117
#define v128_set_to_zero(x) \
94118
((x)->v32[0] = 0, (x)->v32[1] = 0, (x)->v32[2] = 0, (x)->v32[3] = 0)
95119

@@ -113,6 +137,8 @@ void v128_left_shift(v128_t *x, int shift_index);
113137
((z)->v64[0] ^= (x)->v64[0], (z)->v64[1] ^= (x)->v64[1])
114138
#endif
115139

140+
#endif /* defined(__SSE2__) */
141+
116142
/* NOTE! This assumes an odd ordering! */
117143
/* This will not be compatible directly with math on some processors */
118144
/* bit 0 is first 32-bit word, low order bit. in little-endian, that's
@@ -173,13 +199,11 @@ void octet_string_set_to_zero(void *s, size_t len);
173199
#define be64_to_cpu(x) OSSwapInt64(x)
174200
#else /* WORDS_BIGENDIAN */
175201

176-
#if defined(__GNUC__) && (defined(HAVE_X86) || defined(__x86_64__))
202+
#if defined(__GNUC__)
177203
/* Fall back. */
178204
static inline uint32_t be32_to_cpu(uint32_t v)
179205
{
180-
/* optimized for x86. */
181-
asm("bswap %0" : "=r"(v) : "0"(v));
182-
return v;
206+
return __builtin_bswap32(v);
183207
}
184208
#else /* HAVE_X86 */
185209
#ifdef HAVE_NETINET_IN_H
@@ -192,7 +216,9 @@ static inline uint32_t be32_to_cpu(uint32_t v)
192216

193217
static inline uint64_t be64_to_cpu(uint64_t v)
194218
{
195-
#ifdef NO_64BIT_MATH
219+
#if defined(__GNUC__)
220+
v = __builtin_bswap64(v);
221+
#elif defined(NO_64BIT_MATH)
196222
/* use the make64 functions to do 64-bit math */
197223
v = make64(htonl(low32(v)), htonl(high32(v)));
198224
#else /* NO_64BIT_MATH */

crypto/math/datatypes.c

+216-3
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,16 @@
5353

5454
#include "datatypes.h"
5555

56+
#if defined(__SSE2__)
57+
#include <tmmintrin.h>
58+
#endif
59+
60+
#if defined(_MSC_VER)
61+
#define ALIGNMENT(N) __declspec(align(N))
62+
#else
63+
#define ALIGNMENT(N) __attribute__((aligned(N)))
64+
#endif
65+
5666
/*
5767
* bit_string is a buffer that is used to hold output strings, e.g.
5868
* for printing.
@@ -123,6 +133,9 @@ char *v128_bit_string(v128_t *x)
123133

124134
void v128_copy_octet_string(v128_t *x, const uint8_t s[16])
125135
{
136+
#if defined(__SSE2__)
137+
_mm_storeu_si128((__m128i *)(x), _mm_loadu_si128((const __m128i *)(s)));
138+
#else
126139
#ifdef ALIGNMENT_32BIT_REQUIRED
127140
if ((((uint32_t)&s[0]) & 0x3) != 0)
128141
#endif
@@ -151,8 +164,67 @@ void v128_copy_octet_string(v128_t *x, const uint8_t s[16])
151164
v128_copy(x, v);
152165
}
153166
#endif
167+
#endif /* defined(__SSE2__) */
154168
}
155169

170+
#if defined(__SSSE3__)
171+
172+
/* clang-format off */
173+
174+
ALIGNMENT(16)
175+
static const uint8_t right_shift_masks[5][16] = {
176+
{ 0u, 1u, 2u, 3u, 4u, 5u, 6u, 7u,
177+
8u, 9u, 10u, 11u, 12u, 13u, 14u, 15u },
178+
{ 0x80, 0x80, 0x80, 0x80, 0u, 1u, 2u, 3u,
179+
4u, 5u, 6u, 7u, 8u, 9u, 10u, 11u },
180+
{ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80,
181+
0u, 1u, 2u, 3u, 4u, 5u, 6u, 7u },
182+
{ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80,
183+
0x80, 0x80, 0x80, 0x80, 0u, 1u, 2u, 3u },
184+
/* needed for bitvector_left_shift */
185+
{ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80,
186+
0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80 }
187+
};
188+
189+
ALIGNMENT(16)
190+
static const uint8_t left_shift_masks[4][16] = {
191+
{ 0u, 1u, 2u, 3u, 4u, 5u, 6u, 7u,
192+
8u, 9u, 10u, 11u, 12u, 13u, 14u, 15u },
193+
{ 4u, 5u, 6u, 7u, 8u, 9u, 10u, 11u,
194+
12u, 13u, 14u, 15u, 0x80, 0x80, 0x80, 0x80 },
195+
{ 8u, 9u, 10u, 11u, 12u, 13u, 14u, 15u,
196+
0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80 },
197+
{ 12u, 13u, 14u, 15u, 0x80, 0x80, 0x80, 0x80,
198+
0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80 }
199+
};
200+
201+
/* clang-format on */
202+
203+
void v128_left_shift(v128_t *x, int shift)
204+
{
205+
if (shift > 127) {
206+
v128_set_to_zero(x);
207+
return;
208+
}
209+
210+
const int base_index = shift >> 5;
211+
const int bit_index = shift & 31;
212+
213+
__m128i mm = _mm_loadu_si128((const __m128i *)x);
214+
__m128i mm_shift_right = _mm_cvtsi32_si128(bit_index);
215+
__m128i mm_shift_left = _mm_cvtsi32_si128(32 - bit_index);
216+
mm = _mm_shuffle_epi8(mm, ((const __m128i *)left_shift_masks)[base_index]);
217+
218+
__m128i mm1 = _mm_srl_epi32(mm, mm_shift_right);
219+
__m128i mm2 = _mm_sll_epi32(mm, mm_shift_left);
220+
mm2 = _mm_srli_si128(mm2, 4);
221+
mm1 = _mm_or_si128(mm1, mm2);
222+
223+
_mm_storeu_si128((__m128i *)x, mm1);
224+
}
225+
226+
#else /* defined(__SSSE3__) */
227+
156228
void v128_left_shift(v128_t *x, int shift)
157229
{
158230
int i;
@@ -179,6 +251,8 @@ void v128_left_shift(v128_t *x, int shift)
179251
x->v32[i] = 0;
180252
}
181253

254+
#endif /* defined(__SSSE3__) */
255+
182256
/* functions manipulating bitvector_t */
183257

184258
int bitvector_alloc(bitvector_t *v, unsigned long length)
@@ -190,6 +264,7 @@ int bitvector_alloc(bitvector_t *v, unsigned long length)
190264
(length + bits_per_word - 1) & ~(unsigned long)((bits_per_word - 1));
191265

192266
l = length / bits_per_word * bytes_per_word;
267+
l = (l + 15ul) & ~15ul;
193268

194269
/* allocate memory, then set parameters */
195270
if (l == 0) {
@@ -225,6 +300,73 @@ void bitvector_set_to_zero(bitvector_t *x)
225300
memset(x->word, 0, x->length >> 3);
226301
}
227302

303+
#if defined(__SSSE3__)
304+
305+
void bitvector_left_shift(bitvector_t *x, int shift)
306+
{
307+
if ((uint32_t)shift >= x->length) {
308+
bitvector_set_to_zero(x);
309+
return;
310+
}
311+
312+
const int base_index = shift >> 5;
313+
const int bit_index = shift & 31;
314+
const int vec_length = (x->length + 127u) >> 7;
315+
const __m128i *from = ((const __m128i *)x->word) + (base_index >> 2);
316+
__m128i *to = (__m128i *)x->word;
317+
__m128i *const end = to + vec_length;
318+
319+
__m128i mm_right_shift_mask =
320+
((const __m128i *)right_shift_masks)[4u - (base_index & 3u)];
321+
__m128i mm_left_shift_mask =
322+
((const __m128i *)left_shift_masks)[base_index & 3u];
323+
__m128i mm_shift_right = _mm_cvtsi32_si128(bit_index);
324+
__m128i mm_shift_left = _mm_cvtsi32_si128(32 - bit_index);
325+
326+
__m128i mm_current = _mm_loadu_si128(from);
327+
__m128i mm_current_r = _mm_srl_epi32(mm_current, mm_shift_right);
328+
__m128i mm_current_l = _mm_sll_epi32(mm_current, mm_shift_left);
329+
330+
while ((end - from) >= 2) {
331+
++from;
332+
__m128i mm_next = _mm_loadu_si128(from);
333+
334+
__m128i mm_next_r = _mm_srl_epi32(mm_next, mm_shift_right);
335+
__m128i mm_next_l = _mm_sll_epi32(mm_next, mm_shift_left);
336+
mm_current_l = _mm_alignr_epi8(mm_next_l, mm_current_l, 4);
337+
mm_current = _mm_or_si128(mm_current_r, mm_current_l);
338+
339+
mm_current = _mm_shuffle_epi8(mm_current, mm_left_shift_mask);
340+
341+
__m128i mm_temp_next = _mm_srli_si128(mm_next_l, 4);
342+
mm_temp_next = _mm_or_si128(mm_next_r, mm_temp_next);
343+
344+
mm_temp_next = _mm_shuffle_epi8(mm_temp_next, mm_right_shift_mask);
345+
mm_current = _mm_or_si128(mm_temp_next, mm_current);
346+
347+
_mm_storeu_si128(to, mm_current);
348+
++to;
349+
350+
mm_current_r = mm_next_r;
351+
mm_current_l = mm_next_l;
352+
}
353+
354+
mm_current_l = _mm_srli_si128(mm_current_l, 4);
355+
mm_current = _mm_or_si128(mm_current_r, mm_current_l);
356+
357+
mm_current = _mm_shuffle_epi8(mm_current, mm_left_shift_mask);
358+
359+
_mm_storeu_si128(to, mm_current);
360+
++to;
361+
362+
while (to < end) {
363+
_mm_storeu_si128(to, _mm_setzero_si128());
364+
++to;
365+
}
366+
}
367+
368+
#else /* defined(__SSSE3__) */
369+
228370
void bitvector_left_shift(bitvector_t *x, int shift)
229371
{
230372
int i;
@@ -253,16 +395,82 @@ void bitvector_left_shift(bitvector_t *x, int shift)
253395
x->word[i] = 0;
254396
}
255397

398+
#endif /* defined(__SSSE3__) */
399+
256400
int srtp_octet_string_is_eq(uint8_t *a, uint8_t *b, int len)
257401
{
258-
uint8_t *end = b + len;
259-
uint8_t accumulator = 0;
260-
261402
/*
262403
* We use this somewhat obscure implementation to try to ensure the running
263404
* time only depends on len, even accounting for compiler optimizations.
264405
* The accumulator ends up zero iff the strings are equal.
265406
*/
407+
uint8_t *end = b + len;
408+
uint32_t accumulator = 0;
409+
410+
#if defined(__SSE2__)
411+
__m128i mm_accumulator1 = _mm_setzero_si128();
412+
__m128i mm_accumulator2 = _mm_setzero_si128();
413+
for (int i = 0, n = len >> 5; i < n; ++i, a += 32, b += 32) {
414+
__m128i mm_a1 = _mm_loadu_si128((const __m128i *)a);
415+
__m128i mm_b1 = _mm_loadu_si128((const __m128i *)b);
416+
__m128i mm_a2 = _mm_loadu_si128((const __m128i *)(a + 16));
417+
__m128i mm_b2 = _mm_loadu_si128((const __m128i *)(b + 16));
418+
mm_a1 = _mm_xor_si128(mm_a1, mm_b1);
419+
mm_a2 = _mm_xor_si128(mm_a2, mm_b2);
420+
mm_accumulator1 = _mm_or_si128(mm_accumulator1, mm_a1);
421+
mm_accumulator2 = _mm_or_si128(mm_accumulator2, mm_a2);
422+
}
423+
424+
mm_accumulator1 = _mm_or_si128(mm_accumulator1, mm_accumulator2);
425+
426+
if ((end - b) >= 16) {
427+
__m128i mm_a1 = _mm_loadu_si128((const __m128i *)a);
428+
__m128i mm_b1 = _mm_loadu_si128((const __m128i *)b);
429+
mm_a1 = _mm_xor_si128(mm_a1, mm_b1);
430+
mm_accumulator1 = _mm_or_si128(mm_accumulator1, mm_a1);
431+
a += 16;
432+
b += 16;
433+
}
434+
435+
if ((end - b) >= 8) {
436+
__m128i mm_a1 = _mm_loadl_epi64((const __m128i *)a);
437+
__m128i mm_b1 = _mm_loadl_epi64((const __m128i *)b);
438+
mm_a1 = _mm_xor_si128(mm_a1, mm_b1);
439+
mm_accumulator1 = _mm_or_si128(mm_accumulator1, mm_a1);
440+
a += 8;
441+
b += 8;
442+
}
443+
444+
mm_accumulator1 = _mm_or_si128(
445+
mm_accumulator1, _mm_unpackhi_epi64(mm_accumulator1, mm_accumulator1));
446+
mm_accumulator1 =
447+
_mm_or_si128(mm_accumulator1, _mm_srli_si128(mm_accumulator1, 4));
448+
accumulator = _mm_cvtsi128_si32(mm_accumulator1);
449+
#else
450+
uint32_t accumulator2 = 0;
451+
for (int i = 0, n = len >> 3; i < n; ++i, a += 8, b += 8) {
452+
uint32_t a_val1, b_val1;
453+
uint32_t a_val2, b_val2;
454+
memcpy(&a_val1, a, sizeof(a_val1));
455+
memcpy(&b_val1, b, sizeof(b_val1));
456+
memcpy(&a_val2, a + 4, sizeof(a_val2));
457+
memcpy(&b_val2, b + 4, sizeof(b_val2));
458+
accumulator |= a_val1 ^ b_val1;
459+
accumulator2 |= a_val2 ^ b_val2;
460+
}
461+
462+
accumulator |= accumulator2;
463+
464+
if ((end - b) >= 4) {
465+
uint32_t a_val, b_val;
466+
memcpy(&a_val, a, sizeof(a_val));
467+
memcpy(&b_val, b, sizeof(b_val));
468+
accumulator |= a_val ^ b_val;
469+
a += 4;
470+
b += 4;
471+
}
472+
#endif
473+
266474
while (b < end)
267475
accumulator |= (*a++ ^ *b++);
268476

@@ -272,9 +480,14 @@ int srtp_octet_string_is_eq(uint8_t *a, uint8_t *b, int len)
272480

273481
void srtp_cleanse(void *s, size_t len)
274482
{
483+
#if defined(__GNUC__)
484+
memset(s, 0, len);
485+
__asm__ __volatile__("" : : "r"(s) : "memory");
486+
#else
275487
volatile unsigned char *p = (volatile unsigned char *)s;
276488
while (len--)
277489
*p++ = 0;
490+
#endif
278491
}
279492

280493
void octet_string_set_to_zero(void *s, size_t len)

0 commit comments

Comments
 (0)