aes.c (9942B)
1 /* adapted from https://csrc.nist.gov/pubs/fips/197/final */ 2 3 #include "aes.h" 4 5 static uint8_t sbox[0x100] = { 6 0x63, 0x7c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5, 0x30, 0x01, 0x67, 0x2b, 0xfe, 0xd7, 0xab, 0x76, 7 0xca, 0x82, 0xc9, 0x7d, 0xfa, 0x59, 0x47, 0xf0, 0xad, 0xd4, 0xa2, 0xaf, 0x9c, 0xa4, 0x72, 0xc0, 8 0xb7, 0xfd, 0x93, 0x26, 0x36, 0x3f, 0xf7, 0xcc, 0x34, 0xa5, 0xe5, 0xf1, 0x71, 0xd8, 0x31, 0x15, 9 0x04, 0xc7, 0x23, 0xc3, 0x18, 0x96, 0x05, 0x9a, 0x07, 0x12, 0x80, 0xe2, 0xeb, 0x27, 0xb2, 0x75, 10 0x09, 0x83, 0x2c, 0x1a, 0x1b, 0x6e, 0x5a, 0xa0, 0x52, 0x3b, 0xd6, 0xb3, 0x29, 0xe3, 0x2f, 0x84, 11 0x53, 0xd1, 0x00, 0xed, 0x20, 0xfc, 0xb1, 0x5b, 0x6a, 0xcb, 0xbe, 0x39, 0x4a, 0x4c, 0x58, 0xcf, 12 0xd0, 0xef, 0xaa, 0xfb, 0x43, 0x4d, 0x33, 0x85, 0x45, 0xf9, 0x02, 0x7f, 0x50, 0x3c, 0x9f, 0xa8, 13 0x51, 0xa3, 0x40, 0x8f, 0x92, 0x9d, 0x38, 0xf5, 0xbc, 0xb6, 0xda, 0x21, 0x10, 0xff, 0xf3, 0xd2, 14 0xcd, 0x0c, 0x13, 0xec, 0x5f, 0x97, 0x44, 0x17, 0xc4, 0xa7, 0x7e, 0x3d, 0x64, 0x5d, 0x19, 0x73, 15 0x60, 0x81, 0x4f, 0xdc, 0x22, 0x2a, 0x90, 0x88, 0x46, 0xee, 0xb8, 0x14, 0xde, 0x5e, 0x0b, 0xdb, 16 0xe0, 0x32, 0x3a, 0x0a, 0x49, 0x06, 0x24, 0x5c, 0xc2, 0xd3, 0xac, 0x62, 0x91, 0x95, 0xe4, 0x79, 17 0xe7, 0xc8, 0x37, 0x6d, 0x8d, 0xd5, 0x4e, 0xa9, 0x6c, 0x56, 0xf4, 0xea, 0x65, 0x7a, 0xae, 0x08, 18 0xba, 0x78, 0x25, 0x2e, 0x1c, 0xa6, 0xb4, 0xc6, 0xe8, 0xdd, 0x74, 0x1f, 0x4b, 0xbd, 0x8b, 0x8a, 19 0x70, 0x3e, 0xb5, 0x66, 0x48, 0x03, 0xf6, 0x0e, 0x61, 0x35, 0x57, 0xb9, 0x86, 0xc1, 0x1d, 0x9e, 20 0xe1, 0xf8, 0x98, 0x11, 0x69, 0xd9, 0x8e, 0x94, 0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf, 21 0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, 0x41, 0x99, 0x2d, 0x0f, 0xb0, 0x54, 0xbb, 0x16 22 }; 23 24 static uint8_t inv_sbox[0x100] = { 25 0x52, 0x09, 0x6a, 0xd5, 0x30, 0x36, 0xa5, 0x38, 0xbf, 0x40, 0xa3, 0x9e, 0x81, 0xf3, 0xd7, 0xfb, 26 0x7c, 0xe3, 0x39, 0x82, 0x9b, 0x2f, 0xff, 0x87, 0x34, 0x8e, 0x43, 0x44, 0xc4, 0xde, 0xe9, 0xcb, 27 0x54, 0x7b, 0x94, 0x32, 0xa6, 0xc2, 0x23, 0x3d, 0xee, 0x4c, 0x95, 0x0b, 0x42, 0xfa, 0xc3, 0x4e, 28 0x08, 0x2e, 0xa1, 0x66, 0x28, 0xd9, 0x24, 0xb2, 0x76, 0x5b, 0xa2, 0x49, 0x6d, 0x8b, 0xd1, 0x25, 29 0x72, 0xf8, 0xf6, 0x64, 0x86, 0x68, 0x98, 0x16, 0xd4, 0xa4, 0x5c, 0xcc, 0x5d, 0x65, 0xb6, 0x92, 30 0x6c, 0x70, 0x48, 0x50, 0xfd, 0xed, 0xb9, 0xda, 0x5e, 0x15, 0x46, 0x57, 0xa7, 0x8d, 0x9d, 0x84, 31 0x90, 0xd8, 0xab, 0x00, 0x8c, 0xbc, 0xd3, 0x0a, 0xf7, 0xe4, 0x58, 0x05, 0xb8, 0xb3, 0x45, 0x06, 32 0xd0, 0x2c, 0x1e, 0x8f, 0xca, 0x3f, 0x0f, 0x02, 0xc1, 0xaf, 0xbd, 0x03, 0x01, 0x13, 0x8a, 0x6b, 33 0x3a, 0x91, 0x11, 0x41, 0x4f, 0x67, 0xdc, 0xea, 0x97, 0xf2, 0xcf, 0xce, 0xf0, 0xb4, 0xe6, 0x73, 34 0x96, 0xac, 0x74, 0x22, 0xe7, 0xad, 0x35, 0x85, 0xe2, 0xf9, 0x37, 0xe8, 0x1c, 0x75, 0xdf, 0x6e, 35 0x47, 0xf1, 0x1a, 0x71, 0x1d, 0x29, 0xc5, 0x89, 0x6f, 0xb7, 0x62, 0x0e, 0xaa, 0x18, 0xbe, 0x1b, 36 0xfc, 0x56, 0x3e, 0x4b, 0xc6, 0xd2, 0x79, 0x20, 0x9a, 0xdb, 0xc0, 0xfe, 0x78, 0xcd, 0x5a, 0xf4, 37 0x1f, 0xdd, 0xa8, 0x33, 0x88, 0x07, 0xc7, 0x31, 0xb1, 0x12, 0x10, 0x59, 0x27, 0x80, 0xec, 0x5f, 38 0x60, 0x51, 0x7f, 0xa9, 0x19, 0xb5, 0x4a, 0x0d, 0x2d, 0xe5, 0x7a, 0x9f, 0x93, 0xc9, 0x9c, 0xef, 39 0xa0, 0xe0, 0x3b, 0x4d, 0xae, 0x2a, 0xf5, 0xb0, 0xc8, 0xeb, 0xbb, 0x3c, 0x83, 0x53, 0x99, 0x61, 40 0x17, 0x2b, 0x04, 0x7e, 0xba, 0x77, 0xd6, 0x26, 0xe1, 0x69, 0x14, 0x63, 0x55, 0x21, 0x0c, 0x7d 41 }; 42 43 static int rot_word(ba *input, ba **output, unsigned int offset) { 44 int i; 45 46 for (i = 0; i < 4; i++) { 47 (*output)->val[i] = input->val[(i+offset)%4]; 48 } 49 } 50 51 static int sub_word(ba *input, ba **output) { 52 int i; 53 54 for (i = 0; i < 4; i++) { 55 (*output)->val[i] = sbox[input->val[i]]; 56 } 57 } 58 59 static int key_expansion(ba *key, unsigned int keylen, unsigned int rounds, 60 ba *round_keys[4*(rounds + 1)]) 61 { 62 /* number of 4-byte words, help to implement spec */ 63 const int nk = keylen / 4; 64 int i; 65 66 ba *rcon[10] = { 67 ba_from_hex("01000000"), 68 ba_from_hex("02000000"), 69 ba_from_hex("04000000"), 70 ba_from_hex("08000000"), 71 ba_from_hex("10000000"), 72 ba_from_hex("20000000"), 73 ba_from_hex("40000000"), 74 ba_from_hex("80000000"), 75 ba_from_hex("1b000000"), 76 ba_from_hex("36000000") 77 }; 78 79 80 for (i = 0; i < nk; i++) { 81 int j; 82 83 round_keys[i] = ba_alloc(4); 84 for (j = 0; j < 4; j++) 85 (round_keys[i])->val[j] = key->val[i*4+j]; 86 } 87 88 for (i = nk; i < 4 * (rounds + 1); i++) { 89 ba *tmp = ba_alloc(4); 90 ba *tmp2 = ba_alloc(4); 91 92 round_keys[i] = ba_alloc(4); 93 ba_copy(tmp, round_keys[i - 1]); 94 ba_copy(tmp2, round_keys[i - nk]); 95 96 if (i % nk == 0) { 97 ba *tmp3 = ba_alloc(4); 98 99 rot_word(tmp, &tmp3, 1); 100 sub_word(tmp3, &tmp); 101 ba_xor(tmp, rcon[(i-1) / nk]); 102 ba_free(tmp3); 103 } else if (nk > 6 && i % nk == 4) { 104 ba *tmp3 = ba_alloc(4); 105 106 ba_copy(tmp3, tmp); 107 sub_word(tmp3, &tmp); 108 ba_free(tmp3); 109 } 110 111 ba_xor(tmp, tmp2); 112 ba_copy(round_keys[i], tmp); 113 ba_free(tmp); 114 ba_free(tmp2); 115 } 116 117 for (int i = 0; i < 10; i++) 118 ba_free(rcon[i]); 119 120 return 0; 121 } 122 123 static uint8_t galois_mul(uint8_t a, uint8_t b) 124 { 125 uint8_t p; 126 int i; 127 128 p = 0; 129 for (i = 0; i < 8; i++) { 130 uint8_t h; 131 132 if (b & 1) 133 p ^= a; 134 h = a & 0x80; 135 a <<= 1; 136 if (h) 137 a ^= 0x1b; 138 b >>= 1; 139 } 140 141 return p; 142 } 143 144 static int add_round_key(ba *state[4], ba *round_keys[4]) 145 { 146 int i; 147 148 for (i = 0; i < 4; i++) { 149 int j; 150 151 for (j = 0; j < 4; j++) 152 state[i]->val[j] ^= round_keys[j]->val[i]; 153 } 154 155 return 0; 156 } 157 158 static int sub_bytes(ba *state[4]) 159 { 160 int i; 161 int j; 162 163 for (i = 0; i < 4; i++) 164 for (j = 0; j < 4; j++) 165 state[i]->val[j] = sbox[state[i]->val[j]]; 166 167 return 0; 168 } 169 170 static int shift_rows(ba *state[4]) 171 { 172 int i; 173 174 for (i = 1; i < 4; i++) { 175 ba *tmp; 176 177 tmp = ba_alloc(4); 178 ba_copy(tmp, state[i]); 179 rot_word(tmp, &state[i], i); 180 ba_free(tmp); 181 } 182 } 183 184 static int mix_columns(ba *state[4]) 185 { 186 ba *tmp[4]; 187 int c; 188 int i; 189 190 for (i = 0; i < 4; i++) { 191 tmp[i] = ba_alloc(4); 192 ba_copy(tmp[i], state[i]); 193 } 194 195 for (c = 0; c < 4; c++) { 196 int j; 197 198 for (j = 0; j < 4; j++) { 199 state[j]->val[c] = galois_mul(tmp[(j)%4]->val[c], 2) 200 ^ galois_mul(tmp[(j+1)%4]->val[c], 3) 201 ^ tmp[(j+2)%4]->val[c] 202 ^ tmp[(j+3)%4]->val[c]; 203 204 } 205 } 206 207 for (i = 0; i < 4; i++) 208 ba_free(tmp[i]); 209 210 return 0; 211 } 212 213 static int inv_sub_bytes(ba *state[4]) 214 { 215 int i; 216 int j; 217 218 for (i = 0; i < 4; i++) 219 for (j = 0; j < 4; j++) 220 state[i]->val[j] = inv_sbox[state[i]->val[j]]; 221 222 return 0; 223 } 224 225 static int inv_shift_rows(ba *state[4]) 226 { 227 int i; 228 229 for (i = 1; i < 4; i++) { 230 ba *tmp; 231 232 tmp = ba_alloc(4); 233 ba_copy(tmp, state[i]); 234 rot_word(tmp, &state[i], 4-i); 235 ba_free(tmp); 236 } 237 } 238 239 static int inv_mix_columns(ba *state[4]) 240 { 241 ba *tmp[4]; 242 int c; 243 int i; 244 245 for (i = 0; i < 4; i++) { 246 tmp[i] = ba_alloc(4); 247 ba_copy(tmp[i], state[i]); 248 } 249 250 for (c = 0; c < 4; c++) { 251 int j; 252 253 for (j = 0; j < 4; j++) { 254 state[j]->val[c] = galois_mul(tmp[(j)%4]->val[c], 0x0e) 255 ^ galois_mul(tmp[(j+1)%4]->val[c], 0x0b) 256 ^ galois_mul(tmp[(j+2)%4]->val[c], 0x0d) 257 ^ galois_mul(tmp[(j+3)%4]->val[c], 0x9); 258 259 } 260 } 261 262 for (i = 0; i < 4; i++) 263 ba_free(tmp[i]); 264 265 return 0; 266 } 267 268 static void print_state(ba *state[4]) 269 { 270 int i; 271 272 for (i = 0; i < 4; i++) { 273 ba_fprint(state[i], stdout, 0); 274 printf("\n"); 275 } 276 277 printf("\n"); 278 } 279 280 static int aes_generic(unsigned int rounds, unsigned int keylen, 281 ba *plaintext, ba *key, ba **ciphertext) 282 { 283 ba *round_keys[4*(rounds + 1)]; 284 ba *state[4] = { 285 ba_alloc(4), 286 ba_alloc(4), 287 ba_alloc(4), 288 ba_alloc(4), 289 }; 290 int i; 291 292 if (plaintext->len != 16) { 293 printf("invalid block len\n"); 294 return -EINVAL; 295 } 296 if (key->len != keylen) { 297 printf("invalid keylen\n"); 298 return -EINVAL; 299 } 300 301 for (i = 0; i < 16; i++) 302 (state[i%4]->val)[i/4] = plaintext->val[i]; 303 304 key_expansion(key, keylen, rounds, round_keys); 305 306 add_round_key(state, round_keys); 307 308 for (i = 1; i < rounds; i++) { 309 sub_bytes(state); 310 shift_rows(state); 311 mix_columns(state); 312 add_round_key(state, round_keys + (4 * i)); 313 } 314 315 sub_bytes(state); 316 shift_rows(state); 317 318 add_round_key(state, round_keys + (4 * rounds)); 319 320 *ciphertext = ba_alloc(16); 321 322 for (i = 0; i < 16; i++) 323 (*ciphertext)->val[i] = (state[i%4]->val)[i/4]; 324 325 for (i = 0; i < 4; i++) 326 ba_free(state[i]); 327 328 for (i = 0; i < 4 * (rounds + 1); i++) 329 ba_free(round_keys[i]); 330 331 return 0; 332 } 333 334 static int aes_inv_generic(unsigned int rounds, unsigned int keylen, 335 ba *ciphertext, ba *key, ba **plaintext) 336 { 337 ba *round_keys[4*(rounds + 1)]; 338 ba *state[4] = { 339 ba_alloc(4), 340 ba_alloc(4), 341 ba_alloc(4), 342 ba_alloc(4), 343 }; 344 int i; 345 346 if (ciphertext->len != 16) { 347 printf("invalid block len\n"); 348 return -EINVAL; 349 } 350 if (key->len != keylen) { 351 printf("invalid keylen\n"); 352 return -EINVAL; 353 } 354 355 for (i = 0; i < 16; i++) 356 (state[i%4]->val)[i/4] = ciphertext->val[i]; 357 358 key_expansion(key, keylen, rounds, round_keys); 359 360 add_round_key(state, round_keys + (4 * rounds)); 361 362 for (i = rounds - 1; i > 0; i--) { 363 inv_shift_rows(state); 364 inv_sub_bytes(state); 365 add_round_key(state, round_keys + (4 * i)); 366 inv_mix_columns(state); 367 } 368 369 inv_shift_rows(state); 370 inv_sub_bytes(state); 371 372 add_round_key(state, round_keys); 373 374 *plaintext = ba_alloc(16); 375 376 for (i = 0; i < 16; i++) 377 (*plaintext)->val[i] = (state[i%4]->val)[i/4]; 378 379 for (i = 0; i < 4; i++) 380 ba_free(state[i]); 381 382 for (i = 0; i < 4 * (rounds + 1); i++) 383 ba_free(round_keys[i]); 384 385 return 0; 386 } 387 388 int aes_128_encrypt(ba *plaintext, ba *key, ba **ciphertext) 389 { 390 return aes_generic(10, 16, plaintext, key, ciphertext); 391 } 392 393 int aes_128_decrypt(ba *ciphertext, ba *key, ba **plaintext) 394 { 395 return aes_inv_generic(10, 16, ciphertext, key, plaintext); 396 } 397 398 int aes_192_encrypt(ba *plaintext, ba *key, ba **ciphertext) 399 { 400 return aes_generic(12, 24, plaintext, key, ciphertext); 401 } 402 403 int aes_192_decrypt(ba *ciphertext, ba *key, ba **plaintext) 404 { 405 return aes_inv_generic(12, 24, ciphertext, key, plaintext); 406 } 407 408 int aes_256_encrypt(ba *plaintext, ba *key, ba **ciphertext) 409 { 410 return aes_generic(14, 32, plaintext, key, ciphertext); 411 } 412 413 int aes_256_decrypt(ba *ciphertext, ba *key, ba **plaintext) 414 { 415 return aes_inv_generic(14, 32, ciphertext, key, plaintext); 416 }