|
@@ -508,50 +508,79 @@ static u8 rcon[11] = {
|
|
|
0x00, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36
|
|
|
};
|
|
|
|
|
|
+static u32 aes_get_rounds(u32 key_len)
|
|
|
+{
|
|
|
+ u32 rounds = AES128_ROUNDS;
|
|
|
+
|
|
|
+ if (key_len == AES192_KEY_LENGTH)
|
|
|
+ rounds = AES192_ROUNDS;
|
|
|
+ else if (key_len == AES256_KEY_LENGTH)
|
|
|
+ rounds = AES256_ROUNDS;
|
|
|
+
|
|
|
+ return rounds;
|
|
|
+}
|
|
|
+
|
|
|
+static u32 aes_get_keycols(u32 key_len)
|
|
|
+{
|
|
|
+ u32 keycols = AES128_KEYCOLS;
|
|
|
+
|
|
|
+ if (key_len == AES192_KEY_LENGTH)
|
|
|
+ keycols = AES192_KEYCOLS;
|
|
|
+ else if (key_len == AES256_KEY_LENGTH)
|
|
|
+ keycols = AES256_KEYCOLS;
|
|
|
+
|
|
|
+ return keycols;
|
|
|
+}
|
|
|
+
|
|
|
/* produce AES_STATECOLS bytes for each round */
|
|
|
-void aes_expand_key(u8 *key, u8 *expkey)
|
|
|
+void aes_expand_key(u8 *key, u32 key_len, u8 *expkey)
|
|
|
{
|
|
|
u8 tmp0, tmp1, tmp2, tmp3, tmp4;
|
|
|
- u32 idx;
|
|
|
+ u32 idx, aes_rounds, aes_keycols;
|
|
|
|
|
|
- memcpy(expkey, key, AES_KEYCOLS * 4);
|
|
|
+ aes_rounds = aes_get_rounds(key_len);
|
|
|
+ aes_keycols = aes_get_keycols(key_len);
|
|
|
|
|
|
- for (idx = AES_KEYCOLS; idx < AES_STATECOLS * (AES_ROUNDS + 1); idx++) {
|
|
|
+ memcpy(expkey, key, key_len);
|
|
|
+
|
|
|
+ for (idx = aes_keycols; idx < AES_STATECOLS * (aes_rounds + 1); idx++) {
|
|
|
tmp0 = expkey[4*idx - 4];
|
|
|
tmp1 = expkey[4*idx - 3];
|
|
|
tmp2 = expkey[4*idx - 2];
|
|
|
tmp3 = expkey[4*idx - 1];
|
|
|
- if (!(idx % AES_KEYCOLS)) {
|
|
|
+ if (!(idx % aes_keycols)) {
|
|
|
tmp4 = tmp3;
|
|
|
tmp3 = sbox[tmp0];
|
|
|
- tmp0 = sbox[tmp1] ^ rcon[idx / AES_KEYCOLS];
|
|
|
+ tmp0 = sbox[tmp1] ^ rcon[idx / aes_keycols];
|
|
|
tmp1 = sbox[tmp2];
|
|
|
tmp2 = sbox[tmp4];
|
|
|
- } else if ((AES_KEYCOLS > 6) && (idx % AES_KEYCOLS == 4)) {
|
|
|
+ } else if ((aes_keycols > 6) && (idx % aes_keycols == 4)) {
|
|
|
tmp0 = sbox[tmp0];
|
|
|
tmp1 = sbox[tmp1];
|
|
|
tmp2 = sbox[tmp2];
|
|
|
tmp3 = sbox[tmp3];
|
|
|
}
|
|
|
|
|
|
- expkey[4*idx+0] = expkey[4*idx - 4*AES_KEYCOLS + 0] ^ tmp0;
|
|
|
- expkey[4*idx+1] = expkey[4*idx - 4*AES_KEYCOLS + 1] ^ tmp1;
|
|
|
- expkey[4*idx+2] = expkey[4*idx - 4*AES_KEYCOLS + 2] ^ tmp2;
|
|
|
- expkey[4*idx+3] = expkey[4*idx - 4*AES_KEYCOLS + 3] ^ tmp3;
|
|
|
+ expkey[4*idx+0] = expkey[4*idx - 4*aes_keycols + 0] ^ tmp0;
|
|
|
+ expkey[4*idx+1] = expkey[4*idx - 4*aes_keycols + 1] ^ tmp1;
|
|
|
+ expkey[4*idx+2] = expkey[4*idx - 4*aes_keycols + 2] ^ tmp2;
|
|
|
+ expkey[4*idx+3] = expkey[4*idx - 4*aes_keycols + 3] ^ tmp3;
|
|
|
}
|
|
|
}
|
|
|
|
|
|
/* encrypt one 128 bit block */
|
|
|
-void aes_encrypt(u8 *in, u8 *expkey, u8 *out)
|
|
|
+void aes_encrypt(u32 key_len, u8 *in, u8 *expkey, u8 *out)
|
|
|
{
|
|
|
u8 state[AES_STATECOLS * 4];
|
|
|
- u32 round;
|
|
|
+ u32 round, aes_rounds;
|
|
|
+
|
|
|
+ aes_rounds = aes_get_rounds(key_len);
|
|
|
|
|
|
memcpy(state, in, AES_STATECOLS * 4);
|
|
|
add_round_key((u32 *)state, (u32 *)expkey);
|
|
|
|
|
|
- for (round = 1; round < AES_ROUNDS + 1; round++) {
|
|
|
- if (round < AES_ROUNDS)
|
|
|
+ for (round = 1; round < aes_rounds + 1; round++) {
|
|
|
+ if (round < aes_rounds)
|
|
|
mix_sub_columns(state);
|
|
|
else
|
|
|
shift_rows(state);
|
|
@@ -563,18 +592,20 @@ void aes_encrypt(u8 *in, u8 *expkey, u8 *out)
|
|
|
memcpy(out, state, sizeof(state));
|
|
|
}
|
|
|
|
|
|
-void aes_decrypt(u8 *in, u8 *expkey, u8 *out)
|
|
|
+void aes_decrypt(u32 key_len, u8 *in, u8 *expkey, u8 *out)
|
|
|
{
|
|
|
u8 state[AES_STATECOLS * 4];
|
|
|
- int round;
|
|
|
+ int round, aes_rounds;
|
|
|
+
|
|
|
+ aes_rounds = aes_get_rounds(key_len);
|
|
|
|
|
|
memcpy(state, in, sizeof(state));
|
|
|
|
|
|
add_round_key((u32 *)state,
|
|
|
- (u32 *)expkey + AES_ROUNDS * AES_STATECOLS);
|
|
|
+ (u32 *)expkey + aes_rounds * AES_STATECOLS);
|
|
|
inv_shift_rows(state);
|
|
|
|
|
|
- for (round = AES_ROUNDS; round--; ) {
|
|
|
+ for (round = aes_rounds; round--; ) {
|
|
|
add_round_key((u32 *)state,
|
|
|
(u32 *)expkey + round * AES_STATECOLS);
|
|
|
if (round)
|
|
@@ -600,7 +631,7 @@ void aes_apply_cbc_chain_data(u8 *cbc_chain_data, u8 *src, u8 *dst)
|
|
|
*dst++ = *src++ ^ *cbc_chain_data++;
|
|
|
}
|
|
|
|
|
|
-void aes_cbc_encrypt_blocks(u8 *key_exp, u8 *iv, u8 *src, u8 *dst,
|
|
|
+void aes_cbc_encrypt_blocks(u32 key_len, u8 *key_exp, u8 *iv, u8 *src, u8 *dst,
|
|
|
u32 num_aes_blocks)
|
|
|
{
|
|
|
u8 tmp_data[AES_BLOCK_LENGTH];
|
|
@@ -616,7 +647,7 @@ void aes_cbc_encrypt_blocks(u8 *key_exp, u8 *iv, u8 *src, u8 *dst,
|
|
|
debug_print_vector("AES Xor", AES_BLOCK_LENGTH, tmp_data);
|
|
|
|
|
|
/* Encrypt the AES block */
|
|
|
- aes_encrypt(tmp_data, key_exp, dst);
|
|
|
+ aes_encrypt(key_len, tmp_data, key_exp, dst);
|
|
|
debug_print_vector("AES Dst", AES_BLOCK_LENGTH, dst);
|
|
|
|
|
|
/* Update pointers for next loop. */
|
|
@@ -626,7 +657,7 @@ void aes_cbc_encrypt_blocks(u8 *key_exp, u8 *iv, u8 *src, u8 *dst,
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-void aes_cbc_decrypt_blocks(u8 *key_exp, u8 *iv, u8 *src, u8 *dst,
|
|
|
+void aes_cbc_decrypt_blocks(u32 key_len, u8 *key_exp, u8 *iv, u8 *src, u8 *dst,
|
|
|
u32 num_aes_blocks)
|
|
|
{
|
|
|
u8 tmp_data[AES_BLOCK_LENGTH], tmp_block[AES_BLOCK_LENGTH];
|
|
@@ -642,7 +673,7 @@ void aes_cbc_decrypt_blocks(u8 *key_exp, u8 *iv, u8 *src, u8 *dst,
|
|
|
memcpy(tmp_block, src, AES_BLOCK_LENGTH);
|
|
|
|
|
|
/* Decrypt the AES block */
|
|
|
- aes_decrypt(src, key_exp, tmp_data);
|
|
|
+ aes_decrypt(key_len, src, key_exp, tmp_data);
|
|
|
debug_print_vector("AES Xor", AES_BLOCK_LENGTH, tmp_data);
|
|
|
|
|
|
/* Apply the chain data */
|