rsa.c 49 KB
Newer Older
1 2 3
/*
 *  The RSA public-key cryptosystem
 *
4
 *  Copyright (C) 2006-2014, ARM Limited, All Rights Reserved
5
 *
6
 *  This file is part of mbed TLS (https://tls.mbed.org)
7
 *
8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28
 *  This program is free software; you can redistribute it and/or modify
 *  it under the terms of the GNU General Public License as published by
 *  the Free Software Foundation; either version 2 of the License, or
 *  (at your option) any later version.
 *
 *  This program is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU General Public License for more details.
 *
 *  You should have received a copy of the GNU General Public License along
 *  with this program; if not, write to the Free Software Foundation, Inc.,
 *  51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
 */
/*
 *  RSA was designed by Ron Rivest, Adi Shamir and Len Adleman.
 *
 *  http://theory.lcs.mit.edu/~rivest/rsapaper.pdf
 *  http://www.cacr.math.uwaterloo.ca/hac/about/chap8.pdf
 */

29
#if !defined(MBEDTLS_CONFIG_FILE)
30
#include "mbedtls/config.h"
31
#else
32
#include MBEDTLS_CONFIG_FILE
33
#endif
34

35
#if defined(MBEDTLS_RSA_C)
36

37 38
#include "mbedtls/rsa.h"
#include "mbedtls/oid.h"
39

40 41
#include <string.h>

42
#if defined(MBEDTLS_PKCS1_V21)
43
#include "mbedtls/md.h"
44
#endif
45

46
#if defined(MBEDTLS_PKCS1_V15) && !defined(__OpenBSD__)
47
#include <stdlib.h>
48
#endif
49

50
#if defined(MBEDTLS_PLATFORM_C)
51
#include "mbedtls/platform.h"
52
#else
53
#include <stdio.h>
54
#define mbedtls_printf printf
55 56
#endif

57 58 59
/*
 * Initialize an RSA context
 */
60
void mbedtls_rsa_init( mbedtls_rsa_context *ctx,
61
               int padding,
62
               int hash_id )
63
{
64
    memset( ctx, 0, sizeof( mbedtls_rsa_context ) );
65

66
    mbedtls_rsa_set_padding( ctx, padding, hash_id );
67

68 69
#if defined(MBEDTLS_THREADING_C)
    mbedtls_mutex_init( &ctx->mutex );
70
#endif
71 72
}

73 74 75
/*
 * Set padding for an existing RSA context
 */
76
void mbedtls_rsa_set_padding( mbedtls_rsa_context *ctx, int padding, int hash_id )
77 78 79 80 81
{
    ctx->padding = padding;
    ctx->hash_id = hash_id;
}

82
#if defined(MBEDTLS_GENPRIME)
83 84 85 86

/*
 * Generate an RSA keypair
 */
87
int mbedtls_rsa_gen_key( mbedtls_rsa_context *ctx,
88 89 90
                 int (*f_rng)(void *, unsigned char *, size_t),
                 void *p_rng,
                 unsigned int nbits, int exponent )
91 92
{
    int ret;
93
    mbedtls_mpi P1, Q1, H, G;
94

95
    if( f_rng == NULL || nbits < 128 || exponent < 3 )
96
        return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
97

98
    mbedtls_mpi_init( &P1 ); mbedtls_mpi_init( &Q1 ); mbedtls_mpi_init( &H ); mbedtls_mpi_init( &G );
99 100 101 102 103

    /*
     * find primes P and Q with Q < P so that:
     * GCD( E, (P-1)*(Q-1) ) == 1
     */
104
    MBEDTLS_MPI_CHK( mbedtls_mpi_lset( &ctx->E, exponent ) );
105 106 107

    do
    {
108
        MBEDTLS_MPI_CHK( mbedtls_mpi_gen_prime( &ctx->P, ( nbits + 1 ) >> 1, 0,
109
                                f_rng, p_rng ) );
110

111
        MBEDTLS_MPI_CHK( mbedtls_mpi_gen_prime( &ctx->Q, ( nbits + 1 ) >> 1, 0,
112
                                f_rng, p_rng ) );
113

114 115
        if( mbedtls_mpi_cmp_mpi( &ctx->P, &ctx->Q ) < 0 )
            mbedtls_mpi_swap( &ctx->P, &ctx->Q );
116

117
        if( mbedtls_mpi_cmp_mpi( &ctx->P, &ctx->Q ) == 0 )
118 119
            continue;

120 121
        MBEDTLS_MPI_CHK( mbedtls_mpi_mul_mpi( &ctx->N, &ctx->P, &ctx->Q ) );
        if( mbedtls_mpi_msb( &ctx->N ) != nbits )
122 123
            continue;

124 125 126 127
        MBEDTLS_MPI_CHK( mbedtls_mpi_sub_int( &P1, &ctx->P, 1 ) );
        MBEDTLS_MPI_CHK( mbedtls_mpi_sub_int( &Q1, &ctx->Q, 1 ) );
        MBEDTLS_MPI_CHK( mbedtls_mpi_mul_mpi( &H, &P1, &Q1 ) );
        MBEDTLS_MPI_CHK( mbedtls_mpi_gcd( &G, &ctx->E, &H  ) );
128
    }
129
    while( mbedtls_mpi_cmp_int( &G, 1 ) != 0 );
130 131 132 133 134 135 136

    /*
     * D  = E^-1 mod ((P-1)*(Q-1))
     * DP = D mod (P - 1)
     * DQ = D mod (Q - 1)
     * QP = Q^-1 mod P
     */
137 138 139 140
    MBEDTLS_MPI_CHK( mbedtls_mpi_inv_mod( &ctx->D , &ctx->E, &H  ) );
    MBEDTLS_MPI_CHK( mbedtls_mpi_mod_mpi( &ctx->DP, &ctx->D, &P1 ) );
    MBEDTLS_MPI_CHK( mbedtls_mpi_mod_mpi( &ctx->DQ, &ctx->D, &Q1 ) );
    MBEDTLS_MPI_CHK( mbedtls_mpi_inv_mod( &ctx->QP, &ctx->Q, &ctx->P ) );
141

142
    ctx->len = ( mbedtls_mpi_msb( &ctx->N ) + 7 ) >> 3;
143 144 145

cleanup:

146
    mbedtls_mpi_free( &P1 ); mbedtls_mpi_free( &Q1 ); mbedtls_mpi_free( &H ); mbedtls_mpi_free( &G );
147 148 149

    if( ret != 0 )
    {
150 151
        mbedtls_rsa_free( ctx );
        return( MBEDTLS_ERR_RSA_KEY_GEN_FAILED + ret );
152 153
    }

154
    return( 0 );
155 156
}

157
#endif /* MBEDTLS_GENPRIME */
158 159 160 161

/*
 * Check a public RSA key
 */
162
int mbedtls_rsa_check_pubkey( const mbedtls_rsa_context *ctx )
163
{
164
    if( !ctx->N.p || !ctx->E.p )
165
        return( MBEDTLS_ERR_RSA_KEY_CHECK_FAILED );
166

167
    if( ( ctx->N.p[0] & 1 ) == 0 ||
168
        ( ctx->E.p[0] & 1 ) == 0 )
169
        return( MBEDTLS_ERR_RSA_KEY_CHECK_FAILED );
170

171 172 173
    if( mbedtls_mpi_msb( &ctx->N ) < 128 ||
        mbedtls_mpi_msb( &ctx->N ) > MBEDTLS_MPI_MAX_BITS )
        return( MBEDTLS_ERR_RSA_KEY_CHECK_FAILED );
174

175 176 177
    if( mbedtls_mpi_msb( &ctx->E ) < 2 ||
        mbedtls_mpi_cmp_mpi( &ctx->E, &ctx->N ) >= 0 )
        return( MBEDTLS_ERR_RSA_KEY_CHECK_FAILED );
178 179 180 181 182 183 184

    return( 0 );
}

/*
 * Check a private RSA key
 */
185
int mbedtls_rsa_check_privkey( const mbedtls_rsa_context *ctx )
186 187
{
    int ret;
188
    mbedtls_mpi PQ, DE, P1, Q1, H, I, G, G2, L1, L2, DP, DQ, QP;
189

190
    if( ( ret = mbedtls_rsa_check_pubkey( ctx ) ) != 0 )
191 192
        return( ret );

193
    if( !ctx->P.p || !ctx->Q.p || !ctx->D.p )
194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214
        return( MBEDTLS_ERR_RSA_KEY_CHECK_FAILED );

    mbedtls_mpi_init( &PQ ); mbedtls_mpi_init( &DE ); mbedtls_mpi_init( &P1 ); mbedtls_mpi_init( &Q1 );
    mbedtls_mpi_init( &H  ); mbedtls_mpi_init( &I  ); mbedtls_mpi_init( &G  ); mbedtls_mpi_init( &G2 );
    mbedtls_mpi_init( &L1 ); mbedtls_mpi_init( &L2 ); mbedtls_mpi_init( &DP ); mbedtls_mpi_init( &DQ );
    mbedtls_mpi_init( &QP );

    MBEDTLS_MPI_CHK( mbedtls_mpi_mul_mpi( &PQ, &ctx->P, &ctx->Q ) );
    MBEDTLS_MPI_CHK( mbedtls_mpi_mul_mpi( &DE, &ctx->D, &ctx->E ) );
    MBEDTLS_MPI_CHK( mbedtls_mpi_sub_int( &P1, &ctx->P, 1 ) );
    MBEDTLS_MPI_CHK( mbedtls_mpi_sub_int( &Q1, &ctx->Q, 1 ) );
    MBEDTLS_MPI_CHK( mbedtls_mpi_mul_mpi( &H, &P1, &Q1 ) );
    MBEDTLS_MPI_CHK( mbedtls_mpi_gcd( &G, &ctx->E, &H  ) );

    MBEDTLS_MPI_CHK( mbedtls_mpi_gcd( &G2, &P1, &Q1 ) );
    MBEDTLS_MPI_CHK( mbedtls_mpi_div_mpi( &L1, &L2, &H, &G2 ) );
    MBEDTLS_MPI_CHK( mbedtls_mpi_mod_mpi( &I, &DE, &L1  ) );

    MBEDTLS_MPI_CHK( mbedtls_mpi_mod_mpi( &DP, &ctx->D, &P1 ) );
    MBEDTLS_MPI_CHK( mbedtls_mpi_mod_mpi( &DQ, &ctx->D, &Q1 ) );
    MBEDTLS_MPI_CHK( mbedtls_mpi_inv_mod( &QP, &ctx->Q, &ctx->P ) );
215 216 217
    /*
     * Check for a valid PKCS1v2 private key
     */
218 219 220 221 222 223 224
    if( mbedtls_mpi_cmp_mpi( &PQ, &ctx->N ) != 0 ||
        mbedtls_mpi_cmp_mpi( &DP, &ctx->DP ) != 0 ||
        mbedtls_mpi_cmp_mpi( &DQ, &ctx->DQ ) != 0 ||
        mbedtls_mpi_cmp_mpi( &QP, &ctx->QP ) != 0 ||
        mbedtls_mpi_cmp_int( &L2, 0 ) != 0 ||
        mbedtls_mpi_cmp_int( &I, 1 ) != 0 ||
        mbedtls_mpi_cmp_int( &G, 1 ) != 0 )
225
    {
226
        ret = MBEDTLS_ERR_RSA_KEY_CHECK_FAILED;
227
    }
228

229
cleanup:
230 231 232 233
    mbedtls_mpi_free( &PQ ); mbedtls_mpi_free( &DE ); mbedtls_mpi_free( &P1 ); mbedtls_mpi_free( &Q1 );
    mbedtls_mpi_free( &H  ); mbedtls_mpi_free( &I  ); mbedtls_mpi_free( &G  ); mbedtls_mpi_free( &G2 );
    mbedtls_mpi_free( &L1 ); mbedtls_mpi_free( &L2 ); mbedtls_mpi_free( &DP ); mbedtls_mpi_free( &DQ );
    mbedtls_mpi_free( &QP );
234

235
    if( ret == MBEDTLS_ERR_RSA_KEY_CHECK_FAILED )
236 237
        return( ret );

238
    if( ret != 0 )
239
        return( MBEDTLS_ERR_RSA_KEY_CHECK_FAILED + ret );
240 241

    return( 0 );
242 243
}

244 245 246
/*
 * Check if contexts holding a public and private key match
 */
247
int mbedtls_rsa_check_pub_priv( const mbedtls_rsa_context *pub, const mbedtls_rsa_context *prv )
248
{
249 250
    if( mbedtls_rsa_check_pubkey( pub ) != 0 ||
        mbedtls_rsa_check_privkey( prv ) != 0 )
251
    {
252
        return( MBEDTLS_ERR_RSA_KEY_CHECK_FAILED );
253 254
    }

255 256
    if( mbedtls_mpi_cmp_mpi( &pub->N, &prv->N ) != 0 ||
        mbedtls_mpi_cmp_mpi( &pub->E, &prv->E ) != 0 )
257
    {
258
        return( MBEDTLS_ERR_RSA_KEY_CHECK_FAILED );
259 260 261 262 263
    }

    return( 0 );
}

264 265 266
/*
 * Do an RSA public key operation
 */
267
int mbedtls_rsa_public( mbedtls_rsa_context *ctx,
268
                const unsigned char *input,
269 270
                unsigned char *output )
{
271 272
    int ret;
    size_t olen;
273
    mbedtls_mpi T;
274

275
    mbedtls_mpi_init( &T );
276

277
    MBEDTLS_MPI_CHK( mbedtls_mpi_read_binary( &T, input, ctx->len ) );
278

279
    if( mbedtls_mpi_cmp_mpi( &T, &ctx->N ) >= 0 )
280
    {
281 282
        mbedtls_mpi_free( &T );
        return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
283 284
    }

285
#if defined(MBEDTLS_THREADING_C)
286 287
    if( ( ret = mbedtls_mutex_lock( &ctx->mutex ) ) != 0 )
        return( ret );
288 289
#endif

290
    olen = ctx->len;
291 292
    MBEDTLS_MPI_CHK( mbedtls_mpi_exp_mod( &T, &T, &ctx->E, &ctx->N, &ctx->RN ) );
    MBEDTLS_MPI_CHK( mbedtls_mpi_write_binary( &T, output, olen ) );
293 294

cleanup:
295
#if defined(MBEDTLS_THREADING_C)
296 297
    if( ( ret = mbedtls_mutex_unlock( &ctx->mutex ) ) != 0 )
        return( ret );
298
#endif
299

300
    mbedtls_mpi_free( &T );
301 302

    if( ret != 0 )
303
        return( MBEDTLS_ERR_RSA_PUBLIC_FAILED + ret );
304 305 306 307

    return( 0 );
}

308
/*
309 310
 * Generate or update blinding values, see section 10 of:
 *  KOCHER, Paul C. Timing attacks on implementations of Diffie-Hellman, RSA,
311
 *  DSS, and other systems. In : Advances in Cryptology-CRYPTO'96. Springer
312
 *  Berlin Heidelberg, 1996. p. 104-113.
313
 */
314
static int rsa_prepare_blinding( mbedtls_rsa_context *ctx, mbedtls_mpi *Vi, mbedtls_mpi *Vf,
315 316
                 int (*f_rng)(void *, unsigned char *, size_t), void *p_rng )
{
317
    int ret, count = 0;
318

319
#if defined(MBEDTLS_THREADING_C)
320 321
    if( ( ret = mbedtls_mutex_lock( &ctx->mutex ) ) != 0 )
        return( ret );
322 323
#endif

324 325 326
    if( ctx->Vf.p != NULL )
    {
        /* We already have blinding values, just update them by squaring */
327 328 329 330
        MBEDTLS_MPI_CHK( mbedtls_mpi_mul_mpi( &ctx->Vi, &ctx->Vi, &ctx->Vi ) );
        MBEDTLS_MPI_CHK( mbedtls_mpi_mod_mpi( &ctx->Vi, &ctx->Vi, &ctx->N ) );
        MBEDTLS_MPI_CHK( mbedtls_mpi_mul_mpi( &ctx->Vf, &ctx->Vf, &ctx->Vf ) );
        MBEDTLS_MPI_CHK( mbedtls_mpi_mod_mpi( &ctx->Vf, &ctx->Vf, &ctx->N ) );
331

332
        goto done;
333 334
    }

335 336 337
    /* Unblinding value: Vf = random number, invertible mod N */
    do {
        if( count++ > 10 )
338
            return( MBEDTLS_ERR_RSA_RNG_FAILED );
339

340 341 342
        MBEDTLS_MPI_CHK( mbedtls_mpi_fill_random( &ctx->Vf, ctx->len - 1, f_rng, p_rng ) );
        MBEDTLS_MPI_CHK( mbedtls_mpi_gcd( &ctx->Vi, &ctx->Vf, &ctx->N ) );
    } while( mbedtls_mpi_cmp_int( &ctx->Vi, 1 ) != 0 );
343 344

    /* Blinding value: Vi =  Vf^(-e) mod N */
345 346
    MBEDTLS_MPI_CHK( mbedtls_mpi_inv_mod( &ctx->Vi, &ctx->Vf, &ctx->N ) );
    MBEDTLS_MPI_CHK( mbedtls_mpi_exp_mod( &ctx->Vi, &ctx->Vi, &ctx->E, &ctx->N, &ctx->RN ) );
347

348 349 350
done:
    if( Vi != &ctx->Vi )
    {
351 352
        MBEDTLS_MPI_CHK( mbedtls_mpi_copy( Vi, &ctx->Vi ) );
        MBEDTLS_MPI_CHK( mbedtls_mpi_copy( Vf, &ctx->Vf ) );
353 354
    }

355
cleanup:
356
#if defined(MBEDTLS_THREADING_C)
357 358
    if( ( ret = mbedtls_mutex_unlock( &ctx->mutex ) ) != 0 )
        return( ret );
359 360
#endif

361 362 363
    return( ret );
}

364 365 366
/*
 * Do an RSA private key operation
 */
367
int mbedtls_rsa_private( mbedtls_rsa_context *ctx,
368 369
                 int (*f_rng)(void *, unsigned char *, size_t),
                 void *p_rng,
370
                 const unsigned char *input,
371 372
                 unsigned char *output )
{
373 374
    int ret;
    size_t olen;
375 376
    mbedtls_mpi T, T1, T2;
    mbedtls_mpi *Vi, *Vf;
377 378 379 380 381 382

    /*
     * When using the Chinese Remainder Theorem, we use blinding values.
     * Without threading, we just read them directly from the context,
     * otherwise we make a local copy in order to reduce locking contention.
     */
383 384
#if defined(MBEDTLS_THREADING_C)
    mbedtls_mpi Vi_copy, Vf_copy;
385

386
    mbedtls_mpi_init( &Vi_copy ); mbedtls_mpi_init( &Vf_copy );
387 388 389 390 391 392
    Vi = &Vi_copy;
    Vf = &Vf_copy;
#else
    Vi = &ctx->Vi;
    Vf = &ctx->Vf;
#endif
393

394
    mbedtls_mpi_init( &T ); mbedtls_mpi_init( &T1 ); mbedtls_mpi_init( &T2 );
395

396 397
    MBEDTLS_MPI_CHK( mbedtls_mpi_read_binary( &T, input, ctx->len ) );
    if( mbedtls_mpi_cmp_mpi( &T, &ctx->N ) >= 0 )
398
    {
399 400
        mbedtls_mpi_free( &T );
        return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
401 402
    }

403 404 405
    if( f_rng != NULL )
    {
        /*
406 407
         * Blinding
         * T = T * Vi mod N
408
         */
409 410 411
        MBEDTLS_MPI_CHK( rsa_prepare_blinding( ctx, Vi, Vf, f_rng, p_rng ) );
        MBEDTLS_MPI_CHK( mbedtls_mpi_mul_mpi( &T, &T, Vi ) );
        MBEDTLS_MPI_CHK( mbedtls_mpi_mod_mpi( &T, &T, &ctx->N ) );
412
    }
413

414
#if defined(MBEDTLS_THREADING_C)
415 416
    if( ( ret = mbedtls_mutex_lock( &ctx->mutex ) ) != 0 )
        return( ret );
417 418
#endif

419 420
#if defined(MBEDTLS_RSA_NO_CRT)
    MBEDTLS_MPI_CHK( mbedtls_mpi_exp_mod( &T, &T, &ctx->D, &ctx->N, &ctx->RN ) );
421
#else
422 423 424 425 426 427
    /*
     * faster decryption using the CRT
     *
     * T1 = input ^ dP mod P
     * T2 = input ^ dQ mod Q
     */
428 429
    MBEDTLS_MPI_CHK( mbedtls_mpi_exp_mod( &T1, &T, &ctx->DP, &ctx->P, &ctx->RP ) );
    MBEDTLS_MPI_CHK( mbedtls_mpi_exp_mod( &T2, &T, &ctx->DQ, &ctx->Q, &ctx->RQ ) );
430 431 432 433

    /*
     * T = (T1 - T2) * (Q^-1 mod P) mod P
     */
434 435 436
    MBEDTLS_MPI_CHK( mbedtls_mpi_sub_mpi( &T, &T1, &T2 ) );
    MBEDTLS_MPI_CHK( mbedtls_mpi_mul_mpi( &T1, &T, &ctx->QP ) );
    MBEDTLS_MPI_CHK( mbedtls_mpi_mod_mpi( &T, &T1, &ctx->P ) );
437 438

    /*
439
     * T = T2 + T * Q
440
     */
441 442 443
    MBEDTLS_MPI_CHK( mbedtls_mpi_mul_mpi( &T1, &T, &ctx->Q ) );
    MBEDTLS_MPI_CHK( mbedtls_mpi_add_mpi( &T, &T2, &T1 ) );
#endif /* MBEDTLS_RSA_NO_CRT */
444

445 446 447 448
    if( f_rng != NULL )
    {
        /*
         * Unblind
449
         * T = T * Vf mod N
450
         */
451 452
        MBEDTLS_MPI_CHK( mbedtls_mpi_mul_mpi( &T, &T, Vf ) );
        MBEDTLS_MPI_CHK( mbedtls_mpi_mod_mpi( &T, &T, &ctx->N ) );
453
    }
454 455

    olen = ctx->len;
456
    MBEDTLS_MPI_CHK( mbedtls_mpi_write_binary( &T, output, olen ) );
457 458

cleanup:
459
#if defined(MBEDTLS_THREADING_C)
460 461
    if( ( ret = mbedtls_mutex_unlock( &ctx->mutex ) ) != 0 )
        return( ret );
462
    mbedtls_mpi_free( &Vi_copy ); mbedtls_mpi_free( &Vf_copy );
463
#endif
464
    mbedtls_mpi_free( &T ); mbedtls_mpi_free( &T1 ); mbedtls_mpi_free( &T2 );
465 466

    if( ret != 0 )
467
        return( MBEDTLS_ERR_RSA_PRIVATE_FAILED + ret );
468 469 470 471

    return( 0 );
}

472
#if defined(MBEDTLS_PKCS1_V21)
473 474 475
/**
 * Generate and apply the MGF1 operation (from PKCS#1 v2.1) to a buffer.
 *
476 477 478 479 480
 * \param dst       buffer to mask
 * \param dlen      length of destination buffer
 * \param src       source of the mask generation
 * \param slen      length of the source buffer
 * \param md_ctx    message digest context to use
481
 */
482
static void mgf_mask( unsigned char *dst, size_t dlen, unsigned char *src,
483
                      size_t slen, mbedtls_md_context_t *md_ctx )
484
{
485
    unsigned char mask[MBEDTLS_MD_MAX_SIZE];
486 487
    unsigned char counter[4];
    unsigned char *p;
488 489
    unsigned int hlen;
    size_t i, use_len;
490

491
    memset( mask, 0, MBEDTLS_MD_MAX_SIZE );
492 493
    memset( counter, 0, 4 );

494
    hlen = mbedtls_md_get_size( md_ctx->md_info );
495 496 497 498 499 500 501 502 503 504 505

    // Generate and apply dbMask
    //
    p = dst;

    while( dlen > 0 )
    {
        use_len = hlen;
        if( dlen < hlen )
            use_len = dlen;

506 507 508 509
        mbedtls_md_starts( md_ctx );
        mbedtls_md_update( md_ctx, src, slen );
        mbedtls_md_update( md_ctx, counter, 4 );
        mbedtls_md_finish( md_ctx, mask );
510 511 512 513 514 515 516 517 518

        for( i = 0; i < use_len; ++i )
            *p++ ^= mask[i];

        counter[3]++;

        dlen -= use_len;
    }
}
519
#endif /* MBEDTLS_PKCS1_V21 */
520

521
#if defined(MBEDTLS_PKCS1_V21)
522
/*
523
 * Implementation of the PKCS#1 v2.1 RSAES-OAEP-ENCRYPT function
524
 */
525
int mbedtls_rsa_rsaes_oaep_encrypt( mbedtls_rsa_context *ctx,
526 527
                            int (*f_rng)(void *, unsigned char *, size_t),
                            void *p_rng,
528 529 530
                            int mode,
                            const unsigned char *label, size_t label_len,
                            size_t ilen,
531 532
                            const unsigned char *input,
                            unsigned char *output )
533
{
534
    size_t olen;
535
    int ret;
536
    unsigned char *p = output;
537
    unsigned int hlen;
538 539
    const mbedtls_md_info_t *md_info;
    mbedtls_md_context_t md_ctx;
540

541 542
    if( mode == MBEDTLS_RSA_PRIVATE && ctx->padding != MBEDTLS_RSA_PKCS_V21 )
        return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
543 544

    if( f_rng == NULL )
545
        return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
546

547
    md_info = mbedtls_md_info_from_type( (mbedtls_md_type_t) ctx->hash_id );
548
    if( md_info == NULL )
549
        return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
550

551
    olen = ctx->len;
552
    hlen = mbedtls_md_get_size( md_info );
553

554
    if( olen < ilen + 2 * hlen + 2 )
555
        return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
556

557
    memset( output, 0, olen );
558

559
    *p++ = 0;
560

561 562 563
    // Generate a random octet string seed
    //
    if( ( ret = f_rng( p_rng, p, hlen ) ) != 0 )
564
        return( MBEDTLS_ERR_RSA_RNG_FAILED + ret );
565

566
    p += hlen;
567

568 569
    // Construct DB
    //
570
    mbedtls_md( md_info, label, label_len, p );
571 572 573 574
    p += hlen;
    p += olen - 2 * hlen - 2 - ilen;
    *p++ = 1;
    memcpy( p, input, ilen );
575

576 577
    mbedtls_md_init( &md_ctx );
    mbedtls_md_setup( &md_ctx, md_info, 0 );
578

579 580 581 582
    // maskedDB: Apply dbMask to DB
    //
    mgf_mask( output + hlen + 1, olen - hlen - 1, output + 1, hlen,
               &md_ctx );
583

584 585 586 587
    // maskedSeed: Apply seedMask to seed
    //
    mgf_mask( output + 1, hlen, output + hlen + 1, olen - hlen - 1,
               &md_ctx );
588

589
    mbedtls_md_free( &md_ctx );
590

591 592 593
    return( ( mode == MBEDTLS_RSA_PUBLIC )
            ? mbedtls_rsa_public(  ctx, output, output )
            : mbedtls_rsa_private( ctx, f_rng, p_rng, output, output ) );
594
}
595
#endif /* MBEDTLS_PKCS1_V21 */
596

597
#if defined(MBEDTLS_PKCS1_V15)
598 599 600
/*
 * Implementation of the PKCS#1 v2.1 RSAES-PKCS1-V1_5-ENCRYPT function
 */
601
int mbedtls_rsa_rsaes_pkcs1_v15_encrypt( mbedtls_rsa_context *ctx,
602 603 604 605 606 607 608 609 610
                                 int (*f_rng)(void *, unsigned char *, size_t),
                                 void *p_rng,
                                 int mode, size_t ilen,
                                 const unsigned char *input,
                                 unsigned char *output )
{
    size_t nb_pad, olen;
    int ret;
    unsigned char *p = output;
611

612 613
    if( mode == MBEDTLS_RSA_PRIVATE && ctx->padding != MBEDTLS_RSA_PKCS_V15 )
        return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
614 615

    if( f_rng == NULL )
616
        return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
617

618
    olen = ctx->len;
619

620
    if( olen < ilen + 11 )
621
        return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
622

623
    nb_pad = olen - 3 - ilen;
624

625
    *p++ = 0;
626
    if( mode == MBEDTLS_RSA_PUBLIC )
627
    {
628
        *p++ = MBEDTLS_RSA_CRYPT;
629

630 631 632
        while( nb_pad-- > 0 )
        {
            int rng_dl = 100;
633

634 635 636
            do {
                ret = f_rng( p_rng, p, 1 );
            } while( *p == 0 && --rng_dl && ret == 0 );
637

638
            // Check if RNG failed to generate data
639
            //
640
            if( rng_dl == 0 || ret != 0 )
641
                return( MBEDTLS_ERR_RSA_RNG_FAILED + ret );
642

643 644 645 646 647
            p++;
        }
    }
    else
    {
648
        *p++ = MBEDTLS_RSA_SIGN;
649

650 651
        while( nb_pad-- > 0 )
            *p++ = 0xFF;
652 653
    }

654 655 656
    *p++ = 0;
    memcpy( p, input, ilen );

657 658 659
    return( ( mode == MBEDTLS_RSA_PUBLIC )
            ? mbedtls_rsa_public(  ctx, output, output )
            : mbedtls_rsa_private( ctx, f_rng, p_rng, output, output ) );
660
}
661
#endif /* MBEDTLS_PKCS1_V15 */
662 663

/*
664
 * Add the message padding, then do an RSA operation
665
 */
666
int mbedtls_rsa_pkcs1_encrypt( mbedtls_rsa_context *ctx,
667 668 669
                       int (*f_rng)(void *, unsigned char *, size_t),
                       void *p_rng,
                       int mode, size_t ilen,
670
                       const unsigned char *input,
671 672 673 674
                       unsigned char *output )
{
    switch( ctx->padding )
    {
675 676 677
#if defined(MBEDTLS_PKCS1_V15)
        case MBEDTLS_RSA_PKCS_V15:
            return mbedtls_rsa_rsaes_pkcs1_v15_encrypt( ctx, f_rng, p_rng, mode, ilen,
678
                                                input, output );
679
#endif
680

681 682 683
#if defined(MBEDTLS_PKCS1_V21)
        case MBEDTLS_RSA_PKCS_V21:
            return mbedtls_rsa_rsaes_oaep_encrypt( ctx, f_rng, p_rng, mode, NULL, 0,
684 685 686 687
                                           ilen, input, output );
#endif

        default:
688
            return( MBEDTLS_ERR_RSA_INVALID_PADDING );
689 690 691
    }
}

692
#if defined(MBEDTLS_PKCS1_V21)
693 694 695
/*
 * Implementation of the PKCS#1 v2.1 RSAES-OAEP-DECRYPT function
 */
696
int mbedtls_rsa_rsaes_oaep_decrypt( mbedtls_rsa_context *ctx,
697 698 699
                            int (*f_rng)(void *, unsigned char *, size_t),
                            void *p_rng,
                            int mode,
700 701
                            const unsigned char *label, size_t label_len,
                            size_t *olen,
702 703 704
                            const unsigned char *input,
                            unsigned char *output,
                            size_t output_max_len )
705
{
706
    int ret;
707 708
    size_t ilen, i, pad_len;
    unsigned char *p, bad, pad_done;
709 710
    unsigned char buf[MBEDTLS_MPI_MAX_SIZE];
    unsigned char lhash[MBEDTLS_MD_MAX_SIZE];
711
    unsigned int hlen;
712 713
    const mbedtls_md_info_t *md_info;
    mbedtls_md_context_t md_ctx;
714

715 716 717
    /*
     * Parameters sanity checks
     */
718 719
    if( mode == MBEDTLS_RSA_PRIVATE && ctx->padding != MBEDTLS_RSA_PKCS_V21 )
        return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
720 721 722

    ilen = ctx->len;

723
    if( ilen < 16 || ilen > sizeof( buf ) )
724
        return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
725

726
    md_info = mbedtls_md_info_from_type( (mbedtls_md_type_t) ctx->hash_id );
727
    if( md_info == NULL )
728
        return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
729 730 731 732

    /*
     * RSA operation
     */
733 734 735
    ret = ( mode == MBEDTLS_RSA_PUBLIC )
          ? mbedtls_rsa_public(  ctx, input, buf )
          : mbedtls_rsa_private( ctx, f_rng, p_rng, input, buf );
736 737 738 739

    if( ret != 0 )
        return( ret );

740
    /*
741
     * Unmask data and generate lHash
742
     */
743
    hlen = mbedtls_md_get_size( md_info );
744

745 746
    mbedtls_md_init( &md_ctx );
    mbedtls_md_setup( &md_ctx, md_info, 0 );
747

748
    /* Generate lHash */
749
    mbedtls_md( md_info, label, label_len, lhash );
750

751
    /* seed: Apply seedMask to maskedSeed */
752 753
    mgf_mask( buf + 1, hlen, buf + hlen + 1, ilen - hlen - 1,
               &md_ctx );
754

755
    /* DB: Apply dbMask to maskedDB */
756 757
    mgf_mask( buf + hlen + 1, ilen - hlen - 1, buf + 1, hlen,
               &md_ctx );
758

759
    mbedtls_md_free( &md_ctx );
760

761
    /*
762
     * Check contents, in "constant-time"
763 764
     */
    p = buf;
765
    bad = 0;
766

767
    bad |= *p++; /* First byte must be 0 */
768 769 770 771

    p += hlen; /* Skip seed */

    /* Check lHash */
772 773 774 775 776 777 778 779 780 781
    for( i = 0; i < hlen; i++ )
        bad |= lhash[i] ^ *p++;

    /* Get zero-padding len, but always read till end of buffer
     * (minus one, for the 01 byte) */
    pad_len = 0;
    pad_done = 0;
    for( i = 0; i < ilen - 2 * hlen - 2; i++ )
    {
        pad_done |= p[i];
782
        pad_len += ((pad_done | (unsigned char)-pad_done) >> 7) ^ 1;
783
    }
784

785 786
    p += pad_len;
    bad |= *p++ ^ 0x01;
787

788 789 790 791 792 793 794
    /*
     * The only information "leaked" is whether the padding was correct or not
     * (eg, no data is copied if it was not correct). This meets the
     * recommendations in PKCS#1 v2.2: an opponent cannot distinguish between
     * the different error conditions.
     */
    if( bad != 0 )
795
        return( MBEDTLS_ERR_RSA_INVALID_PADDING );
796

797
    if( ilen - ( p - buf ) > output_max_len )
798
        return( MBEDTLS_ERR_RSA_OUTPUT_TOO_LARGE );
799

800 801
    *olen = ilen - (p - buf);
    memcpy( output, p, *olen );
802

803 804
    return( 0 );
}
805
#endif /* MBEDTLS_PKCS1_V21 */
806

807
#if defined(MBEDTLS_PKCS1_V15)
808 809 810
/*
 * Implementation of the PKCS#1 v2.1 RSAES-PKCS1-V1_5-DECRYPT function
 */
811
int mbedtls_rsa_rsaes_pkcs1_v15_decrypt( mbedtls_rsa_context *ctx,
812 813
                                 int (*f_rng)(void *, unsigned char *, size_t),
                                 void *p_rng,
814 815 816 817 818
                                 int mode, size_t *olen,
                                 const unsigned char *input,
                                 unsigned char *output,
                                 size_t output_max_len)
{
819 820 821
    int ret;
    size_t ilen, pad_count = 0, i;
    unsigned char *p, bad, pad_done = 0;
822
    unsigned char buf[MBEDTLS_MPI_MAX_SIZE];
823

824 825
    if( mode == MBEDTLS_RSA_PRIVATE && ctx->padding != MBEDTLS_RSA_PKCS_V15 )
        return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
826 827 828 829

    ilen = ctx->len;

    if( ilen < 16 || ilen > sizeof( buf ) )
830
        return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
831

832 833 834
    ret = ( mode == MBEDTLS_RSA_PUBLIC )
          ? mbedtls_rsa_public(  ctx, input, buf )
          : mbedtls_rsa_private( ctx, f_rng, p_rng, input, buf );
835 836 837 838 839

    if( ret != 0 )
        return( ret );

    p = buf;
840
    bad = 0;
841

842 843 844 845
    /*
     * Check and get padding len in "constant-time"
     */
    bad |= *p++; /* First byte must be 0 */
846

847
    /* This test does not depend on secret data */
848
    if( mode == MBEDTLS_RSA_PRIVATE )
849
    {
850
        bad |= *p++ ^ MBEDTLS_RSA_CRYPT;
851

852 853 854 855
        /* Get padding len, but always read till end of buffer
         * (minus one, for the 00 byte) */
        for( i = 0; i < ilen - 3; i++ )
        {
856 857
            pad_done  |= ((p[i] | (unsigned char)-p[i]) >> 7) ^ 1;
            pad_count += ((pad_done | (unsigned char)-pad_done) >> 7) ^ 1;
858
        }
859

860 861
        p += pad_count;
        bad |= *p++; /* Must be zero */
862 863 864
    }
    else
    {
865
        bad |= *p++ ^ MBEDTLS_RSA_SIGN;
866

867 868 869 870
        /* Get padding len, but always read till end of buffer
         * (minus one, for the 00 byte) */
        for( i = 0; i < ilen - 3; i++ )
        {
871
            pad_done |= ( p[i] != 0xFF );
872 873
            pad_count += ( pad_done == 0 );
        }
874

875 876
        p += pad_count;
        bad |= *p++; /* Must be zero */
877 878
    }

879
    if( bad )
880
        return( MBEDTLS_ERR_RSA_INVALID_PADDING );
881

882
    if( ilen - ( p - buf ) > output_max_len )
883
        return( MBEDTLS_ERR_RSA_OUTPUT_TOO_LARGE );
884

885
    *olen = ilen - (p - buf);
886 887 888 889
    memcpy( output, p, *olen );

    return( 0 );
}
890
#endif /* MBEDTLS_PKCS1_V15 */
891 892

/*
893
 * Do an RSA operation, then remove the message padding
894
 */
895
int mbedtls_rsa_pkcs1_decrypt( mbedtls_rsa_context *ctx,
896 897
                       int (*f_rng)(void *, unsigned char *, size_t),
                       void *p_rng,
898 899 900 901
                       int mode, size_t *olen,
                       const unsigned char *input,
                       unsigned char *output,
                       size_t output_max_len)
902
{
903 904
    switch( ctx->padding )
    {
905 906 907
#if defined(MBEDTLS_PKCS1_V15)
        case MBEDTLS_RSA_PKCS_V15:
            return mbedtls_rsa_rsaes_pkcs1_v15_decrypt( ctx, f_rng, p_rng, mode, olen,
908
                                                input, output, output_max_len );
909
#endif
910

911 912 913
#if defined(MBEDTLS_PKCS1_V21)
        case MBEDTLS_RSA_PKCS_V21:
            return mbedtls_rsa_rsaes_oaep_decrypt( ctx, f_rng, p_rng, mode, NULL, 0,
914 915
                                           olen, input, output,
                                           output_max_len );
916 917 918
#endif

        default:
919
            return( MBEDTLS_ERR_RSA_INVALID_PADDING );
920 921 922
    }
}

923
#if defined(MBEDTLS_PKCS1_V21)
924 925 926
/*
 * Implementation of the PKCS#1 v2.1 RSASSA-PSS-SIGN function
 */
927
int mbedtls_rsa_rsassa_pss_sign( mbedtls_rsa_context *ctx,
928 929 930
                         int (*f_rng)(void *, unsigned char *, size_t),
                         void *p_rng,
                         int mode,
931
                         mbedtls_md_type_t md_alg,
932 933 934 935 936 937
                         unsigned int hashlen,
                         const unsigned char *hash,
                         unsigned char *sig )
{
    size_t olen;
    unsigned char *p = sig;
938
    unsigned char salt[MBEDTLS_MD_MAX_SIZE];
939 940
    unsigned int slen, hlen, offset = 0;
    int ret;
941
    size_t msb;
942 943
    const mbedtls_md_info_t *md_info;
    mbedtls_md_context_t md_ctx;
944

945 946
    if( mode == MBEDTLS_RSA_PRIVATE && ctx->padding != MBEDTLS_RSA_PKCS_V21 )
        return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
947 948

    if( f_rng == NULL )
949
        return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
950 951 952

    olen = ctx->len;

953
    if( md_alg != MBEDTLS_MD_NONE )
954
    {
955 956
        // Gather length of hash to sign
        //
957
        md_info = mbedtls_md_info_from_type( md_alg );
958
        if( md_info == NULL )
959
            return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
960

961
        hashlen = mbedtls_md_get_size( md_info );
962
    }
963

964
    md_info = mbedtls_md_info_from_type( (mbedtls_md_type_t) ctx->hash_id );
965
    if( md_info == NULL )
966
        return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
967

968
    hlen = mbedtls_md_get_size( md_info );
969
    slen = hlen;
970

971
    if( olen < hlen + slen + 2 )
972
        return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
973

974
    memset( sig, 0, olen );
975

976 977 978
    // Generate salt of length slen
    //
    if( ( ret = f_rng( p_rng, salt, slen ) ) != 0 )
979
        return( MBEDTLS_ERR_RSA_RNG_FAILED + ret );
980

981 982
    // Note: EMSA-PSS encoding is over the length of N - 1 bits
    //
983
    msb = mbedtls_mpi_msb( &ctx->N ) - 1;
984 985 986 987
    p += olen - hlen * 2 - 2;
    *p++ = 0x01;
    memcpy( p, salt, slen );
    p += slen;
988

989 990
    mbedtls_md_init( &md_ctx );
    mbedtls_md_setup( &md_ctx, md_info, 0 );
991

992 993
    // Generate H = Hash( M' )
    //
994 995 996 997 998
    mbedtls_md_starts( &md_ctx );
    mbedtls_md_update( &md_ctx, p, 8 );
    mbedtls_md_update( &md_ctx, hash, hashlen );
    mbedtls_md_update( &md_ctx, salt, slen );
    mbedtls_md_finish( &md_ctx, p );
999

1000 1001 1002 1003
    // Compensate for boundary condition when applying mask
    //
    if( msb % 8 == 0 )
        offset = 1;
1004

1005 1006 1007
    // maskedDB: Apply dbMask to DB
    //
    mgf_mask( sig + offset, olen - hlen - 1 - offset, p, hlen, &md_ctx );
1008

1009
    mbedtls_md_free( &md_ctx );
1010

1011
    msb = mbedtls_mpi_msb( &ctx->N ) - 1;
1012
    sig[0] &= 0xFF >> ( olen * 8 - msb );
1013

1014 1015
    p += hlen;
    *p++ = 0xBC;
1016

1017 1018 1019
    return( ( mode == MBEDTLS_RSA_PUBLIC )
            ? mbedtls_rsa_public(  ctx, sig, sig )
            : mbedtls_rsa_private( ctx, f_rng, p_rng, sig, sig ) );
1020
}
1021
#endif /* MBEDTLS_PKCS1_V21 */
1022

1023
#if defined(MBEDTLS_PKCS1_V15)
1024 1025 1026 1027 1028 1029
/*
 * Implementation of the PKCS#1 v2.1 RSASSA-PKCS1-V1_5-SIGN function
 */
/*
 * Do an RSA operation to sign the message digest
 */
1030
int mbedtls_rsa_rsassa_pkcs1_v15_sign( mbedtls_rsa_context *ctx,