]> asedeno.scripts.mit.edu Git - linux.git/blob - arch/arm64/crypto/aes-neonbs-glue.c
Merge branch 'for-5.3' of git://git.kernel.org/pub/scm/linux/kernel/git/dennis/percpu
[linux.git] / arch / arm64 / crypto / aes-neonbs-glue.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * Bit sliced AES using NEON instructions
4  *
5  * Copyright (C) 2016 - 2017 Linaro Ltd <ard.biesheuvel@linaro.org>
6  */
7
8 #include <asm/neon.h>
9 #include <asm/simd.h>
10 #include <crypto/aes.h>
11 #include <crypto/internal/simd.h>
12 #include <crypto/internal/skcipher.h>
13 #include <crypto/xts.h>
14 #include <linux/module.h>
15
16 #include "aes-ctr-fallback.h"
17
18 MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
19 MODULE_LICENSE("GPL v2");
20
21 MODULE_ALIAS_CRYPTO("ecb(aes)");
22 MODULE_ALIAS_CRYPTO("cbc(aes)");
23 MODULE_ALIAS_CRYPTO("ctr(aes)");
24 MODULE_ALIAS_CRYPTO("xts(aes)");
25
26 asmlinkage void aesbs_convert_key(u8 out[], u32 const rk[], int rounds);
27
28 asmlinkage void aesbs_ecb_encrypt(u8 out[], u8 const in[], u8 const rk[],
29                                   int rounds, int blocks);
30 asmlinkage void aesbs_ecb_decrypt(u8 out[], u8 const in[], u8 const rk[],
31                                   int rounds, int blocks);
32
33 asmlinkage void aesbs_cbc_decrypt(u8 out[], u8 const in[], u8 const rk[],
34                                   int rounds, int blocks, u8 iv[]);
35
36 asmlinkage void aesbs_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[],
37                                   int rounds, int blocks, u8 iv[], u8 final[]);
38
39 asmlinkage void aesbs_xts_encrypt(u8 out[], u8 const in[], u8 const rk[],
40                                   int rounds, int blocks, u8 iv[]);
41 asmlinkage void aesbs_xts_decrypt(u8 out[], u8 const in[], u8 const rk[],
42                                   int rounds, int blocks, u8 iv[]);
43
44 /* borrowed from aes-neon-blk.ko */
45 asmlinkage void neon_aes_ecb_encrypt(u8 out[], u8 const in[], u32 const rk[],
46                                      int rounds, int blocks);
47 asmlinkage void neon_aes_cbc_encrypt(u8 out[], u8 const in[], u32 const rk[],
48                                      int rounds, int blocks, u8 iv[]);
49
50 struct aesbs_ctx {
51         u8      rk[13 * (8 * AES_BLOCK_SIZE) + 32];
52         int     rounds;
53 } __aligned(AES_BLOCK_SIZE);
54
55 struct aesbs_cbc_ctx {
56         struct aesbs_ctx        key;
57         u32                     enc[AES_MAX_KEYLENGTH_U32];
58 };
59
60 struct aesbs_ctr_ctx {
61         struct aesbs_ctx        key;            /* must be first member */
62         struct crypto_aes_ctx   fallback;
63 };
64
65 struct aesbs_xts_ctx {
66         struct aesbs_ctx        key;
67         u32                     twkey[AES_MAX_KEYLENGTH_U32];
68 };
69
70 static int aesbs_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
71                         unsigned int key_len)
72 {
73         struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
74         struct crypto_aes_ctx rk;
75         int err;
76
77         err = crypto_aes_expand_key(&rk, in_key, key_len);
78         if (err)
79                 return err;
80
81         ctx->rounds = 6 + key_len / 4;
82
83         kernel_neon_begin();
84         aesbs_convert_key(ctx->rk, rk.key_enc, ctx->rounds);
85         kernel_neon_end();
86
87         return 0;
88 }
89
90 static int __ecb_crypt(struct skcipher_request *req,
91                        void (*fn)(u8 out[], u8 const in[], u8 const rk[],
92                                   int rounds, int blocks))
93 {
94         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
95         struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
96         struct skcipher_walk walk;
97         int err;
98
99         err = skcipher_walk_virt(&walk, req, false);
100
101         while (walk.nbytes >= AES_BLOCK_SIZE) {
102                 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
103
104                 if (walk.nbytes < walk.total)
105                         blocks = round_down(blocks,
106                                             walk.stride / AES_BLOCK_SIZE);
107
108                 kernel_neon_begin();
109                 fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->rk,
110                    ctx->rounds, blocks);
111                 kernel_neon_end();
112                 err = skcipher_walk_done(&walk,
113                                          walk.nbytes - blocks * AES_BLOCK_SIZE);
114         }
115
116         return err;
117 }
118
119 static int ecb_encrypt(struct skcipher_request *req)
120 {
121         return __ecb_crypt(req, aesbs_ecb_encrypt);
122 }
123
124 static int ecb_decrypt(struct skcipher_request *req)
125 {
126         return __ecb_crypt(req, aesbs_ecb_decrypt);
127 }
128
129 static int aesbs_cbc_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
130                             unsigned int key_len)
131 {
132         struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
133         struct crypto_aes_ctx rk;
134         int err;
135
136         err = crypto_aes_expand_key(&rk, in_key, key_len);
137         if (err)
138                 return err;
139
140         ctx->key.rounds = 6 + key_len / 4;
141
142         memcpy(ctx->enc, rk.key_enc, sizeof(ctx->enc));
143
144         kernel_neon_begin();
145         aesbs_convert_key(ctx->key.rk, rk.key_enc, ctx->key.rounds);
146         kernel_neon_end();
147
148         return 0;
149 }
150
151 static int cbc_encrypt(struct skcipher_request *req)
152 {
153         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
154         struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
155         struct skcipher_walk walk;
156         int err;
157
158         err = skcipher_walk_virt(&walk, req, false);
159
160         while (walk.nbytes >= AES_BLOCK_SIZE) {
161                 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
162
163                 /* fall back to the non-bitsliced NEON implementation */
164                 kernel_neon_begin();
165                 neon_aes_cbc_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
166                                      ctx->enc, ctx->key.rounds, blocks,
167                                      walk.iv);
168                 kernel_neon_end();
169                 err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
170         }
171         return err;
172 }
173
174 static int cbc_decrypt(struct skcipher_request *req)
175 {
176         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
177         struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
178         struct skcipher_walk walk;
179         int err;
180
181         err = skcipher_walk_virt(&walk, req, false);
182
183         while (walk.nbytes >= AES_BLOCK_SIZE) {
184                 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
185
186                 if (walk.nbytes < walk.total)
187                         blocks = round_down(blocks,
188                                             walk.stride / AES_BLOCK_SIZE);
189
190                 kernel_neon_begin();
191                 aesbs_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
192                                   ctx->key.rk, ctx->key.rounds, blocks,
193                                   walk.iv);
194                 kernel_neon_end();
195                 err = skcipher_walk_done(&walk,
196                                          walk.nbytes - blocks * AES_BLOCK_SIZE);
197         }
198
199         return err;
200 }
201
202 static int aesbs_ctr_setkey_sync(struct crypto_skcipher *tfm, const u8 *in_key,
203                                  unsigned int key_len)
204 {
205         struct aesbs_ctr_ctx *ctx = crypto_skcipher_ctx(tfm);
206         int err;
207
208         err = crypto_aes_expand_key(&ctx->fallback, in_key, key_len);
209         if (err)
210                 return err;
211
212         ctx->key.rounds = 6 + key_len / 4;
213
214         kernel_neon_begin();
215         aesbs_convert_key(ctx->key.rk, ctx->fallback.key_enc, ctx->key.rounds);
216         kernel_neon_end();
217
218         return 0;
219 }
220
221 static int ctr_encrypt(struct skcipher_request *req)
222 {
223         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
224         struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
225         struct skcipher_walk walk;
226         u8 buf[AES_BLOCK_SIZE];
227         int err;
228
229         err = skcipher_walk_virt(&walk, req, false);
230
231         while (walk.nbytes > 0) {
232                 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
233                 u8 *final = (walk.total % AES_BLOCK_SIZE) ? buf : NULL;
234
235                 if (walk.nbytes < walk.total) {
236                         blocks = round_down(blocks,
237                                             walk.stride / AES_BLOCK_SIZE);
238                         final = NULL;
239                 }
240
241                 kernel_neon_begin();
242                 aesbs_ctr_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
243                                   ctx->rk, ctx->rounds, blocks, walk.iv, final);
244                 kernel_neon_end();
245
246                 if (final) {
247                         u8 *dst = walk.dst.virt.addr + blocks * AES_BLOCK_SIZE;
248                         u8 *src = walk.src.virt.addr + blocks * AES_BLOCK_SIZE;
249
250                         crypto_xor_cpy(dst, src, final,
251                                        walk.total % AES_BLOCK_SIZE);
252
253                         err = skcipher_walk_done(&walk, 0);
254                         break;
255                 }
256                 err = skcipher_walk_done(&walk,
257                                          walk.nbytes - blocks * AES_BLOCK_SIZE);
258         }
259         return err;
260 }
261
262 static int aesbs_xts_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
263                             unsigned int key_len)
264 {
265         struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
266         struct crypto_aes_ctx rk;
267         int err;
268
269         err = xts_verify_key(tfm, in_key, key_len);
270         if (err)
271                 return err;
272
273         key_len /= 2;
274         err = crypto_aes_expand_key(&rk, in_key + key_len, key_len);
275         if (err)
276                 return err;
277
278         memcpy(ctx->twkey, rk.key_enc, sizeof(ctx->twkey));
279
280         return aesbs_setkey(tfm, in_key, key_len);
281 }
282
283 static int ctr_encrypt_sync(struct skcipher_request *req)
284 {
285         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
286         struct aesbs_ctr_ctx *ctx = crypto_skcipher_ctx(tfm);
287
288         if (!crypto_simd_usable())
289                 return aes_ctr_encrypt_fallback(&ctx->fallback, req);
290
291         return ctr_encrypt(req);
292 }
293
294 static int __xts_crypt(struct skcipher_request *req,
295                        void (*fn)(u8 out[], u8 const in[], u8 const rk[],
296                                   int rounds, int blocks, u8 iv[]))
297 {
298         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
299         struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
300         struct skcipher_walk walk;
301         int err;
302
303         err = skcipher_walk_virt(&walk, req, false);
304         if (err)
305                 return err;
306
307         kernel_neon_begin();
308         neon_aes_ecb_encrypt(walk.iv, walk.iv, ctx->twkey, ctx->key.rounds, 1);
309         kernel_neon_end();
310
311         while (walk.nbytes >= AES_BLOCK_SIZE) {
312                 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
313
314                 if (walk.nbytes < walk.total)
315                         blocks = round_down(blocks,
316                                             walk.stride / AES_BLOCK_SIZE);
317
318                 kernel_neon_begin();
319                 fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->key.rk,
320                    ctx->key.rounds, blocks, walk.iv);
321                 kernel_neon_end();
322                 err = skcipher_walk_done(&walk,
323                                          walk.nbytes - blocks * AES_BLOCK_SIZE);
324         }
325         return err;
326 }
327
328 static int xts_encrypt(struct skcipher_request *req)
329 {
330         return __xts_crypt(req, aesbs_xts_encrypt);
331 }
332
333 static int xts_decrypt(struct skcipher_request *req)
334 {
335         return __xts_crypt(req, aesbs_xts_decrypt);
336 }
337
338 static struct skcipher_alg aes_algs[] = { {
339         .base.cra_name          = "__ecb(aes)",
340         .base.cra_driver_name   = "__ecb-aes-neonbs",
341         .base.cra_priority      = 250,
342         .base.cra_blocksize     = AES_BLOCK_SIZE,
343         .base.cra_ctxsize       = sizeof(struct aesbs_ctx),
344         .base.cra_module        = THIS_MODULE,
345         .base.cra_flags         = CRYPTO_ALG_INTERNAL,
346
347         .min_keysize            = AES_MIN_KEY_SIZE,
348         .max_keysize            = AES_MAX_KEY_SIZE,
349         .walksize               = 8 * AES_BLOCK_SIZE,
350         .setkey                 = aesbs_setkey,
351         .encrypt                = ecb_encrypt,
352         .decrypt                = ecb_decrypt,
353 }, {
354         .base.cra_name          = "__cbc(aes)",
355         .base.cra_driver_name   = "__cbc-aes-neonbs",
356         .base.cra_priority      = 250,
357         .base.cra_blocksize     = AES_BLOCK_SIZE,
358         .base.cra_ctxsize       = sizeof(struct aesbs_cbc_ctx),
359         .base.cra_module        = THIS_MODULE,
360         .base.cra_flags         = CRYPTO_ALG_INTERNAL,
361
362         .min_keysize            = AES_MIN_KEY_SIZE,
363         .max_keysize            = AES_MAX_KEY_SIZE,
364         .walksize               = 8 * AES_BLOCK_SIZE,
365         .ivsize                 = AES_BLOCK_SIZE,
366         .setkey                 = aesbs_cbc_setkey,
367         .encrypt                = cbc_encrypt,
368         .decrypt                = cbc_decrypt,
369 }, {
370         .base.cra_name          = "__ctr(aes)",
371         .base.cra_driver_name   = "__ctr-aes-neonbs",
372         .base.cra_priority      = 250,
373         .base.cra_blocksize     = 1,
374         .base.cra_ctxsize       = sizeof(struct aesbs_ctx),
375         .base.cra_module        = THIS_MODULE,
376         .base.cra_flags         = CRYPTO_ALG_INTERNAL,
377
378         .min_keysize            = AES_MIN_KEY_SIZE,
379         .max_keysize            = AES_MAX_KEY_SIZE,
380         .chunksize              = AES_BLOCK_SIZE,
381         .walksize               = 8 * AES_BLOCK_SIZE,
382         .ivsize                 = AES_BLOCK_SIZE,
383         .setkey                 = aesbs_setkey,
384         .encrypt                = ctr_encrypt,
385         .decrypt                = ctr_encrypt,
386 }, {
387         .base.cra_name          = "ctr(aes)",
388         .base.cra_driver_name   = "ctr-aes-neonbs",
389         .base.cra_priority      = 250 - 1,
390         .base.cra_blocksize     = 1,
391         .base.cra_ctxsize       = sizeof(struct aesbs_ctr_ctx),
392         .base.cra_module        = THIS_MODULE,
393
394         .min_keysize            = AES_MIN_KEY_SIZE,
395         .max_keysize            = AES_MAX_KEY_SIZE,
396         .chunksize              = AES_BLOCK_SIZE,
397         .walksize               = 8 * AES_BLOCK_SIZE,
398         .ivsize                 = AES_BLOCK_SIZE,
399         .setkey                 = aesbs_ctr_setkey_sync,
400         .encrypt                = ctr_encrypt_sync,
401         .decrypt                = ctr_encrypt_sync,
402 }, {
403         .base.cra_name          = "__xts(aes)",
404         .base.cra_driver_name   = "__xts-aes-neonbs",
405         .base.cra_priority      = 250,
406         .base.cra_blocksize     = AES_BLOCK_SIZE,
407         .base.cra_ctxsize       = sizeof(struct aesbs_xts_ctx),
408         .base.cra_module        = THIS_MODULE,
409         .base.cra_flags         = CRYPTO_ALG_INTERNAL,
410
411         .min_keysize            = 2 * AES_MIN_KEY_SIZE,
412         .max_keysize            = 2 * AES_MAX_KEY_SIZE,
413         .walksize               = 8 * AES_BLOCK_SIZE,
414         .ivsize                 = AES_BLOCK_SIZE,
415         .setkey                 = aesbs_xts_setkey,
416         .encrypt                = xts_encrypt,
417         .decrypt                = xts_decrypt,
418 } };
419
420 static struct simd_skcipher_alg *aes_simd_algs[ARRAY_SIZE(aes_algs)];
421
422 static void aes_exit(void)
423 {
424         int i;
425
426         for (i = 0; i < ARRAY_SIZE(aes_simd_algs); i++)
427                 if (aes_simd_algs[i])
428                         simd_skcipher_free(aes_simd_algs[i]);
429
430         crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
431 }
432
433 static int __init aes_init(void)
434 {
435         struct simd_skcipher_alg *simd;
436         const char *basename;
437         const char *algname;
438         const char *drvname;
439         int err;
440         int i;
441
442         if (!cpu_have_named_feature(ASIMD))
443                 return -ENODEV;
444
445         err = crypto_register_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
446         if (err)
447                 return err;
448
449         for (i = 0; i < ARRAY_SIZE(aes_algs); i++) {
450                 if (!(aes_algs[i].base.cra_flags & CRYPTO_ALG_INTERNAL))
451                         continue;
452
453                 algname = aes_algs[i].base.cra_name + 2;
454                 drvname = aes_algs[i].base.cra_driver_name + 2;
455                 basename = aes_algs[i].base.cra_driver_name;
456                 simd = simd_skcipher_create_compat(algname, drvname, basename);
457                 err = PTR_ERR(simd);
458                 if (IS_ERR(simd))
459                         goto unregister_simds;
460
461                 aes_simd_algs[i] = simd;
462         }
463         return 0;
464
465 unregister_simds:
466         aes_exit();
467         return err;
468 }
469
470 module_init(aes_init);
471 module_exit(aes_exit);