Skip to content

Commit

Permalink
Reduce number of malloc/free call in XMSS/external (#1724)
Browse files Browse the repository at this point in the history
* remove unused file

* move malloc from prf and prf_keygen to external, reduce number of malloc/free calls

* push malloc/free to top level function

* continue to move malloc/free to upper level

* clean up

* modify TODO to TODO(from upstream)

* make astyle happy

* clean up

* use malloc and NULL check
  • Loading branch information
ducnguyen-sb authored and SWilson4 committed May 14, 2024
1 parent b45415c commit ba63672
Show file tree
Hide file tree
Showing 9 changed files with 151 additions and 363 deletions.
3 changes: 0 additions & 3 deletions src/sig_stfl/sig_stfl.c
Original file line number Diff line number Diff line change
Expand Up @@ -912,7 +912,6 @@ OQS_API OQS_STATUS OQS_SIG_STFL_sign(const OQS_SIG_STFL *sig, uint8_t *signature
#endif
}


OQS_API OQS_STATUS OQS_SIG_STFL_verify(const OQS_SIG_STFL *sig, const uint8_t *message, size_t message_len, const uint8_t *signature, size_t signature_len, const uint8_t *public_key) {
if (sig == NULL || sig->verify == NULL || sig->verify(message, message_len, signature, signature_len, public_key) != 0) {
return OQS_ERROR;
Expand All @@ -921,7 +920,6 @@ OQS_API OQS_STATUS OQS_SIG_STFL_verify(const OQS_SIG_STFL *sig, const uint8_t *m
}
}


OQS_API OQS_STATUS OQS_SIG_STFL_sigs_remaining(const OQS_SIG_STFL *sig, unsigned long long *remain, const OQS_SIG_STFL_SECRET_KEY *secret_key) {
#ifndef OQS_ALLOW_SFTL_KEY_AND_SIG_GEN
(void)sig;
Expand All @@ -937,7 +935,6 @@ OQS_API OQS_STATUS OQS_SIG_STFL_sigs_remaining(const OQS_SIG_STFL *sig, unsigned
#endif //OQS_ALLOW_SFTL_KEY_AND_SIG_GEN
}


OQS_API OQS_STATUS OQS_SIG_STFL_sigs_total(const OQS_SIG_STFL *sig, unsigned long long *max, const OQS_SIG_STFL_SECRET_KEY *secret_key) {
#ifndef OQS_ALLOW_SFTL_KEY_AND_SIG_GEN
(void)sig;
Expand Down
44 changes: 17 additions & 27 deletions src/sig_stfl/xmss/external/hash.c
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,15 @@ void addr_to_bytes(unsigned char *bytes, const uint32_t addr[8])
*/
int prf(const xmss_params *params,
unsigned char *out, const unsigned char in[32],
const unsigned char *key)
const unsigned char *key,
unsigned char *buf)
{
unsigned char* buf = malloc(params->padding_len + params->n + 32);

ull_to_bytes(buf, params->padding_len, XMSS_HASH_PADDING_PRF);
memcpy(buf + params->padding_len, key, params->n);
memcpy(buf + params->padding_len + params->n, in, 32);

int ret = core_hash(params, out, buf, params->padding_len + params->n + 32);

OQS_MEM_insecure_free(buf);

return ret;
}
Expand All @@ -50,18 +48,15 @@ int prf(const xmss_params *params,
*/
int prf_keygen(const xmss_params *params,
unsigned char *out, const unsigned char *in,
const unsigned char *key)
const unsigned char *key,
unsigned char *buf)
{
unsigned char *buf = malloc(params->padding_len + 2*params->n + 32);

ull_to_bytes(buf, params->padding_len, XMSS_HASH_PADDING_PRF_KEYGEN);
memcpy(buf + params->padding_len, key, params->n);
memcpy(buf + params->padding_len + params->n, in, params->n + 32);

int ret = core_hash(params, out, buf, params->padding_len + 2*params->n + 32);

OQS_MEM_insecure_free(buf);

return ret;
}

Expand Down Expand Up @@ -92,12 +87,11 @@ int hash_message(const xmss_params *params, unsigned char *out,
*/
int thash_h(const xmss_params *params,
unsigned char *out, const unsigned char *in,
const unsigned char *pub_seed, uint32_t addr[8])
const unsigned char *pub_seed, uint32_t addr[8],
unsigned char *buf)
{
unsigned char *tmp = malloc(params->padding_len + 3 * params->n + 2 * params->n);

unsigned char *buf = tmp;
unsigned char *bitmask = tmp + (params->padding_len + 3 * params->n);
unsigned char *bitmask = buf + (params->padding_len + 3 * params->n);
unsigned char *prf_buf = bitmask + 2*params->n;

unsigned char addr_as_bytes[32];
unsigned int i;
Expand All @@ -108,34 +102,32 @@ int thash_h(const xmss_params *params,
/* Generate the n-byte key. */
set_key_and_mask(addr, 0);
addr_to_bytes(addr_as_bytes, addr);
prf(params, buf + params->padding_len, addr_as_bytes, pub_seed);
prf(params, buf + params->padding_len, addr_as_bytes, pub_seed, prf_buf);

/* Generate the 2n-byte mask. */
set_key_and_mask(addr, 1);
addr_to_bytes(addr_as_bytes, addr);
prf(params, bitmask, addr_as_bytes, pub_seed);
prf(params, bitmask, addr_as_bytes, pub_seed, prf_buf);

set_key_and_mask(addr, 2);
addr_to_bytes(addr_as_bytes, addr);
prf(params, bitmask + params->n, addr_as_bytes, pub_seed);
prf(params, bitmask + params->n, addr_as_bytes, pub_seed, prf_buf);

for (i = 0; i < 2 * params->n; i++) {
buf[params->padding_len + params->n + i] = in[i] ^ bitmask[i];
}
int ret = core_hash(params, out, buf, params->padding_len + 3 * params->n);

OQS_MEM_insecure_free(tmp);

return ret;
}

int thash_f(const xmss_params *params,
unsigned char *out, const unsigned char *in,
const unsigned char *pub_seed, uint32_t addr[8])
const unsigned char *pub_seed, uint32_t addr[8],
unsigned char *buf)
{
unsigned char *tmp = malloc(params->padding_len + 2 * params->n + params->n);
unsigned char *buf = tmp;
unsigned char *bitmask = tmp + (params->padding_len + 2 * params->n);
unsigned char *bitmask = buf + (params->padding_len + 2 * params->n);
unsigned char *prf_buf = bitmask + params->n;

unsigned char addr_as_bytes[32];
unsigned int i;
Expand All @@ -146,19 +138,17 @@ int thash_f(const xmss_params *params,
/* Generate the n-byte key. */
set_key_and_mask(addr, 0);
addr_to_bytes(addr_as_bytes, addr);
prf(params, buf + params->padding_len, addr_as_bytes, pub_seed);
prf(params, buf + params->padding_len, addr_as_bytes, pub_seed, prf_buf);

/* Generate the n-byte mask. */
set_key_and_mask(addr, 1);
addr_to_bytes(addr_as_bytes, addr);
prf(params, bitmask, addr_as_bytes, pub_seed);
prf(params, bitmask, addr_as_bytes, pub_seed, prf_buf);

for (i = 0; i < params->n; i++) {
buf[params->padding_len + params->n + i] = in[i] ^ bitmask[i];
}
int ret = core_hash(params, out, buf, params->padding_len + 2 * params->n);

OQS_MEM_insecure_free(tmp);

return ret;
}
12 changes: 8 additions & 4 deletions src/sig_stfl/xmss/external/hash.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@ void addr_to_bytes(unsigned char *bytes, const uint32_t addr[8]);
#define prf XMSS_INNER_NAMESPACE(prf)
int prf(const xmss_params *params,
unsigned char *out, const unsigned char in[32],
const unsigned char *key);
const unsigned char *key,
unsigned char *buf);

#define prf_keygen XMSS_INNER_NAMESPACE(prf_keygen)
int prf_keygen(const xmss_params *params,
unsigned char *out, const unsigned char *in,
const unsigned char *key);
const unsigned char *key,
unsigned char *buf);

#define h_msg XMSS_INNER_NAMESPACE(h_msg)
int h_msg(const xmss_params *params,
Expand All @@ -28,12 +30,14 @@ int h_msg(const xmss_params *params,
#define thash_h XMSS_INNER_NAMESPACE(thash_h)
int thash_h(const xmss_params *params,
unsigned char *out, const unsigned char *in,
const unsigned char *pub_seed, uint32_t addr[8]);
const unsigned char *pub_seed, uint32_t addr[8],
unsigned char *buf);

#define thash_f XMSS_INNER_NAMESPACE(thash_f)
int thash_f(const xmss_params *params,
unsigned char *out, const unsigned char *in,
const unsigned char *pub_seed, uint32_t addr[8]);
const unsigned char *pub_seed, uint32_t addr[8],
unsigned char *buf);

#define hash_message XMSS_INNER_NAMESPACE(hash_message)
int hash_message(const xmss_params *params, unsigned char *out,
Expand Down
4 changes: 2 additions & 2 deletions src/sig_stfl/xmss/external/params.c
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ int xmss_parse_oid(xmss_params *params, const uint32_t oid)
params->d = 1;
params->wots_w = 16;

// TODO figure out sensible and legal values for this based on the above
// TODO (from upstream) figure out sensible and legal values for this based on the above
params->bds_k = 0;

return xmss_xmssmt_initialize_params(params);
Expand Down Expand Up @@ -692,7 +692,7 @@ int xmssmt_parse_oid(xmss_params *params, const uint32_t oid)

params->wots_w = 16;

// TODO figure out sensible and legal values for this based on the above
// TODO (from upstream) figure out sensible and legal values for this based on the above
params->bds_k = 0;

return xmss_xmssmt_initialize_params(params);
Expand Down
46 changes: 33 additions & 13 deletions src/sig_stfl/xmss/external/wots.c
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,20 @@
*/
static void expand_seed(const xmss_params *params,
unsigned char *outseeds, const unsigned char *inseed,
const unsigned char *pub_seed, uint32_t addr[8])
const unsigned char *pub_seed, uint32_t addr[8],
unsigned char *buf)
{
unsigned int i;
unsigned char *buf = malloc(params->n + 32);
unsigned char *prf_buf = buf + params->n + 32;

set_hash_addr(addr, 0);
set_key_and_mask(addr, 0);
memcpy(buf, pub_seed, params->n);
for (i = 0; i < params->wots_len; i++) {
set_chain_addr(addr, i);
addr_to_bytes(buf + params->n, addr);
prf_keygen(params, outseeds + i*params->n, buf, inseed);
prf_keygen(params, outseeds + i*params->n, buf, inseed, prf_buf);
}

OQS_MEM_insecure_free(buf);
}

/**
Expand All @@ -41,7 +40,8 @@ static void expand_seed(const xmss_params *params,
static void gen_chain(const xmss_params *params,
unsigned char *out, const unsigned char *in,
unsigned int start, unsigned int steps,
const unsigned char *pub_seed, uint32_t addr[8])
const unsigned char *pub_seed, uint32_t addr[8],
unsigned char *thash_buf)
{
unsigned int i;

Expand All @@ -51,7 +51,7 @@ static void gen_chain(const xmss_params *params,
/* Iterate 'steps' calls to the hash function. */
for (i = start; i < (start+steps) && i < params->wots_w; i++) {
set_hash_addr(addr, i);
thash_f(params, out, out, pub_seed, addr);
thash_f(params, out, out, pub_seed, addr, thash_buf);
}
}

Expand Down Expand Up @@ -88,6 +88,9 @@ static void wots_checksum(const xmss_params *params,
int csum = 0;
unsigned int csum_bytes_length = (params->wots_len2 * params->wots_log_w + 7) / 8;
unsigned char *csum_bytes = malloc(csum_bytes_length);
if (csum_bytes == NULL) {
return;
}
unsigned int i;

/* Compute checksum. */
Expand Down Expand Up @@ -125,15 +128,21 @@ void wots_pkgen(const xmss_params *params,
const unsigned char *pub_seed, uint32_t addr[8])
{
unsigned int i;

unsigned char *buf = malloc(2 * params->padding_len + 4 * params->n + 64);
if (buf == NULL) {
return;
}

/* The WOTS+ private key is derived from the seed. */
expand_seed(params, pk, seed, pub_seed, addr);
expand_seed(params, pk, seed, pub_seed, addr, buf);

for (i = 0; i < params->wots_len; i++) {
set_chain_addr(addr, i);
gen_chain(params, pk + i*params->n, pk + i*params->n,
0, params->wots_w - 1, pub_seed, addr);
0, params->wots_w - 1, pub_seed, addr, buf);
}

OQS_MEM_insecure_free(buf);
}

/**
Expand All @@ -146,20 +155,25 @@ void wots_sign(const xmss_params *params,
uint32_t addr[8])
{
unsigned int *lengths = calloc(params->wots_len, sizeof(unsigned int));
unsigned char *buf = malloc(2 * params->padding_len + 4 * params->n + 64);
unsigned int i;
if (lengths == NULL || buf == NULL) {
return;
}

chain_lengths(params, lengths, msg);

/* The WOTS+ private key is derived from the seed. */
expand_seed(params, sig, seed, pub_seed, addr);
expand_seed(params, sig, seed, pub_seed, addr, buf);

for (i = 0; i < params->wots_len; i++) {
set_chain_addr(addr, i);
gen_chain(params, sig + i*params->n, sig + i*params->n,
0, lengths[i], pub_seed, addr);
0, lengths[i], pub_seed, addr, buf);
}

OQS_MEM_insecure_free(lengths);
OQS_MEM_insecure_free(buf);
}

/**
Expand All @@ -172,15 +186,21 @@ void wots_pk_from_sig(const xmss_params *params, unsigned char *pk,
const unsigned char *pub_seed, uint32_t addr[8])
{
unsigned int *lengths = calloc(params->wots_len, sizeof(unsigned int ));
const size_t thash_buf_len = 2 * params->padding_len + 4 * params->n + 32;
unsigned char *thash_buf = malloc(thash_buf_len);
unsigned int i;
if (lengths == NULL || thash_buf == NULL) {
return;
}

chain_lengths(params, lengths, msg);

for (i = 0; i < params->wots_len; i++) {
set_chain_addr(addr, i);
gen_chain(params, pk + i*params->n, sig + i*params->n,
lengths[i], params->wots_w - 1 - lengths[i], pub_seed, addr);
lengths[i], params->wots_w - 1 - lengths[i], pub_seed, addr, thash_buf);
}

OQS_MEM_insecure_free(lengths);
OQS_MEM_insecure_free(thash_buf);
}
Loading

0 comments on commit ba63672

Please sign in to comment.