rsa.c 49.7 KB
Newer Older
1 2 3
/*
 *  The RSA public-key cryptosystem
 *
4
 *  Copyright (C) 2006-2015, ARM Limited, All Rights Reserved
5
 *  SPDX-License-Identifier: Apache-2.0
6
 *
7 8 9
 *  Licensed under the Apache License, Version 2.0 (the "License"); you may
 *  not use this file except in compliance with the License.
 *  You may obtain a copy of the License at
10
 *
11
 *  http://www.apache.org/licenses/LICENSE-2.0
12
 *
13 14 15 16 17
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
18
 *
19
 *  This file is part of mbed TLS (https://tls.mbed.org)
20 21
 */
/*
22 23 24 25 26 27 28 29 30
 *  The following sources were referenced in the design of this implementation
 *  of the RSA algorithm:
 *
 *  [1] A method for obtaining digital signatures and public-key cryptosystems
 *      R Rivest, A Shamir, and L Adleman
 *      http://people.csail.mit.edu/rivest/pubs.html#RSA78
 *
 *  [2] Handbook of Applied Cryptography - 1997, Chapter 8
 *      Menezes, van Oorschot and Vanstone
31 32 33
 *
 */

34
#if !defined(MBEDTLS_CONFIG_FILE)
35
#include "mbedtls/config.h"
36
#else
37
#include MBEDTLS_CONFIG_FILE
38
#endif
39

40
#if defined(MBEDTLS_RSA_C)
41

42 43
#include "mbedtls/rsa.h"
#include "mbedtls/oid.h"
44

45 46
#include <string.h>

47
#if defined(MBEDTLS_PKCS1_V21)
48
#include "mbedtls/md.h"
49
#endif
50

51
#if defined(MBEDTLS_PKCS1_V15) && !defined(__OpenBSD__)
52
#include <stdlib.h>
53
#endif
54

55
#if defined(MBEDTLS_PLATFORM_C)
56
#include "mbedtls/platform.h"
57
#else
58
#include <stdio.h>
59
#define mbedtls_printf printf
60 61
#define mbedtls_calloc calloc
#define mbedtls_free   free
62 63
#endif

64 65 66
/*
 * Initialize an RSA context
 */
67
void mbedtls_rsa_init( mbedtls_rsa_context *ctx,
68
               int padding,
69
               int hash_id )
70
{
71
    memset( ctx, 0, sizeof( mbedtls_rsa_context ) );
72

73
    mbedtls_rsa_set_padding( ctx, padding, hash_id );
74

75 76
#if defined(MBEDTLS_THREADING_C)
    mbedtls_mutex_init( &ctx->mutex );
77
#endif
78 79
}

80 81 82
/*
 * Set padding for an existing RSA context
 */
83
void mbedtls_rsa_set_padding( mbedtls_rsa_context *ctx, int padding, int hash_id )
84 85 86 87 88
{
    ctx->padding = padding;
    ctx->hash_id = hash_id;
}

89
#if defined(MBEDTLS_GENPRIME)
90 91 92 93

/*
 * Generate an RSA keypair
 */
94
int mbedtls_rsa_gen_key( mbedtls_rsa_context *ctx,
95 96 97
                 int (*f_rng)(void *, unsigned char *, size_t),
                 void *p_rng,
                 unsigned int nbits, int exponent )
98 99
{
    int ret;
100
    mbedtls_mpi P1, Q1, H, G;
101

102
    if( f_rng == NULL || nbits < 128 || exponent < 3 )
103
        return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
104

105 106
    mbedtls_mpi_init( &P1 ); mbedtls_mpi_init( &Q1 ); 
    mbedtls_mpi_init( &H ); mbedtls_mpi_init( &G );
107 108 109 110 111

    /*
     * find primes P and Q with Q < P so that:
     * GCD( E, (P-1)*(Q-1) ) == 1
     */
112
    MBEDTLS_MPI_CHK( mbedtls_mpi_lset( &ctx->E, exponent ) );
113 114 115

    do
    {
116
        MBEDTLS_MPI_CHK( mbedtls_mpi_gen_prime( &ctx->P, nbits >> 1, 0,
117
                                f_rng, p_rng ) );
118

119
        if( nbits % 2 )
120
        {
121 122
            MBEDTLS_MPI_CHK( mbedtls_mpi_gen_prime( &ctx->Q, ( nbits >> 1 ) + 1, 0,
                                f_rng, p_rng ) );
123
        }
124
        else
125
        {
126
            MBEDTLS_MPI_CHK( mbedtls_mpi_gen_prime( &ctx->Q, nbits >> 1, 0,
127
                                f_rng, p_rng ) );
128
        }
129

130
        if( mbedtls_mpi_cmp_mpi( &ctx->P, &ctx->Q ) == 0 )
131 132
            continue;

133
        MBEDTLS_MPI_CHK( mbedtls_mpi_mul_mpi( &ctx->N, &ctx->P, &ctx->Q ) );
134
        if( mbedtls_mpi_bitlen( &ctx->N ) != nbits )
135 136
            continue;

137 138 139 140
        MBEDTLS_MPI_CHK( mbedtls_mpi_sub_int( &P1, &ctx->P, 1 ) );
        MBEDTLS_MPI_CHK( mbedtls_mpi_sub_int( &Q1, &ctx->Q, 1 ) );
        MBEDTLS_MPI_CHK( mbedtls_mpi_mul_mpi( &H, &P1, &Q1 ) );
        MBEDTLS_MPI_CHK( mbedtls_mpi_gcd( &G, &ctx->E, &H  ) );
141
    }
142
    while( mbedtls_mpi_cmp_int( &G, 1 ) != 0 );
143 144 145 146 147 148 149

    /*
     * D  = E^-1 mod ((P-1)*(Q-1))
     * DP = D mod (P - 1)
     * DQ = D mod (Q - 1)
     * QP = Q^-1 mod P
     */
150 151 152 153
    MBEDTLS_MPI_CHK( mbedtls_mpi_inv_mod( &ctx->D , &ctx->E, &H  ) );
    MBEDTLS_MPI_CHK( mbedtls_mpi_mod_mpi( &ctx->DP, &ctx->D, &P1 ) );
    MBEDTLS_MPI_CHK( mbedtls_mpi_mod_mpi( &ctx->DQ, &ctx->D, &Q1 ) );
    MBEDTLS_MPI_CHK( mbedtls_mpi_inv_mod( &ctx->QP, &ctx->Q, &ctx->P ) );
154

155
    ctx->len = ( mbedtls_mpi_bitlen( &ctx->N ) + 7 ) >> 3;
156 157 158

cleanup:

159
    mbedtls_mpi_free( &P1 ); mbedtls_mpi_free( &Q1 ); mbedtls_mpi_free( &H ); mbedtls_mpi_free( &G );
160 161 162

    if( ret != 0 )
    {
163 164
        mbedtls_rsa_free( ctx );
        return( MBEDTLS_ERR_RSA_KEY_GEN_FAILED + ret );
165 166
    }

167
    return( 0 );
168 169
}

170
#endif /* MBEDTLS_GENPRIME */
171 172 173 174

/*
 * Check a public RSA key
 */
175
int mbedtls_rsa_check_pubkey( const mbedtls_rsa_context *ctx )
176
{
177
    if( !ctx->N.p || !ctx->E.p )
178
        return( MBEDTLS_ERR_RSA_KEY_CHECK_FAILED );
179

180
    if( ( ctx->N.p[0] & 1 ) == 0 ||
181
        ( ctx->E.p[0] & 1 ) == 0 )
182
        return( MBEDTLS_ERR_RSA_KEY_CHECK_FAILED );
183

184 185
    if( mbedtls_mpi_bitlen( &ctx->N ) < 128 ||
        mbedtls_mpi_bitlen( &ctx->N ) > MBEDTLS_MPI_MAX_BITS )
186
        return( MBEDTLS_ERR_RSA_KEY_CHECK_FAILED );
187

188
    if( mbedtls_mpi_bitlen( &ctx->E ) < 2 ||
189 190
        mbedtls_mpi_cmp_mpi( &ctx->E, &ctx->N ) >= 0 )
        return( MBEDTLS_ERR_RSA_KEY_CHECK_FAILED );
191 192 193 194 195 196 197

    return( 0 );
}

/*
 * Check a private RSA key
 */
198
int mbedtls_rsa_check_privkey( const mbedtls_rsa_context *ctx )
199 200
{
    int ret;
201
    mbedtls_mpi PQ, DE, P1, Q1, H, I, G, G2, L1, L2, DP, DQ, QP;
202

203
    if( ( ret = mbedtls_rsa_check_pubkey( ctx ) ) != 0 )
204 205
        return( ret );

206
    if( !ctx->P.p || !ctx->Q.p || !ctx->D.p )
207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227
        return( MBEDTLS_ERR_RSA_KEY_CHECK_FAILED );

    mbedtls_mpi_init( &PQ ); mbedtls_mpi_init( &DE ); mbedtls_mpi_init( &P1 ); mbedtls_mpi_init( &Q1 );
    mbedtls_mpi_init( &H  ); mbedtls_mpi_init( &I  ); mbedtls_mpi_init( &G  ); mbedtls_mpi_init( &G2 );
    mbedtls_mpi_init( &L1 ); mbedtls_mpi_init( &L2 ); mbedtls_mpi_init( &DP ); mbedtls_mpi_init( &DQ );
    mbedtls_mpi_init( &QP );

    MBEDTLS_MPI_CHK( mbedtls_mpi_mul_mpi( &PQ, &ctx->P, &ctx->Q ) );
    MBEDTLS_MPI_CHK( mbedtls_mpi_mul_mpi( &DE, &ctx->D, &ctx->E ) );
    MBEDTLS_MPI_CHK( mbedtls_mpi_sub_int( &P1, &ctx->P, 1 ) );
    MBEDTLS_MPI_CHK( mbedtls_mpi_sub_int( &Q1, &ctx->Q, 1 ) );
    MBEDTLS_MPI_CHK( mbedtls_mpi_mul_mpi( &H, &P1, &Q1 ) );
    MBEDTLS_MPI_CHK( mbedtls_mpi_gcd( &G, &ctx->E, &H  ) );

    MBEDTLS_MPI_CHK( mbedtls_mpi_gcd( &G2, &P1, &Q1 ) );
    MBEDTLS_MPI_CHK( mbedtls_mpi_div_mpi( &L1, &L2, &H, &G2 ) );
    MBEDTLS_MPI_CHK( mbedtls_mpi_mod_mpi( &I, &DE, &L1  ) );

    MBEDTLS_MPI_CHK( mbedtls_mpi_mod_mpi( &DP, &ctx->D, &P1 ) );
    MBEDTLS_MPI_CHK( mbedtls_mpi_mod_mpi( &DQ, &ctx->D, &Q1 ) );
    MBEDTLS_MPI_CHK( mbedtls_mpi_inv_mod( &QP, &ctx->Q, &ctx->P ) );
228 229 230
    /*
     * Check for a valid PKCS1v2 private key
     */
231 232 233 234 235 236 237
    if( mbedtls_mpi_cmp_mpi( &PQ, &ctx->N ) != 0 ||
        mbedtls_mpi_cmp_mpi( &DP, &ctx->DP ) != 0 ||
        mbedtls_mpi_cmp_mpi( &DQ, &ctx->DQ ) != 0 ||
        mbedtls_mpi_cmp_mpi( &QP, &ctx->QP ) != 0 ||
        mbedtls_mpi_cmp_int( &L2, 0 ) != 0 ||
        mbedtls_mpi_cmp_int( &I, 1 ) != 0 ||
        mbedtls_mpi_cmp_int( &G, 1 ) != 0 )
238
    {
239
        ret = MBEDTLS_ERR_RSA_KEY_CHECK_FAILED;
240
    }
241

242
cleanup:
243 244 245 246
    mbedtls_mpi_free( &PQ ); mbedtls_mpi_free( &DE ); mbedtls_mpi_free( &P1 ); mbedtls_mpi_free( &Q1 );
    mbedtls_mpi_free( &H  ); mbedtls_mpi_free( &I  ); mbedtls_mpi_free( &G  ); mbedtls_mpi_free( &G2 );
    mbedtls_mpi_free( &L1 ); mbedtls_mpi_free( &L2 ); mbedtls_mpi_free( &DP ); mbedtls_mpi_free( &DQ );
    mbedtls_mpi_free( &QP );
247

248
    if( ret == MBEDTLS_ERR_RSA_KEY_CHECK_FAILED )
249 250
        return( ret );

251
    if( ret != 0 )
252
        return( MBEDTLS_ERR_RSA_KEY_CHECK_FAILED + ret );
253 254

    return( 0 );
255 256
}

257 258 259
/*
 * Check if contexts holding a public and private key match
 */
260
int mbedtls_rsa_check_pub_priv( const mbedtls_rsa_context *pub, const mbedtls_rsa_context *prv )
261
{
262 263
    if( mbedtls_rsa_check_pubkey( pub ) != 0 ||
        mbedtls_rsa_check_privkey( prv ) != 0 )
264
    {
265
        return( MBEDTLS_ERR_RSA_KEY_CHECK_FAILED );
266 267
    }

268 269
    if( mbedtls_mpi_cmp_mpi( &pub->N, &prv->N ) != 0 ||
        mbedtls_mpi_cmp_mpi( &pub->E, &prv->E ) != 0 )
270
    {
271
        return( MBEDTLS_ERR_RSA_KEY_CHECK_FAILED );
272 273 274 275 276
    }

    return( 0 );
}

277 278 279
/*
 * Do an RSA public key operation
 */
280
int mbedtls_rsa_public( mbedtls_rsa_context *ctx,
281
                const unsigned char *input,
282 283
                unsigned char *output )
{
284 285
    int ret;
    size_t olen;
286
    mbedtls_mpi T;
287

288
    mbedtls_mpi_init( &T );
289

290 291 292 293 294
#if defined(MBEDTLS_THREADING_C)
    if( ( ret = mbedtls_mutex_lock( &ctx->mutex ) ) != 0 )
        return( ret );
#endif

295
    MBEDTLS_MPI_CHK( mbedtls_mpi_read_binary( &T, input, ctx->len ) );
296

297
    if( mbedtls_mpi_cmp_mpi( &T, &ctx->N ) >= 0 )
298
    {
299 300
        ret = MBEDTLS_ERR_MPI_BAD_INPUT_DATA;
        goto cleanup;
301 302 303
    }

    olen = ctx->len;
304 305
    MBEDTLS_MPI_CHK( mbedtls_mpi_exp_mod( &T, &T, &ctx->E, &ctx->N, &ctx->RN ) );
    MBEDTLS_MPI_CHK( mbedtls_mpi_write_binary( &T, output, olen ) );
306 307

cleanup:
308
#if defined(MBEDTLS_THREADING_C)
309 310
    if( mbedtls_mutex_unlock( &ctx->mutex ) != 0 )
        return( MBEDTLS_ERR_THREADING_MUTEX_ERROR );
311
#endif
312

313
    mbedtls_mpi_free( &T );
314 315

    if( ret != 0 )
316
        return( MBEDTLS_ERR_RSA_PUBLIC_FAILED + ret );
317 318 319 320

    return( 0 );
}

321
/*
322 323
 * Generate or update blinding values, see section 10 of:
 *  KOCHER, Paul C. Timing attacks on implementations of Diffie-Hellman, RSA,
324
 *  DSS, and other systems. In : Advances in Cryptology-CRYPTO'96. Springer
325
 *  Berlin Heidelberg, 1996. p. 104-113.
326
 */
327
static int rsa_prepare_blinding( mbedtls_rsa_context *ctx,
328 329
                 int (*f_rng)(void *, unsigned char *, size_t), void *p_rng )
{
330
    int ret, count = 0;
331

332 333 334
    if( ctx->Vf.p != NULL )
    {
        /* We already have blinding values, just update them by squaring */
335 336 337 338
        MBEDTLS_MPI_CHK( mbedtls_mpi_mul_mpi( &ctx->Vi, &ctx->Vi, &ctx->Vi ) );
        MBEDTLS_MPI_CHK( mbedtls_mpi_mod_mpi( &ctx->Vi, &ctx->Vi, &ctx->N ) );
        MBEDTLS_MPI_CHK( mbedtls_mpi_mul_mpi( &ctx->Vf, &ctx->Vf, &ctx->Vf ) );
        MBEDTLS_MPI_CHK( mbedtls_mpi_mod_mpi( &ctx->Vf, &ctx->Vf, &ctx->N ) );
339

340
        goto cleanup;
341 342
    }

343 344 345
    /* Unblinding value: Vf = random number, invertible mod N */
    do {
        if( count++ > 10 )
346
            return( MBEDTLS_ERR_RSA_RNG_FAILED );
347

348 349 350
        MBEDTLS_MPI_CHK( mbedtls_mpi_fill_random( &ctx->Vf, ctx->len - 1, f_rng, p_rng ) );
        MBEDTLS_MPI_CHK( mbedtls_mpi_gcd( &ctx->Vi, &ctx->Vf, &ctx->N ) );
    } while( mbedtls_mpi_cmp_int( &ctx->Vi, 1 ) != 0 );
351 352

    /* Blinding value: Vi =  Vf^(-e) mod N */
353 354
    MBEDTLS_MPI_CHK( mbedtls_mpi_inv_mod( &ctx->Vi, &ctx->Vf, &ctx->N ) );
    MBEDTLS_MPI_CHK( mbedtls_mpi_exp_mod( &ctx->Vi, &ctx->Vi, &ctx->E, &ctx->N, &ctx->RN ) );
355

356

357 358 359 360
cleanup:
    return( ret );
}

361 362 363
/*
 * Do an RSA private key operation
 */
364
int mbedtls_rsa_private( mbedtls_rsa_context *ctx,
365 366
                 int (*f_rng)(void *, unsigned char *, size_t),
                 void *p_rng,
367
                 const unsigned char *input,
368 369
                 unsigned char *output )
{
370 371
    int ret;
    size_t olen;
372
    mbedtls_mpi T, T1, T2;
373

374 375 376 377
    /* Make sure we have private key info, prevent possible misuse */
    if( ctx->P.p == NULL || ctx->Q.p == NULL || ctx->D.p == NULL )
        return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );

378
    mbedtls_mpi_init( &T ); mbedtls_mpi_init( &T1 ); mbedtls_mpi_init( &T2 );
379

380 381 382
#if defined(MBEDTLS_THREADING_C)
    if( ( ret = mbedtls_mutex_lock( &ctx->mutex ) ) != 0 )
        return( ret );
383
#endif
384

385 386
    MBEDTLS_MPI_CHK( mbedtls_mpi_read_binary( &T, input, ctx->len ) );
    if( mbedtls_mpi_cmp_mpi( &T, &ctx->N ) >= 0 )
387
    {
388 389
        ret = MBEDTLS_ERR_MPI_BAD_INPUT_DATA;
        goto cleanup;
390 391
    }

392 393 394
    if( f_rng != NULL )
    {
        /*
395 396
         * Blinding
         * T = T * Vi mod N
397
         */
398 399
        MBEDTLS_MPI_CHK( rsa_prepare_blinding( ctx, f_rng, p_rng ) );
        MBEDTLS_MPI_CHK( mbedtls_mpi_mul_mpi( &T, &T, &ctx->Vi ) );
400
        MBEDTLS_MPI_CHK( mbedtls_mpi_mod_mpi( &T, &T, &ctx->N ) );
401
    }
402

403 404
#if defined(MBEDTLS_RSA_NO_CRT)
    MBEDTLS_MPI_CHK( mbedtls_mpi_exp_mod( &T, &T, &ctx->D, &ctx->N, &ctx->RN ) );
405
#else
406 407 408 409 410 411
    /*
     * faster decryption using the CRT
     *
     * T1 = input ^ dP mod P
     * T2 = input ^ dQ mod Q
     */
412 413
    MBEDTLS_MPI_CHK( mbedtls_mpi_exp_mod( &T1, &T, &ctx->DP, &ctx->P, &ctx->RP ) );
    MBEDTLS_MPI_CHK( mbedtls_mpi_exp_mod( &T2, &T, &ctx->DQ, &ctx->Q, &ctx->RQ ) );
414 415 416 417

    /*
     * T = (T1 - T2) * (Q^-1 mod P) mod P
     */
418 419 420
    MBEDTLS_MPI_CHK( mbedtls_mpi_sub_mpi( &T, &T1, &T2 ) );
    MBEDTLS_MPI_CHK( mbedtls_mpi_mul_mpi( &T1, &T, &ctx->QP ) );
    MBEDTLS_MPI_CHK( mbedtls_mpi_mod_mpi( &T, &T1, &ctx->P ) );
421 422

    /*
423
     * T = T2 + T * Q
424
     */
425 426 427
    MBEDTLS_MPI_CHK( mbedtls_mpi_mul_mpi( &T1, &T, &ctx->Q ) );
    MBEDTLS_MPI_CHK( mbedtls_mpi_add_mpi( &T, &T2, &T1 ) );
#endif /* MBEDTLS_RSA_NO_CRT */
428

429 430 431 432
    if( f_rng != NULL )
    {
        /*
         * Unblind
433
         * T = T * Vf mod N
434
         */
435
        MBEDTLS_MPI_CHK( mbedtls_mpi_mul_mpi( &T, &T, &ctx->Vf ) );
436
        MBEDTLS_MPI_CHK( mbedtls_mpi_mod_mpi( &T, &T, &ctx->N ) );
437
    }
438 439

    olen = ctx->len;
440
    MBEDTLS_MPI_CHK( mbedtls_mpi_write_binary( &T, output, olen ) );
441 442

cleanup:
443
#if defined(MBEDTLS_THREADING_C)
444 445
    if( mbedtls_mutex_unlock( &ctx->mutex ) != 0 )
        return( MBEDTLS_ERR_THREADING_MUTEX_ERROR );
446
#endif
447

448
    mbedtls_mpi_free( &T ); mbedtls_mpi_free( &T1 ); mbedtls_mpi_free( &T2 );
449 450

    if( ret != 0 )
451
        return( MBEDTLS_ERR_RSA_PRIVATE_FAILED + ret );
452 453 454 455

    return( 0 );
}

456
#if defined(MBEDTLS_PKCS1_V21)
457 458 459
/**
 * Generate and apply the MGF1 operation (from PKCS#1 v2.1) to a buffer.
 *
460 461 462 463 464
 * \param dst       buffer to mask
 * \param dlen      length of destination buffer
 * \param src       source of the mask generation
 * \param slen      length of the source buffer
 * \param md_ctx    message digest context to use
465
 */
466
static void mgf_mask( unsigned char *dst, size_t dlen, unsigned char *src,
467
                      size_t slen, mbedtls_md_context_t *md_ctx )
468
{
469
    unsigned char mask[MBEDTLS_MD_MAX_SIZE];
470 471
    unsigned char counter[4];
    unsigned char *p;
472 473
    unsigned int hlen;
    size_t i, use_len;
474

475
    memset( mask, 0, MBEDTLS_MD_MAX_SIZE );
476 477
    memset( counter, 0, 4 );

478
    hlen = mbedtls_md_get_size( md_ctx->md_info );
479 480 481 482 483 484 485 486 487 488 489

    // Generate and apply dbMask
    //
    p = dst;

    while( dlen > 0 )
    {
        use_len = hlen;
        if( dlen < hlen )
            use_len = dlen;

490 491 492 493
        mbedtls_md_starts( md_ctx );
        mbedtls_md_update( md_ctx, src, slen );
        mbedtls_md_update( md_ctx, counter, 4 );
        mbedtls_md_finish( md_ctx, mask );
494 495 496 497 498 499 500 501 502

        for( i = 0; i < use_len; ++i )
            *p++ ^= mask[i];

        counter[3]++;

        dlen -= use_len;
    }
}
503
#endif /* MBEDTLS_PKCS1_V21 */
504

505
#if defined(MBEDTLS_PKCS1_V21)
506
/*
507
 * Implementation of the PKCS#1 v2.1 RSAES-OAEP-ENCRYPT function
508
 */
509
int mbedtls_rsa_rsaes_oaep_encrypt( mbedtls_rsa_context *ctx,
510 511
                            int (*f_rng)(void *, unsigned char *, size_t),
                            void *p_rng,
512 513 514
                            int mode,
                            const unsigned char *label, size_t label_len,
                            size_t ilen,
515 516
                            const unsigned char *input,
                            unsigned char *output )
517
{
518
    size_t olen;
519
    int ret;
520
    unsigned char *p = output;
521
    unsigned int hlen;
522 523
    const mbedtls_md_info_t *md_info;
    mbedtls_md_context_t md_ctx;
524

525 526
    if( mode == MBEDTLS_RSA_PRIVATE && ctx->padding != MBEDTLS_RSA_PKCS_V21 )
        return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
527 528

    if( f_rng == NULL )
529
        return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
530

531
    md_info = mbedtls_md_info_from_type( (mbedtls_md_type_t) ctx->hash_id );
532
    if( md_info == NULL )
533
        return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
534

535
    olen = ctx->len;
536
    hlen = mbedtls_md_get_size( md_info );
537

538
    if( olen < ilen + 2 * hlen + 2 )
539
        return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
540

541
    memset( output, 0, olen );
542

543
    *p++ = 0;
544

545 546 547
    // Generate a random octet string seed
    //
    if( ( ret = f_rng( p_rng, p, hlen ) ) != 0 )
548
        return( MBEDTLS_ERR_RSA_RNG_FAILED + ret );
549

550
    p += hlen;
551

552 553
    // Construct DB
    //
554
    mbedtls_md( md_info, label, label_len, p );
555 556 557 558
    p += hlen;
    p += olen - 2 * hlen - 2 - ilen;
    *p++ = 1;
    memcpy( p, input, ilen );
559

560 561
    mbedtls_md_init( &md_ctx );
    mbedtls_md_setup( &md_ctx, md_info, 0 );
562

563 564 565 566
    // maskedDB: Apply dbMask to DB
    //
    mgf_mask( output + hlen + 1, olen - hlen - 1, output + 1, hlen,
               &md_ctx );
567

568 569 570 571
    // maskedSeed: Apply seedMask to seed
    //
    mgf_mask( output + 1, hlen, output + hlen + 1, olen - hlen - 1,
               &md_ctx );
572

573
    mbedtls_md_free( &md_ctx );
574

575 576 577
    return( ( mode == MBEDTLS_RSA_PUBLIC )
            ? mbedtls_rsa_public(  ctx, output, output )
            : mbedtls_rsa_private( ctx, f_rng, p_rng, output, output ) );
578
}
579
#endif /* MBEDTLS_PKCS1_V21 */
580

581
#if defined(MBEDTLS_PKCS1_V15)
582 583 584
/*
 * Implementation of the PKCS#1 v2.1 RSAES-PKCS1-V1_5-ENCRYPT function
 */
585
int mbedtls_rsa_rsaes_pkcs1_v15_encrypt( mbedtls_rsa_context *ctx,
586 587 588 589 590 591 592 593 594
                                 int (*f_rng)(void *, unsigned char *, size_t),
                                 void *p_rng,
                                 int mode, size_t ilen,
                                 const unsigned char *input,
                                 unsigned char *output )
{
    size_t nb_pad, olen;
    int ret;
    unsigned char *p = output;
595

596 597
    if( mode == MBEDTLS_RSA_PRIVATE && ctx->padding != MBEDTLS_RSA_PKCS_V15 )
        return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
598

599 600
    // We don't check p_rng because it won't be dereferenced here
    if( f_rng == NULL || input == NULL || output == NULL )
601
        return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
602

603
    olen = ctx->len;
604

605
    if( olen < ilen + 11 )
606
        return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
607

608
    nb_pad = olen - 3 - ilen;
609

610
    *p++ = 0;
611
    if( mode == MBEDTLS_RSA_PUBLIC )
612
    {
613
        *p++ = MBEDTLS_RSA_CRYPT;
614

615 616 617
        while( nb_pad-- > 0 )
        {
            int rng_dl = 100;
618

619 620 621
            do {
                ret = f_rng( p_rng, p, 1 );
            } while( *p == 0 && --rng_dl && ret == 0 );
622

623
            // Check if RNG failed to generate data
624
            //
625
            if( rng_dl == 0 || ret != 0 )
626
                return( MBEDTLS_ERR_RSA_RNG_FAILED + ret );
627

628 629 630 631 632
            p++;
        }
    }
    else
    {
633
        *p++ = MBEDTLS_RSA_SIGN;
634

635 636
        while( nb_pad-- > 0 )
            *p++ = 0xFF;
637 638
    }

639 640 641
    *p++ = 0;
    memcpy( p, input, ilen );

642 643 644
    return( ( mode == MBEDTLS_RSA_PUBLIC )
            ? mbedtls_rsa_public(  ctx, output, output )
            : mbedtls_rsa_private( ctx, f_rng, p_rng, output, output ) );
645
}
646
#endif /* MBEDTLS_PKCS1_V15 */
647 648

/*
649
 * Add the message padding, then do an RSA operation
650
 */
651
int mbedtls_rsa_pkcs1_encrypt( mbedtls_rsa_context *ctx,
652 653 654
                       int (*f_rng)(void *, unsigned char *, size_t),
                       void *p_rng,
                       int mode, size_t ilen,
655
                       const unsigned char *input,
656 657 658 659
                       unsigned char *output )
{
    switch( ctx->padding )
    {
660 661 662
#if defined(MBEDTLS_PKCS1_V15)
        case MBEDTLS_RSA_PKCS_V15:
            return mbedtls_rsa_rsaes_pkcs1_v15_encrypt( ctx, f_rng, p_rng, mode, ilen,
663
                                                input, output );
664
#endif
665

666 667 668
#if defined(MBEDTLS_PKCS1_V21)
        case MBEDTLS_RSA_PKCS_V21:
            return mbedtls_rsa_rsaes_oaep_encrypt( ctx, f_rng, p_rng, mode, NULL, 0,
669 670 671 672
                                           ilen, input, output );
#endif

        default:
673
            return( MBEDTLS_ERR_RSA_INVALID_PADDING );
674 675 676
    }
}

677
#if defined(MBEDTLS_PKCS1_V21)
678 679 680
/*
 * Implementation of the PKCS#1 v2.1 RSAES-OAEP-DECRYPT function
 */
681
int mbedtls_rsa_rsaes_oaep_decrypt( mbedtls_rsa_context *ctx,
682 683 684
                            int (*f_rng)(void *, unsigned char *, size_t),
                            void *p_rng,
                            int mode,
685 686
                            const unsigned char *label, size_t label_len,
                            size_t *olen,
687 688 689
                            const unsigned char *input,
                            unsigned char *output,
                            size_t output_max_len )
690
{
691
    int ret;
692 693
    size_t ilen, i, pad_len;
    unsigned char *p, bad, pad_done;
694 695
    unsigned char buf[MBEDTLS_MPI_MAX_SIZE];
    unsigned char lhash[MBEDTLS_MD_MAX_SIZE];
696
    unsigned int hlen;
697 698
    const mbedtls_md_info_t *md_info;
    mbedtls_md_context_t md_ctx;
699

700 701 702
    /*
     * Parameters sanity checks
     */
703 704
    if( mode == MBEDTLS_RSA_PRIVATE && ctx->padding != MBEDTLS_RSA_PKCS_V21 )
        return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
705 706 707

    ilen = ctx->len;

708
    if( ilen < 16 || ilen > sizeof( buf ) )
709
        return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
710

711
    md_info = mbedtls_md_info_from_type( (mbedtls_md_type_t) ctx->hash_id );
712
    if( md_info == NULL )
713
        return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
714 715 716 717

    /*
     * RSA operation
     */
718 719 720
    ret = ( mode == MBEDTLS_RSA_PUBLIC )
          ? mbedtls_rsa_public(  ctx, input, buf )
          : mbedtls_rsa_private( ctx, f_rng, p_rng, input, buf );
721 722 723 724

    if( ret != 0 )
        return( ret );

725
    /*
726
     * Unmask data and generate lHash
727
     */
728
    hlen = mbedtls_md_get_size( md_info );
729

730 731
    mbedtls_md_init( &md_ctx );
    mbedtls_md_setup( &md_ctx, md_info, 0 );
732

733
    /* Generate lHash */
734
    mbedtls_md( md_info, label, label_len, lhash );
735

736
    /* seed: Apply seedMask to maskedSeed */
737 738
    mgf_mask( buf + 1, hlen, buf + hlen + 1, ilen - hlen - 1,
               &md_ctx );
739

740
    /* DB: Apply dbMask to maskedDB */
741 742
    mgf_mask( buf + hlen + 1, ilen - hlen - 1, buf + 1, hlen,
               &md_ctx );
743

744
    mbedtls_md_free( &md_ctx );
745

746
    /*
747
     * Check contents, in "constant-time"
748 749
     */
    p = buf;
750
    bad = 0;
751

752
    bad |= *p++; /* First byte must be 0 */
753 754 755 756

    p += hlen; /* Skip seed */

    /* Check lHash */
757 758 759 760 761 762 763 764 765 766
    for( i = 0; i < hlen; i++ )
        bad |= lhash[i] ^ *p++;

    /* Get zero-padding len, but always read till end of buffer
     * (minus one, for the 01 byte) */
    pad_len = 0;
    pad_done = 0;
    for( i = 0; i < ilen - 2 * hlen - 2; i++ )
    {
        pad_done |= p[i];
767
        pad_len += ((pad_done | (unsigned char)-pad_done) >> 7) ^ 1;
768
    }
769

770 771
    p += pad_len;
    bad |= *p++ ^ 0x01;
772

773 774 775 776 777 778 779
    /*
     * The only information "leaked" is whether the padding was correct or not
     * (eg, no data is copied if it was not correct). This meets the
     * recommendations in PKCS#1 v2.2: an opponent cannot distinguish between
     * the different error conditions.
     */
    if( bad != 0 )
780
        return( MBEDTLS_ERR_RSA_INVALID_PADDING );
781

782
    if( ilen - ( p - buf ) > output_max_len )
783
        return( MBEDTLS_ERR_RSA_OUTPUT_TOO_LARGE );
784

785 786
    *olen = ilen - (p - buf);
    memcpy( output, p, *olen );
787

788 789
    return( 0 );
}
790
#endif /* MBEDTLS_PKCS1_V21 */
791

792
#if defined(MBEDTLS_PKCS1_V15)
793 794 795
/*
 * Implementation of the PKCS#1 v2.1 RSAES-PKCS1-V1_5-DECRYPT function
 */
796
int mbedtls_rsa_rsaes_pkcs1_v15_decrypt( mbedtls_rsa_context *ctx,
797 798
                                 int (*f_rng)(void *, unsigned char *, size_t),
                                 void *p_rng,
799 800 801 802 803
                                 int mode, size_t *olen,
                                 const unsigned char *input,
                                 unsigned char *output,
                                 size_t output_max_len)
{
804 805 806
    int ret;
    size_t ilen, pad_count = 0, i;
    unsigned char *p, bad, pad_done = 0;
807
    unsigned char buf[MBEDTLS_MPI_MAX_SIZE];
808

809 810
    if( mode == MBEDTLS_RSA_PRIVATE && ctx->padding != MBEDTLS_RSA_PKCS_V15 )
        return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
811 812 813 814

    ilen = ctx->len;

    if( ilen < 16 || ilen > sizeof( buf ) )
815
        return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
816

817 818 819
    ret = ( mode == MBEDTLS_RSA_PUBLIC )
          ? mbedtls_rsa_public(  ctx, input, buf )
          : mbedtls_rsa_private( ctx, f_rng, p_rng, input, buf );
820 821 822 823 824

    if( ret != 0 )
        return( ret );

    p = buf;
825
    bad = 0;
826

827 828 829 830
    /*
     * Check and get padding len in "constant-time"
     */
    bad |= *p++; /* First byte must be 0 */
831

832
    /* This test does not depend on secret data */
833
    if( mode == MBEDTLS_RSA_PRIVATE )
834
    {
835
        bad |= *p++ ^ MBEDTLS_RSA_CRYPT;
836

837 838 839 840
        /* Get padding len, but always read till end of buffer
         * (minus one, for the 00 byte) */
        for( i = 0; i < ilen - 3; i++ )
        {
841 842
            pad_done  |= ((p[i] | (unsigned char)-p[i]) >> 7) ^ 1;
            pad_count += ((pad_done | (unsigned char)-pad_done) >> 7) ^ 1;
843
        }
844

845 846
        p += pad_count;
        bad |= *p++; /* Must be zero */
847 848 849
    }
    else
    {
850
        bad |= *p++ ^ MBEDTLS_RSA_SIGN;
851

852 853 854 855
        /* Get padding len, but always read till end of buffer
         * (minus one, for the 00 byte) */
        for( i = 0; i < ilen - 3; i++ )
        {
856
            pad_done |= ( p[i] != 0xFF );
857 858
            pad_count += ( pad_done == 0 );
        }
859

860 861
        p += pad_count;
        bad |= *p++; /* Must be zero */
862 863
    }

864
    if( bad )
865
        return( MBEDTLS_ERR_RSA_INVALID_PADDING );
866

867
    if( ilen - ( p - buf ) > output_max_len )
868
        return( MBEDTLS_ERR_RSA_OUTPUT_TOO_LARGE );
869

870
    *olen = ilen - (p - buf);
871 872 873 874
    memcpy( output, p, *olen );

    return( 0 );
}
875
#endif /* MBEDTLS_PKCS1_V15 */
876 877

/*
878
 * Do an RSA operation, then remove the message padding
879
 */
880
int mbedtls_rsa_pkcs1_decrypt( mbedtls_rsa_context *ctx,
881 882
                       int (*f_rng)(void *, unsigned char *, size_t),
                       void *p_rng,
883 884 885 886
                       int mode, size_t *olen,
                       const unsigned char *input,
                       unsigned char *output,
                       size_t output_max_len)
887
{
888 889
    switch( ctx->padding )
    {
890 891 892
#if defined(MBEDTLS_PKCS1_V15)
        case MBEDTLS_RSA_PKCS_V15:
            return mbedtls_rsa_rsaes_pkcs1_v15_decrypt( ctx, f_rng, p_rng, mode, olen,
893
                                                input, output, output_max_len );
894
#endif
895

896 897 898
#if defined(MBEDTLS_PKCS1_V21)
        case MBEDTLS_RSA_PKCS_V21:
            return mbedtls_rsa_rsaes_oaep_decrypt( ctx, f_rng, p_rng, mode, NULL, 0,
899 900
                                           olen, input, output,
                                           output_max_len );
901 902 903
#endif

        default:
904
            return( MBEDTLS_ERR_RSA_INVALID_PADDING );
905 906 907
    }
}

908
#if defined(MBEDTLS_PKCS1_V21)
909 910 911
/*
 * Implementation of the PKCS#1 v2.1 RSASSA-PSS-SIGN function
 */
912
int mbedtls_rsa_rsassa_pss_sign( mbedtls_rsa_context *ctx,
913 914 915
                         int (*f_rng)(void *, unsigned char *, size_t),
                         void *p_rng,
                         int mode,
916
                         mbedtls_md_type_t md_alg,
917 918 919 920 921 922
                         unsigned int hashlen,
                         const unsigned char *hash,
                         unsigned char *sig )
{
    size_t olen;
    unsigned char *p = sig;
923
    unsigned char salt[MBEDTLS_MD_MAX_SIZE];
924 925
    unsigned int slen, hlen, offset = 0;
    int ret;
926
    size_t msb;
927 928
    const mbedtls_md_info_t *md_info;
    mbedtls_md_context_t md_ctx;
929

930 931
    if( mode == MBEDTLS_RSA_PRIVATE && ctx->padding != MBEDTLS_RSA_PKCS_V21 )
        return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
932 933

    if( f_rng == NULL )
934
        return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
935 936 937

    olen = ctx->len;

938
    if( md_alg != MBEDTLS_MD_NONE )
939
    {
940 941
        // Gather length of hash to sign
        //
942
        md_info = mbedtls_md_info_from_type( md_alg );
943
        if( md_info == NULL )
944
            return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
945

946
        hashlen = mbedtls_md_get_size( md_info );
947
    }
948

949
    md_info = mbedtls_md_info_from_type( (mbedtls_md_type_t) ctx->hash_id );
950
    if( md_info == NULL )
951
        return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
952

953
    hlen = mbedtls_md_get_size( md_info );
954
    slen = hlen;
955

956
    if( olen < hlen + slen + 2 )
957
        return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
958

959
    memset( sig, 0, olen );
960

961 962 963
    // Generate salt of length slen
    //
    if( ( ret = f_rng( p_rng, salt, slen ) ) != 0 )
964
        return( MBEDTLS_ERR_RSA_RNG_FAILED + ret );
965

966 967
    // Note: EMSA-PSS encoding is over the length of N - 1 bits
    //
968
    msb = mbedtls_mpi_bitlen( &ctx->N ) - 1;
969 970 971 972
    p += olen - hlen * 2 - 2;
    *p++ = 0x01;
    memcpy( p, salt, slen );
    p += slen;
973

974 975
    mbedtls_md_init( &md_ctx );
    mbedtls_md_setup( &md_ctx, md_info, 0 );
976

977 978
    // Generate H = Hash( M' )
    //
979 980 981 982 983
    mbedtls_md_starts( &md_ctx );
    mbedtls_md_update( &md_ctx, p, 8 );
    mbedtls_md_update( &md_ctx, hash, hashlen );
    mbedtls_md_update( &md_ctx, salt, slen );
    mbedtls_md_finish( &md_ctx, p );
984

985 986 987 988
    // Compensate for boundary condition when applying mask
    //
    if( msb % 8 == 0 )
        offset = 1;
989

990 991 992
    // maskedDB: Apply dbMask to DB
    //
    mgf_mask( sig + offset, olen - hlen - 1 - offset, p, hlen, &md_ctx );
993

994
    mbedtls_md_free( &md_ctx );
995

996
    msb = mbedtls_mpi_bitlen( &ctx->N ) - 1;
997
    sig[0] &= 0xFF >> ( olen * 8 - msb );
998

999 1000
    p += hlen;
    *p++ = 0xBC;
1001

1002 1003 1004
    return( ( mode == MBEDTLS_RSA_PUBLIC )
            ? mbedtls_rsa_public(  ctx, sig, sig )
            : mbedtls_rsa_private( ctx, f_rng, p_rng, sig, sig ) );
1005
}
1006
#endif /* MBEDTLS_PKCS1_V21 */
1007

1008
#if defined(MBEDTLS_PKCS1_V15)
1009 1010 1011 1012 1013 1014
/*
 * Implementation of the PKCS#1 v2.1 RSASSA-PKCS1-V1_5-SIGN function
 */
/*
 * Do an RSA operation to sign the message digest
 */
1015
int mbedtls_rsa_rsassa_pkcs1_v15_sign( mbedtls_rsa_context *ctx,