ristretto.cxx 7.87 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/**
 * @file ristretto.cxx
 * @author Mike Hamburg
 *
 * @copyright
 *   Copyright (c) 2015 Cryptography Research, Inc.  \n
 *   Released under the MIT License.  See LICENSE.txt for license information.
 *
 * @brief Ristretto implementation widget
 */

#include <decaf.hxx>
#include <stdio.h>
using namespace decaf;

16
static inline int hexi(char c) {
17 18 19
    if (c >= '0' && c <= '9') return c-'0';
    if (c >= 'a' && c <= 'f') return c-'a'+0xa;
    if (c >= 'A' && c <= 'F') return c-'A'+0xa;
20 21 22
    return -1;
}

23
static int parsehex(uint8_t *out, size_t sizeof_out, const char *hex) {
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
    size_t l = strlen(hex);
    if (l%2 != 0) {
        fprintf(stderr,"String should be hex, but has odd length\n: %s\n", hex);
        return -1;
    } else if (l/2 > sizeof_out) {
        fprintf(stderr,"Argument is too long: %s\n", hex);
        return -1;
    }
    
    memset(out,0,sizeof_out);
    int ret1,ret2;
    for (size_t i=0; i<l/2; i++) {
        if (   (ret1 = hexi(hex[2*i  ])) < 0
        || (ret2 = hexi(hex[2*i+1])) < 0) {
            fprintf(stderr,"Invalid hex %s\n",hex);
            return -1;
        }
        out[i] = ret1*16+ret2;
    }
    return 0;
}

46
static void printhex(const uint8_t *in, size_t sizeof_in) {
47 48 49 50 51 52 53 54 55 56 57
    for (; sizeof_in > 0; in++,sizeof_in--) {
        printf("%02x",*in);
    }
}


static int g_argc = 0;
static char **g_argv = NULL;
static int error = 0;
static int done = 0;

58
static void usage() {
59 60 61 62 63 64 65 66 67 68
    const char *me=g_argv[0];
    if (!me) me = "ristretto";
    for (unsigned i=0; g_argv[0][i]; i++) {
        if (g_argv[0][i] == '/' && g_argv[0][i+1] != 0 && g_argv[0][i+1] != '/') {
            me = &g_argv[0][i];
        }
    }
    
    fprintf(stderr,"Usage: %s [points] [operations] ...\n", me);
    fprintf(stderr,"  -b 255|448: Set which group to use (sometimes inferred from lengths)\n");
69 70 71
    fprintf(stderr,"  -E: Display output as Elligator inverses\n");
    fprintf(stderr,"  -D: Display output in EdDSA format (times clearing ratio)\n");
    fprintf(stderr,"  -R: Display raw xyzt\n");
72
    fprintf(stderr,"  -C: Display output in X[25519|448] format\n");
73
    fprintf(stderr,"  -H: ... divide by encoding ratio first\n");
74 75 76 77 78
    fprintf(stderr,"\n");
    fprintf(stderr,"  Ways to create points:\n");
    fprintf(stderr,"    [hex]: Point from point data as hex\n");
    fprintf(stderr,"    -e [hex]: Create point by hashing to curve using elligator\n");
    fprintf(stderr,"    base: Base point of curve\n");
79
    fprintf(stderr,"    identity: Identity point of curve\n");
80 81 82 83
    fprintf(stderr,"\n");
    fprintf(stderr,"  Operations:\n");
    fprintf(stderr,"    -n [point]: negative of point\n");
    fprintf(stderr,"    -s [scalar] * [point]: Hash to curve using elligator\n");
84
    fprintf(stderr,"    [point] + [point]: Add two points\n");
85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
    fprintf(stderr,"\n");
    fprintf(stderr,"  NB: this is a debugging widget.  It doesn't yet have order of operations.\n");
    fprintf(stderr,"  *** DON'T USE THIS UTILITY FOR ACTUAL CRYPTO! ***\n");
    fprintf(stderr,"  It's only for debugging!\n");
    fprintf(stderr,"\n");
    
    exit(-2);
}

template<typename Group> class Run {
public:
    static void run() {
        uint8_t tmp[Group::Point::SER_BYTES];
        typename Group::Point a,b;
        typename Group::Scalar s;
100
        bool plus=false, empty=true, elligator=false, mul=false, scalar=false, div=false, torque=false,
101
            scalarempty=true, neg=false, einv=false, like_eddsa=false, like_x=false, decoeff=false, raw=false;
102 103 104 105 106 107 108 109 110 111 112 113
        if (done || error) return;
        for (int i=1; i<g_argc && !error; i++) {
            bool point = false;
            
            if (!strcmp(g_argv[i],"-b") && ++i<g_argc) {
                if (atoi(g_argv[i]) == Group::bits()) continue;
                else return;
            } else if (!strcmp(g_argv[i],"+")) {
                if (elligator || scalar || empty) usage();
                plus = true;
            } else if (!strcmp(g_argv[i],"-n")) {
                neg = !neg;
114 115 116 117 118 119
            } else if (!strcmp(g_argv[i],"-E")) {
                einv = true;
            } else if (!strcmp(g_argv[i],"-R")) {
                raw = true;
            } else if (!strcmp(g_argv[i],"-D")) {
                like_eddsa = true;
120 121
            } else if (!strcmp(g_argv[i],"-C")) {
                like_x = true;
122 123
            } else if (!strcmp(g_argv[i],"-H")) {
                decoeff = true;
124 125
            } else if (!strcmp(g_argv[i],"-T")) {
                torque = true;
126
            } else if (!strcmp(g_argv[i],"*")) {
127
                if (elligator || scalar || scalarempty || div) usage();
128
                mul = true;
129 130 131
            } else if (!strcmp(g_argv[i],"/")) {
                if (elligator || scalar || scalarempty || mul) usage();
                div = true;
132 133 134 135 136 137 138 139 140 141
            } else if (!strcmp(g_argv[i],"-s")) {
                if (elligator || scalar || !scalarempty) usage();
                scalar = true;
            } else if (!strcmp(g_argv[i],"-e")) {
                if (elligator || scalar) usage();
                elligator = true;
            } else if (!strcmp(g_argv[i],"base")) {
                if (elligator || scalar) usage();
                b = b.base();
                point = true;
142 143 144 145
            } else if (!strcmp(g_argv[i],"identity")) {
                if (elligator || scalar) usage();
                b = b.identity();
                point = true;
146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163
            } else if ((strlen(g_argv[i]) == 2*sizeof(tmp)
                    || ((scalar || elligator) && strlen(g_argv[i]) <= 2*sizeof(tmp)))
                        && !(error=parsehex(tmp,sizeof(tmp),g_argv[i]))) {
                if (scalar) {
                    s = Block(tmp,sizeof(tmp)); scalar=false; scalarempty=false;
                } else if (elligator) {
                    point = true;
                    b.set_to_hash(Block(tmp,sizeof(tmp))); elligator=false;
                } else if (DECAF_SUCCESS != b.decode(Block(tmp,sizeof(tmp)))) {
                    fprintf(stderr,"Error: %s isn't in the group\n",g_argv[i]);
                    error = -1;
                } else {
                    point = true;
                }
            } else if (error || !empty) usage();

            if (point) {
                if (neg) { b = -b; neg = false; }
164 165
                if (div) { b /= s; div=false; }
                if (torque) { b = b.debugging_torque(); torque=false; }
166 167 168 169 170 171 172 173
                if (mul) { b *= s; mul=false; }
                if (empty) { a = b; empty=false; }
                else if (plus) { a += b; plus=false; }
                else usage();
            }
        }
        
        if (!error && !empty) {
174 175 176 177 178 179 180 181 182 183 184 185 186 187
            if (einv) {
                uint8_t buffer[Group::Point::HASH_BYTES];
                for (int h=0; h<1<<Group::Point::INVERT_ELLIGATOR_WHICH_BITS; h++) {
                    if (DECAF_SUCCESS == a.invert_elligator(
                        Buffer(buffer,sizeof(buffer)), h
                    )) {
                        printhex(buffer,sizeof(buffer));
                        printf("\n");
                    }
                }
            } else if (raw) {
                printhex((const uint8_t *)&a, sizeof(a));
                printf("\n");
            } else if (like_eddsa) {
188 189
                if (decoeff) a /= (Group::Point::EDDSA_ENCODE_RATIO);
                SecureBuffer b = a.mul_by_ratio_and_encode_like_eddsa();
190 191
                printhex(b.data(),b.size());
                printf("\n");
192
            } else if (like_x) {
193 194
                if (decoeff) a /= (Group::Point::LADDER_ENCODE_RATIO);
                SecureBuffer b = a.mul_by_ratio_and_encode_like_ladder();
195 196
                printhex(b.data(),b.size());
                printf("\n");
197 198 199 200 201
            } else {
                a.serialize_into(tmp);
                printhex(tmp,sizeof(tmp));
                printf("\n");
            }
202 203 204 205 206 207 208 209 210 211 212 213 214
            done = true;
        }
        
    }
};

int main(int argc, char **argv) {
    g_argc = argc;
    g_argv = argv;
    run_for_all_curves<Run>();
    if (!done) usage();
    return (error<0) ? -error : error;
}