224 lines
5.6 KiB
Python
224 lines
5.6 KiB
Python
#!/usr/bin/env python3
|
|
import hashlib
|
|
import secrets
|
|
import sys
|
|
|
|
sha256_blocksize = hashlib.sha256().block_size
|
|
sha256_outputsize = hashlib.sha256().digest_size
|
|
|
|
def xor(x, y):
|
|
assert len(x) == len(y)
|
|
for a, b in zip(x, y):
|
|
yield a ^ b
|
|
|
|
def hmac_sha256(key, message):
|
|
# Handle long keys
|
|
# Makes the key the length of hash output
|
|
if len(key) > sha256_blocksize:
|
|
key = sha256(key)
|
|
|
|
# Handle short keys
|
|
# An if, not an elif, since output size < blocksize
|
|
if len(key) < sha256_blocksize:
|
|
key = key + b'\x00' * (sha256_blocksize - len(key))
|
|
|
|
ipad = b'\x36' * sha256_blocksize
|
|
|
|
# Do inner hash
|
|
m = hashlib.sha256()
|
|
m.update(bytes(xor(key, ipad)))
|
|
m.update(message)
|
|
inner = m.digest()
|
|
|
|
opad = b'\x5c' * sha256_blocksize
|
|
|
|
# Do outer hash
|
|
m = hashlib.sha256()
|
|
m.update(bytes(xor(key, opad)))
|
|
m.update(inner)
|
|
outer = m.digest()
|
|
|
|
return outer
|
|
|
|
def ceildiv(p, q):
|
|
assert p >= 0
|
|
assert q > 0
|
|
truncated_result = p // q
|
|
remainder = p % q
|
|
if remainder > 0:
|
|
return truncated_result + 1
|
|
else:
|
|
return truncated_result
|
|
|
|
def hkdf_sha256(salt, key_material, info, length):
|
|
assert length <= 255
|
|
|
|
# Extract
|
|
if salt == b'':
|
|
salt = b'\x00' * sha256_outputsize
|
|
pseudorandom_key = hmac_sha256(salt, key_material)
|
|
|
|
# Expand
|
|
# output[n] corresponds to the T(n) in RFC5869
|
|
# Since T(0) is an empty string, initialize output as [b'']
|
|
output = [b'']
|
|
|
|
# In RFC5869 the indices for the parts we compute are in 1…N, but
|
|
# range(ceildiv(length, sha256_outputsize)) generates 0…N-1
|
|
for index_minus_one in range(ceildiv(length, sha256_outputsize)):
|
|
index = index_minus_one + 1
|
|
output.append(hmac_sha256(pseudorandom_key, output[index_minus_one] + info + bytes([index])))
|
|
|
|
# Cut the output into the size requested
|
|
return b''.join(output)[:length]
|
|
|
|
def hmac_sha256_ctr_keystream(nonce, key):
|
|
# We encrypt a 512 bit block that consist of a 256 bit nonce and a
|
|
# 256 bit counter encoded in big-endian format
|
|
assert len(nonce) == 256//8
|
|
assert len(key) == 256//8
|
|
|
|
def encode_counter(counter):
|
|
encoded_reverse = bytearray()
|
|
for i in range(256//8):
|
|
encoded_reverse.append(counter & 0xff)
|
|
counter >>= 8
|
|
return bytes(reversed(encoded_reverse))
|
|
|
|
counter = 0
|
|
while True:
|
|
yield from hmac_sha256(key, nonce + encode_counter(counter))
|
|
counter += 1
|
|
|
|
def shacrypt_enc(key, plaintext):
|
|
assert len(key) == 256//8
|
|
|
|
# Generate the IVs
|
|
hkdf_salt = secrets.token_bytes(256//8)
|
|
cipher_nonce = secrets.token_bytes(256//8)
|
|
|
|
# Derive keys
|
|
keys = hkdf_sha256(hkdf_salt, key, b'', 512//8)
|
|
del key
|
|
# Create HMAC key before the encryption one, so that an attacker
|
|
# needs to run the full HKDF invocation to get to the encryption
|
|
# key, instead of just half of it which would be the case if they
|
|
# were the other way around
|
|
# No idea if this would end up helping against any attack but hey
|
|
# it's not hurting in the very least
|
|
hmac_key = keys[:256//8]
|
|
cipher_key = keys[256//8:]
|
|
del keys
|
|
|
|
# Encrypt
|
|
ciphered = bytearray()
|
|
for plaintextbyte, keybyte in zip(plaintext, hmac_sha256_ctr_keystream(cipher_nonce, cipher_key)):
|
|
ciphered.append(plaintextbyte ^ keybyte)
|
|
del plaintext
|
|
del cipher_key
|
|
|
|
# Contruct the HMACed part of ciphertext
|
|
hmaced = b''.join((
|
|
hkdf_salt,
|
|
cipher_nonce,
|
|
ciphered
|
|
))
|
|
del ciphered
|
|
|
|
# HMAC
|
|
hmac = hmac_sha256(hmac_key, hmaced)
|
|
del hmac_key
|
|
|
|
# Construct the full ciphertext
|
|
return hmaced + hmac
|
|
|
|
class AuthenticationError(Exception): pass
|
|
|
|
def shacrypt_dec(key, ciphertext):
|
|
assert len(key) == 256//8
|
|
|
|
# Extract the HMACed part of ciphertext
|
|
hmaced = ciphertext[:-sha256_outputsize]
|
|
|
|
# Extract the expected HMAC
|
|
expected_hmac = ciphertext[-sha256_outputsize:]
|
|
|
|
del ciphertext
|
|
|
|
# Extract the IVs
|
|
hkdf_salt = hmaced[0:256//8]
|
|
cipher_nonce = hmaced[256//8:256//8 + 256//8]
|
|
|
|
# Derive keys
|
|
keys = hkdf_sha256(hkdf_salt, key, b'', 512//8)
|
|
del key
|
|
# Create HMAC key before the encryption one, so that an attacker
|
|
# needs to run the full HKDF invocation to get to the encryption
|
|
# key, instead of just half of it which would be the case if they
|
|
# were the other way around
|
|
# No idea if this would end up helping against any attack but hey
|
|
# it's not hurting in the very least
|
|
hmac_key = keys[:256//8]
|
|
cipher_key = keys[256//8:]
|
|
del keys
|
|
|
|
# Verify HMAC
|
|
hmac = hmac_sha256(hmac_key, hmaced)
|
|
del hmac_key
|
|
if not secrets.compare_digest(expected_hmac, hmac):
|
|
raise AuthenticationError
|
|
del expected_hmac
|
|
del hmac
|
|
|
|
# Extract the ciphered part of the ciphertext
|
|
ciphered = hmaced[2 * 256//8:]
|
|
del hmaced
|
|
|
|
# Decrypt
|
|
plaintext = bytearray()
|
|
for cipheredbyte, keybyte in zip(ciphered, hmac_sha256_ctr_keystream(cipher_nonce, cipher_key)):
|
|
plaintext.append(cipheredbyte ^ keybyte)
|
|
del ciphered
|
|
del cipher_nonce
|
|
del cipher_key
|
|
|
|
return plaintext
|
|
|
|
def main():
|
|
if len(sys.argv) != 3:
|
|
print('Usage: %s enc|dec key' % sys.argv[0], file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
try:
|
|
key = bytes.fromhex(sys.argv[2])
|
|
except ValueError:
|
|
print('%s: Error: Key must be hex-encoded' % sys.argv[0], file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
if len(key) != 256//8:
|
|
print('%s: Error: Key must be 256 bits longs' % sys.argv[0], file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
if sys.argv[1] == 'enc':
|
|
plaintext = sys.stdin.buffer.read()
|
|
ciphertext = shacrypt_enc(key, plaintext)
|
|
sys.stdout.buffer.write(ciphertext)
|
|
|
|
elif sys.argv[1] == 'dec':
|
|
ciphertext = sys.stdin.buffer.read()
|
|
try:
|
|
plaintext = shacrypt_dec(key, ciphertext)
|
|
except AuthenticationError:
|
|
print('%s: Error: HMAC mismatch' % sys.argv[0], file=sys.stderr)
|
|
sys.exit(1)
|
|
sys.stdout.buffer.write(plaintext)
|
|
|
|
else:
|
|
print('Usage: %0 enc|dec key' % sys.argv[0], file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
sys.stdout.buffer.flush()
|
|
|
|
if __name__ == '__main__':
|
|
main()
|