cryptopals_c

cryptopals crypto challenges solutions in pure c
git clone git://git.superpozycja.net/cryptopals_c
Log | Files | Refs | README

aes.c (9942B)


      1 /* adapted from https://csrc.nist.gov/pubs/fips/197/final */
      2 
      3 #include "aes.h"
      4 
      5 static uint8_t sbox[0x100] = {
      6 	0x63, 0x7c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5, 0x30, 0x01, 0x67, 0x2b, 0xfe, 0xd7, 0xab, 0x76,
      7 	0xca, 0x82, 0xc9, 0x7d, 0xfa, 0x59, 0x47, 0xf0, 0xad, 0xd4, 0xa2, 0xaf, 0x9c, 0xa4, 0x72, 0xc0,
      8 	0xb7, 0xfd, 0x93, 0x26, 0x36, 0x3f, 0xf7, 0xcc, 0x34, 0xa5, 0xe5, 0xf1, 0x71, 0xd8, 0x31, 0x15,
      9 	0x04, 0xc7, 0x23, 0xc3, 0x18, 0x96, 0x05, 0x9a, 0x07, 0x12, 0x80, 0xe2, 0xeb, 0x27, 0xb2, 0x75,
     10 	0x09, 0x83, 0x2c, 0x1a, 0x1b, 0x6e, 0x5a, 0xa0, 0x52, 0x3b, 0xd6, 0xb3, 0x29, 0xe3, 0x2f, 0x84,
     11 	0x53, 0xd1, 0x00, 0xed, 0x20, 0xfc, 0xb1, 0x5b, 0x6a, 0xcb, 0xbe, 0x39, 0x4a, 0x4c, 0x58, 0xcf,
     12 	0xd0, 0xef, 0xaa, 0xfb, 0x43, 0x4d, 0x33, 0x85, 0x45, 0xf9, 0x02, 0x7f, 0x50, 0x3c, 0x9f, 0xa8,
     13 	0x51, 0xa3, 0x40, 0x8f, 0x92, 0x9d, 0x38, 0xf5, 0xbc, 0xb6, 0xda, 0x21, 0x10, 0xff, 0xf3, 0xd2,
     14 	0xcd, 0x0c, 0x13, 0xec, 0x5f, 0x97, 0x44, 0x17, 0xc4, 0xa7, 0x7e, 0x3d, 0x64, 0x5d, 0x19, 0x73,
     15 	0x60, 0x81, 0x4f, 0xdc, 0x22, 0x2a, 0x90, 0x88, 0x46, 0xee, 0xb8, 0x14, 0xde, 0x5e, 0x0b, 0xdb,
     16 	0xe0, 0x32, 0x3a, 0x0a, 0x49, 0x06, 0x24, 0x5c, 0xc2, 0xd3, 0xac, 0x62, 0x91, 0x95, 0xe4, 0x79,
     17 	0xe7, 0xc8, 0x37, 0x6d, 0x8d, 0xd5, 0x4e, 0xa9, 0x6c, 0x56, 0xf4, 0xea, 0x65, 0x7a, 0xae, 0x08,
     18 	0xba, 0x78, 0x25, 0x2e, 0x1c, 0xa6, 0xb4, 0xc6, 0xe8, 0xdd, 0x74, 0x1f, 0x4b, 0xbd, 0x8b, 0x8a,
     19 	0x70, 0x3e, 0xb5, 0x66, 0x48, 0x03, 0xf6, 0x0e, 0x61, 0x35, 0x57, 0xb9, 0x86, 0xc1, 0x1d, 0x9e,
     20 	0xe1, 0xf8, 0x98, 0x11, 0x69, 0xd9, 0x8e, 0x94, 0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf,
     21 	0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, 0x41, 0x99, 0x2d, 0x0f, 0xb0, 0x54, 0xbb, 0x16
     22 };
     23 
     24 static uint8_t inv_sbox[0x100] = {
     25 	0x52, 0x09, 0x6a, 0xd5, 0x30, 0x36, 0xa5, 0x38, 0xbf, 0x40, 0xa3, 0x9e, 0x81, 0xf3, 0xd7, 0xfb,
     26 	0x7c, 0xe3, 0x39, 0x82, 0x9b, 0x2f, 0xff, 0x87, 0x34, 0x8e, 0x43, 0x44, 0xc4, 0xde, 0xe9, 0xcb,
     27 	0x54, 0x7b, 0x94, 0x32, 0xa6, 0xc2, 0x23, 0x3d, 0xee, 0x4c, 0x95, 0x0b, 0x42, 0xfa, 0xc3, 0x4e,
     28 	0x08, 0x2e, 0xa1, 0x66, 0x28, 0xd9, 0x24, 0xb2, 0x76, 0x5b, 0xa2, 0x49, 0x6d, 0x8b, 0xd1, 0x25,
     29 	0x72, 0xf8, 0xf6, 0x64, 0x86, 0x68, 0x98, 0x16, 0xd4, 0xa4, 0x5c, 0xcc, 0x5d, 0x65, 0xb6, 0x92,
     30 	0x6c, 0x70, 0x48, 0x50, 0xfd, 0xed, 0xb9, 0xda, 0x5e, 0x15, 0x46, 0x57, 0xa7, 0x8d, 0x9d, 0x84,
     31 	0x90, 0xd8, 0xab, 0x00, 0x8c, 0xbc, 0xd3, 0x0a, 0xf7, 0xe4, 0x58, 0x05, 0xb8, 0xb3, 0x45, 0x06,
     32 	0xd0, 0x2c, 0x1e, 0x8f, 0xca, 0x3f, 0x0f, 0x02, 0xc1, 0xaf, 0xbd, 0x03, 0x01, 0x13, 0x8a, 0x6b,
     33 	0x3a, 0x91, 0x11, 0x41, 0x4f, 0x67, 0xdc, 0xea, 0x97, 0xf2, 0xcf, 0xce, 0xf0, 0xb4, 0xe6, 0x73,
     34 	0x96, 0xac, 0x74, 0x22, 0xe7, 0xad, 0x35, 0x85, 0xe2, 0xf9, 0x37, 0xe8, 0x1c, 0x75, 0xdf, 0x6e,
     35 	0x47, 0xf1, 0x1a, 0x71, 0x1d, 0x29, 0xc5, 0x89, 0x6f, 0xb7, 0x62, 0x0e, 0xaa, 0x18, 0xbe, 0x1b,
     36 	0xfc, 0x56, 0x3e, 0x4b, 0xc6, 0xd2, 0x79, 0x20, 0x9a, 0xdb, 0xc0, 0xfe, 0x78, 0xcd, 0x5a, 0xf4,
     37 	0x1f, 0xdd, 0xa8, 0x33, 0x88, 0x07, 0xc7, 0x31, 0xb1, 0x12, 0x10, 0x59, 0x27, 0x80, 0xec, 0x5f,
     38 	0x60, 0x51, 0x7f, 0xa9, 0x19, 0xb5, 0x4a, 0x0d, 0x2d, 0xe5, 0x7a, 0x9f, 0x93, 0xc9, 0x9c, 0xef,
     39 	0xa0, 0xe0, 0x3b, 0x4d, 0xae, 0x2a, 0xf5, 0xb0, 0xc8, 0xeb, 0xbb, 0x3c, 0x83, 0x53, 0x99, 0x61,
     40 	0x17, 0x2b, 0x04, 0x7e, 0xba, 0x77, 0xd6, 0x26, 0xe1, 0x69, 0x14, 0x63, 0x55, 0x21, 0x0c, 0x7d
     41 };
     42 
     43 static int rot_word(ba *input, ba **output, unsigned int offset) {
     44 	int i;
     45 
     46 	for (i = 0; i < 4; i++) {
     47 		(*output)->val[i] = input->val[(i+offset)%4];
     48 	}
     49 }
     50 
     51 static int sub_word(ba *input, ba **output) {
     52 	int i;
     53 
     54 	for (i = 0; i < 4; i++) {
     55 		(*output)->val[i] = sbox[input->val[i]];
     56 	}
     57 }
     58 
     59 static int key_expansion(ba *key, unsigned int keylen, unsigned int rounds,
     60 			 ba *round_keys[4*(rounds + 1)])
     61 {
     62 	/* number of 4-byte words, help to implement spec */
     63 	const int nk = keylen / 4;
     64 	int i;
     65 
     66 	ba *rcon[10] = {
     67 		ba_from_hex("01000000"),
     68 		ba_from_hex("02000000"),
     69 		ba_from_hex("04000000"),
     70 		ba_from_hex("08000000"),
     71 		ba_from_hex("10000000"),
     72 		ba_from_hex("20000000"),
     73 		ba_from_hex("40000000"),
     74 		ba_from_hex("80000000"),
     75 		ba_from_hex("1b000000"),
     76 		ba_from_hex("36000000")
     77 	};
     78 
     79 
     80 	for (i = 0; i < nk; i++) {
     81 		int j;
     82 
     83 		round_keys[i] = ba_alloc(4);
     84 		for (j = 0; j < 4; j++)
     85 			(round_keys[i])->val[j] = key->val[i*4+j];
     86 	}
     87 
     88 	for (i = nk; i < 4 * (rounds + 1); i++) {
     89 		ba *tmp = ba_alloc(4);
     90 		ba *tmp2 = ba_alloc(4);
     91 
     92 		round_keys[i] = ba_alloc(4);
     93 		ba_copy(tmp, round_keys[i - 1]);
     94 		ba_copy(tmp2, round_keys[i - nk]);
     95 
     96 		if (i % nk == 0) {
     97 			ba *tmp3 = ba_alloc(4);
     98 
     99 			rot_word(tmp, &tmp3, 1);
    100 			sub_word(tmp3, &tmp);
    101 			ba_xor(tmp, rcon[(i-1) / nk]);
    102 			ba_free(tmp3);
    103 		} else if (nk > 6 && i % nk == 4) {
    104 			ba *tmp3 = ba_alloc(4);
    105 
    106 			ba_copy(tmp3, tmp);
    107 			sub_word(tmp3, &tmp);
    108 			ba_free(tmp3);
    109 		}
    110 
    111 		ba_xor(tmp, tmp2);
    112 		ba_copy(round_keys[i], tmp);
    113 		ba_free(tmp);
    114 		ba_free(tmp2);
    115 	}
    116 
    117 	for (int i = 0; i < 10; i++)
    118 		ba_free(rcon[i]);
    119 
    120 	return 0;
    121 }
    122 
    123 static uint8_t galois_mul(uint8_t a, uint8_t b)
    124 {
    125 	uint8_t p;
    126 	int i;
    127 
    128 	p = 0;
    129 	for (i = 0; i < 8; i++) {
    130 		uint8_t h;
    131 
    132 		if (b & 1)
    133 			p ^= a;
    134 		h = a & 0x80;
    135 		a <<= 1;
    136 		if (h)
    137 			a ^= 0x1b;
    138 		b >>= 1;
    139 	}
    140 
    141 	return p;
    142 }
    143 
    144 static int add_round_key(ba *state[4], ba *round_keys[4])
    145 {
    146 	int i;
    147 
    148 	for (i = 0; i < 4; i++) {
    149 		int j;
    150 
    151 		for (j = 0; j < 4; j++)
    152 			state[i]->val[j] ^= round_keys[j]->val[i];
    153 	}
    154 
    155 	return 0;
    156 }
    157 
    158 static int sub_bytes(ba *state[4])
    159 {
    160 	int i;
    161 	int j;
    162 
    163 	for (i = 0; i < 4; i++)
    164 		for (j = 0; j < 4; j++)
    165 			state[i]->val[j] = sbox[state[i]->val[j]];
    166 
    167 	return 0;
    168 }
    169 
    170 static int shift_rows(ba *state[4])
    171 {
    172 	int i;
    173 
    174 	for (i = 1; i < 4; i++) {
    175 		ba *tmp;
    176 
    177 		tmp = ba_alloc(4);
    178 		ba_copy(tmp, state[i]);
    179 		rot_word(tmp, &state[i], i);
    180 		ba_free(tmp);
    181 	}
    182 }
    183 
    184 static int mix_columns(ba *state[4])
    185 {
    186 	ba *tmp[4];
    187 	int c;
    188 	int i;
    189 
    190 	for (i = 0; i < 4; i++) {
    191 		tmp[i] = ba_alloc(4);
    192 		ba_copy(tmp[i], state[i]);
    193 	}
    194 
    195 	for (c = 0; c < 4; c++) {
    196 		int j;
    197 
    198 		for (j = 0; j < 4; j++) {
    199 			state[j]->val[c] = galois_mul(tmp[(j)%4]->val[c], 2)
    200 			                 ^ galois_mul(tmp[(j+1)%4]->val[c], 3)
    201 					 ^ tmp[(j+2)%4]->val[c]
    202 					 ^ tmp[(j+3)%4]->val[c];
    203 
    204 		}
    205 	}
    206 
    207 	for (i = 0; i < 4; i++)
    208 		ba_free(tmp[i]);
    209 
    210 	return 0;
    211 }
    212 
    213 static int inv_sub_bytes(ba *state[4])
    214 {
    215 	int i;
    216 	int j;
    217 
    218 	for (i = 0; i < 4; i++)
    219 		for (j = 0; j < 4; j++)
    220 			state[i]->val[j] = inv_sbox[state[i]->val[j]];
    221 
    222 	return 0;
    223 }
    224 
    225 static int inv_shift_rows(ba *state[4])
    226 {
    227 	int i;
    228 
    229 	for (i = 1; i < 4; i++) {
    230 		ba *tmp;
    231 
    232 		tmp = ba_alloc(4);
    233 		ba_copy(tmp, state[i]);
    234 		rot_word(tmp, &state[i], 4-i);
    235 		ba_free(tmp);
    236 	}
    237 }
    238 
    239 static int inv_mix_columns(ba *state[4])
    240 {
    241 	ba *tmp[4];
    242 	int c;
    243 	int i;
    244 
    245 	for (i = 0; i < 4; i++) {
    246 		tmp[i] = ba_alloc(4);
    247 		ba_copy(tmp[i], state[i]);
    248 	}
    249 
    250 	for (c = 0; c < 4; c++) {
    251 		int j;
    252 
    253 		for (j = 0; j < 4; j++) {
    254 			state[j]->val[c] = galois_mul(tmp[(j)%4]->val[c], 0x0e)
    255 			                 ^ galois_mul(tmp[(j+1)%4]->val[c], 0x0b)
    256 					 ^ galois_mul(tmp[(j+2)%4]->val[c], 0x0d)
    257 					 ^ galois_mul(tmp[(j+3)%4]->val[c], 0x9);
    258 
    259 		}
    260 	}
    261 
    262 	for (i = 0; i < 4; i++)
    263 		ba_free(tmp[i]);
    264 
    265 	return 0;
    266 }
    267 
    268 static void print_state(ba *state[4])
    269 {
    270 	int i;
    271 
    272 	for (i = 0; i < 4; i++) {
    273 		ba_fprint(state[i], stdout, 0);
    274 		printf("\n");
    275 	}
    276 
    277 	printf("\n");
    278 }
    279 
    280 static int aes_generic(unsigned int rounds, unsigned int keylen,
    281 		       ba *plaintext, ba *key, ba **ciphertext)
    282 {
    283 	ba *round_keys[4*(rounds + 1)];
    284 	ba *state[4] = {
    285 		ba_alloc(4),
    286 		ba_alloc(4),
    287 		ba_alloc(4),
    288 		ba_alloc(4),
    289 	};
    290 	int i;
    291 
    292 	if (plaintext->len != 16) {
    293 		printf("invalid block len\n");
    294 		return -EINVAL;
    295 	}
    296 	if (key->len != keylen) {
    297 		printf("invalid keylen\n");
    298 		return -EINVAL;
    299 	}
    300 
    301 	for (i = 0; i < 16; i++)
    302 		(state[i%4]->val)[i/4] = plaintext->val[i];
    303 
    304 	key_expansion(key, keylen, rounds, round_keys);
    305 
    306 	add_round_key(state, round_keys);
    307 
    308 	for (i = 1; i < rounds; i++) {
    309 		sub_bytes(state);
    310 		shift_rows(state);
    311 		mix_columns(state);
    312 		add_round_key(state, round_keys + (4 * i));
    313 	}
    314 
    315 	sub_bytes(state);
    316 	shift_rows(state);
    317 
    318 	add_round_key(state, round_keys + (4 * rounds));
    319 
    320 	*ciphertext = ba_alloc(16);
    321 
    322 	for (i = 0; i < 16; i++)
    323 		 (*ciphertext)->val[i] = (state[i%4]->val)[i/4];
    324 
    325 	for (i = 0; i < 4; i++)
    326 		ba_free(state[i]);
    327 
    328 	for (i = 0; i < 4 * (rounds + 1); i++)
    329 		ba_free(round_keys[i]);
    330 
    331 	return 0;
    332 }
    333 
    334 static int aes_inv_generic(unsigned int rounds, unsigned int keylen,
    335                            ba *ciphertext, ba *key, ba **plaintext)
    336 {
    337 	ba *round_keys[4*(rounds + 1)];
    338 	ba *state[4] = {
    339 		ba_alloc(4),
    340 		ba_alloc(4),
    341 		ba_alloc(4),
    342 		ba_alloc(4),
    343 	};
    344 	int i;
    345 
    346 	if (ciphertext->len != 16) {
    347 		printf("invalid block len\n");
    348 		return -EINVAL;
    349 	}
    350 	if (key->len != keylen) {
    351 		printf("invalid keylen\n");
    352 		return -EINVAL;
    353 	}
    354 
    355 	for (i = 0; i < 16; i++)
    356 		(state[i%4]->val)[i/4] = ciphertext->val[i];
    357 
    358 	key_expansion(key, keylen, rounds, round_keys);
    359 
    360 	add_round_key(state, round_keys + (4 * rounds));
    361 
    362 	for (i = rounds - 1; i > 0; i--) {
    363 		inv_shift_rows(state);
    364 		inv_sub_bytes(state);
    365 		add_round_key(state, round_keys + (4 * i));
    366 		inv_mix_columns(state);
    367 	}
    368 
    369 	inv_shift_rows(state);
    370 	inv_sub_bytes(state);
    371 
    372 	add_round_key(state, round_keys);
    373 
    374 	*plaintext = ba_alloc(16);
    375 
    376 	for (i = 0; i < 16; i++)
    377 		 (*plaintext)->val[i] = (state[i%4]->val)[i/4];
    378 
    379 	for (i = 0; i < 4; i++)
    380 		ba_free(state[i]);
    381 
    382 	for (i = 0; i < 4 * (rounds + 1); i++)
    383 		ba_free(round_keys[i]);
    384 
    385 	return 0;
    386 }
    387 
    388 int aes_128_encrypt(ba *plaintext, ba *key, ba **ciphertext)
    389 {
    390 	return aes_generic(10, 16, plaintext, key, ciphertext);
    391 }
    392 
    393 int aes_128_decrypt(ba *ciphertext, ba *key, ba **plaintext)
    394 {
    395 	return aes_inv_generic(10, 16, ciphertext, key, plaintext);
    396 }
    397 
    398 int aes_192_encrypt(ba *plaintext, ba *key, ba **ciphertext)
    399 {
    400 	return aes_generic(12, 24, plaintext, key, ciphertext);
    401 }
    402 
    403 int aes_192_decrypt(ba *ciphertext, ba *key, ba **plaintext)
    404 {
    405 	return aes_inv_generic(12, 24, ciphertext, key, plaintext);
    406 }
    407 
    408 int aes_256_encrypt(ba *plaintext, ba *key, ba **ciphertext)
    409 {
    410 	return aes_generic(14, 32, plaintext, key, ciphertext);
    411 }
    412 
    413 int aes_256_decrypt(ba *ciphertext, ba *key, ba **plaintext)
    414 {
    415 	return aes_inv_generic(14, 32, ciphertext, key, plaintext);
    416 }