commit a85896c83199ceb1e6ee6901b5e2d736a3d974a3
parent 256fafd82571a2a2f4f1203b03425c805910b5ea
Author: superpozycja <anna@superpozycja.net>
Date: Tue, 25 Feb 2025 18:53:58 +0100
lib/aes: add aes decryption
Diffstat:
M | lib/aes.c | | | 187 | +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-------------------- |
M | lib/aes.h | | | 11 | ++++++++--- |
2 files changed, 149 insertions(+), 49 deletions(-)
diff --git a/lib/aes.c b/lib/aes.c
@@ -21,7 +21,7 @@ static uint8_t sbox[0x100] = {
0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, 0x41, 0x99, 0x2d, 0x0f, 0xb0, 0x54, 0xbb, 0x16
};
-static uint8_t sbox_inv[0x100] = {
+static uint8_t inv_sbox[0x100] = {
0x52, 0x09, 0x6a, 0xd5, 0x30, 0x36, 0xa5, 0x38, 0xbf, 0x40, 0xa3, 0x9e, 0x81, 0xf3, 0xd7, 0xfb,
0x7c, 0xe3, 0x39, 0x82, 0x9b, 0x2f, 0xff, 0x87, 0x34, 0x8e, 0x43, 0x44, 0xc4, 0xde, 0xe9, 0xcb,
0x54, 0x7b, 0x94, 0x32, 0xa6, 0xc2, 0x23, 0x3d, 0xee, 0x4c, 0x95, 0x0b, 0x42, 0xfa, 0xc3, 0x4e,
@@ -119,6 +119,27 @@ static int key_expansion(ba *key, unsigned int keylen, unsigned int rounds,
return 0;
}
+static uint8_t galois_mul(uint8_t a, uint8_t b)
+{
+ uint8_t p;
+ int i;
+
+ p = 0;
+ for (i = 0; i < 8; i++) {
+ uint8_t h;
+
+ if (b & 1)
+ p ^= a;
+ h = a & 0x80;
+ a <<= 1;
+ if (h)
+ a ^= 0x1b;
+ b >>= 1;
+ }
+
+ return p;
+}
+
static int add_round_key(ba *state[4], ba *round_keys[4])
{
int i;
@@ -159,28 +180,62 @@ static int shift_rows(ba *state[4])
}
}
-static uint8_t galois_mul(uint8_t a, uint8_t b)
+static int mix_columns(ba *state[4])
{
- uint8_t p;
+ ba *tmp[4];
+ int c;
int i;
- p = 0;
- for (i = 0; i < 8; i++) {
- uint8_t h;
+ for (i = 0; i < 4; i++) {
+ tmp[i] = ba_alloc(4);
+ ba_copy(tmp[i], state[i]);
+ }
- if (b & 1)
- p ^= a;
- h = a & 0x80;
- a <<= 1;
- if (h)
- a ^= 0x1b;
- b >>= 1;
+ for (c = 0; c < 4; c++) {
+ int j;
+
+ for (j = 0; j < 4; j++) {
+ state[j]->val[c] = galois_mul(tmp[(j)%4]->val[c], 2)
+ ^ galois_mul(tmp[(j+1)%4]->val[c], 3)
+ ^ tmp[(j+2)%4]->val[c]
+ ^ tmp[(j+3)%4]->val[c];
+
+ }
}
- return p;
+ for (i = 0; i < 4; i++)
+ ba_free(tmp[i]);
+
+ return 0;
}
-static int mix_columns(ba *state[4])
+static int inv_sub_bytes(ba *state[4])
+{
+ int i;
+ int j;
+
+ for (i = 0; i < 4; i++)
+ for (j = 0; j < 4; j++)
+ state[i]->val[j] = inv_sbox[state[i]->val[j]];
+
+ return 0;
+}
+
+static int inv_shift_rows(ba *state[4])
+{
+ int i;
+
+ for (i = 1; i < 4; i++) {
+ ba *tmp;
+
+ tmp = ba_alloc(4);
+ ba_copy(tmp, state[i]);
+ rot_word(tmp, &state[i], 4-i);
+ ba_free(tmp);
+ }
+}
+
+static int inv_mix_columns(ba *state[4])
{
ba *tmp[4];
int c;
@@ -195,10 +250,10 @@ static int mix_columns(ba *state[4])
int j;
for (j = 0; j < 4; j++) {
- state[j]->val[c] = galois_mul(tmp[(j)%4]->val[c], 2)
- ^ galois_mul(tmp[(j+1)%4]->val[c], 3)
- ^ tmp[(j+2)%4]->val[c]
- ^ tmp[(j+3)%4]->val[c];
+ state[j]->val[c] = galois_mul(tmp[(j)%4]->val[c], 0x0e)
+ ^ galois_mul(tmp[(j+1)%4]->val[c], 0x0b)
+ ^ galois_mul(tmp[(j+2)%4]->val[c], 0x0d)
+ ^ galois_mul(tmp[(j+3)%4]->val[c], 0x9);
}
}
@@ -222,7 +277,7 @@ static void print_state(ba *state[4])
}
static int aes_generic(unsigned int rounds, unsigned int keylen,
- ba *plaintext, ba *key, ba* ciphertext)
+ ba *plaintext, ba *key, ba **ciphertext)
{
ba *round_keys[4*(rounds + 1)];
ba *state[4] = {
@@ -248,61 +303,101 @@ static int aes_generic(unsigned int rounds, unsigned int keylen,
key_expansion(key, keylen, rounds, round_keys);
add_round_key(state, round_keys);
- print_state(state);
for (i = 1; i < rounds; i++) {
sub_bytes(state);
- print_state(state);
shift_rows(state);
- print_state(state);
mix_columns(state);
- print_state(state);
add_round_key(state, round_keys + (4 * i));
- print_state(state);
}
sub_bytes(state);
- print_state(state);
shift_rows(state);
- print_state(state);
add_round_key(state, round_keys + (4 * rounds));
- print_state(state);
- /*
- sub_bytes
- shift_rows
- mix_columns
- add_round_key
- sub_bytes
- shift_rows
- add_round_key(state, round_keys[rounds]);
- */
-
- ciphertext = ba_alloc(16);
+
+ *ciphertext = ba_alloc(16);
for (i = 0; i < 16; i++)
- ciphertext->val[i] = (state[i/4]->val)[i%4];
+ (*ciphertext)->val[i] = (state[i%4]->val)[i/4];
+ return 0;
+}
- printf("%x\n", galois_mul(0xd4, 2) ^ galois_mul(0xbf, 3) ^ 0x5d ^ 0x30);
- ba_fprint(ciphertext, stdout, 0);
- printf("\n");
+static int aes_inv_generic(unsigned int rounds, unsigned int keylen,
+ ba *ciphertext, ba *key, ba **plaintext)
+{
+ ba *round_keys[4*(rounds + 1)];
+ ba *state[4] = {
+ ba_alloc(4),
+ ba_alloc(4),
+ ba_alloc(4),
+ ba_alloc(4),
+ };
+ int i;
+
+ if (ciphertext->len != 16) {
+ printf("invalid block len\n");
+ return -EINVAL;
+ }
+ if (key->len != keylen) {
+ printf("invalid keylen\n");
+ return -EINVAL;
+ }
+
+ for (i = 0; i < 16; i++)
+ (state[i%4]->val)[i/4] = ciphertext->val[i];
+
+ key_expansion(key, keylen, rounds, round_keys);
+
+ add_round_key(state, round_keys + (4 * rounds));
+
+ for (i = rounds - 1; i > 0; i--) {
+ inv_shift_rows(state);
+ inv_sub_bytes(state);
+ add_round_key(state, round_keys + (4 * i));
+ inv_mix_columns(state);
+ }
+
+ inv_shift_rows(state);
+ inv_sub_bytes(state);
+
+ add_round_key(state, round_keys);
+
+ *plaintext = ba_alloc(16);
+
+ for (i = 0; i < 16; i++)
+ (*plaintext)->val[i] = (state[i%4]->val)[i/4];
return 0;
}
-int aes_128_encrypt(ba *plaintext, ba *key, ba *ciphertext)
+int aes_128_encrypt(ba *plaintext, ba *key, ba **ciphertext)
{
return aes_generic(10, 16, plaintext, key, ciphertext);
}
-int aes_192_encrypt(ba *plaintext, ba *key, ba *ciphertext)
+int aes_128_decrypt(ba *ciphertext, ba *key, ba **plaintext)
+{
+ return aes_inv_generic(10, 16, ciphertext, key, plaintext);
+}
+
+int aes_192_encrypt(ba *plaintext, ba *key, ba **ciphertext)
{
return aes_generic(12, 24, plaintext, key, ciphertext);
}
-int aes_256_encrypt(ba *plaintext, ba *key, ba *ciphertext)
+int aes_192_decrypt(ba *ciphertext, ba *key, ba **plaintext)
+{
+ return aes_inv_generic(12, 24, ciphertext, key, plaintext);
+}
+
+int aes_256_encrypt(ba *plaintext, ba *key, ba **ciphertext)
{
return aes_generic(14, 32, plaintext, key, ciphertext);
}
+int aes_256_decrypt(ba *ciphertext, ba *key, ba **plaintext)
+{
+ return aes_inv_generic(14, 32, ciphertext, key, plaintext);
+}
diff --git a/lib/aes.h b/lib/aes.h
@@ -7,8 +7,13 @@
#include <errno.h>
#include "ba.h"
-int aes_128_encrypt(ba *plaintext, ba *key, ba *ciphertext);
-int aes_192_encrypt(ba *plaintext, ba *key, ba *ciphertext);
-int aes_256_encrypt(ba *plaintext, ba *key, ba *ciphertext);
+int aes_128_encrypt(ba *plaintext, ba *key, ba **ciphertext);
+int aes_128_decrypt(ba *ciphertext, ba *key, ba **plaintext);
+
+int aes_192_encrypt(ba *plaintext, ba *key, ba **ciphertext);
+int aes_192_decrypt(ba *ciphertext, ba *key, ba **plaintext);
+
+int aes_256_encrypt(ba *plaintext, ba *key, ba **ciphertext);
+int aes_256_decrypt(ba *ciphertext, ba *key, ba **plaintext);
#endif