Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 188 additions & 0 deletions crypto/aescts/aescts.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
// Package aescts implements AES-CTS (Ciphertext Stealing) mode as used by
// Kerberos per RFC 3962. The variant used is CBC-CTS where the last two
// ciphertext blocks are swapped before output (Kerberos / CS3 style).
package aescts

import (
"crypto/aes"
"crypto/cipher"
"errors"
)

const blockSize = 16

// Encrypt encrypts plaintext using AES-CTS with the given key and IV.
// plaintext must be >= 16 bytes (one AES block).
// Output length equals input length.
//
// For plaintext of exactly one block (16 bytes), standard AES-CBC is used
// (no swap is possible with a single block). For two or more blocks, the
// last two blocks in the CBC output are swapped and the output is truncated
// to len(plaintext) bytes.
func Encrypt(key, iv, plaintext []byte) ([]byte, error) {
n := len(plaintext)
if n < blockSize {
return nil, errors.New("aescts: plaintext must be at least 16 bytes")
}

// Special case: exactly one block — CBC with no swap
if n == blockSize {
padded := make([]byte, blockSize)
copy(padded, plaintext)
return aesCBCEncrypt(key, iv, padded)
}

r := n % blockSize // remainder bytes in last partial block (0 = exact multiple)

// Pad plaintext to a multiple of blockSize
paddedLen := n
if r != 0 {
paddedLen = n + (blockSize - r)
}
padded := make([]byte, paddedLen)
copy(padded, plaintext)

// AES-CBC encrypt the padded plaintext
cbcOut, err := aesCBCEncrypt(key, iv, padded)
if err != nil {
return nil, err
}

numBlocks := paddedLen / blockSize
result := make([]byte, n)

if r == 0 {
// Exact multiple of blockSize: swap last two complete blocks.
// CBC output: ... C[n-2] C[n-1]
// CTS output: ... C[n-1] C[n-2]
prefixEnd := n - 2*blockSize
if prefixEnd > 0 {
copy(result[:prefixEnd], cbcOut[:prefixEnd])
}
copy(result[prefixEnd:prefixEnd+blockSize], cbcOut[n-blockSize:n])
copy(result[prefixEnd+blockSize:n], cbcOut[n-2*blockSize:n-blockSize])
} else {
// Non-multiple: CBC gives numBlocks full blocks (last one zero-padded).
// CBC blocks: C[0] ... C[numBlocks-2] C[numBlocks-1]
// CTS output: C[0]...C[numBlocks-3] + C[numBlocks-1](full 16B) + C[numBlocks-2][:r]
// Total: (numBlocks-2)*16 + 16 + r = (numBlocks-1)*16 + r = n ✓
prefixEnd := (numBlocks - 2) * blockSize
penultStart := (numBlocks - 2) * blockSize
lastStart := (numBlocks - 1) * blockSize

if prefixEnd > 0 {
copy(result[:prefixEnd], cbcOut[:prefixEnd])
}
// Full last CBC block
copy(result[prefixEnd:prefixEnd+blockSize], cbcOut[lastStart:lastStart+blockSize])
// First r bytes of penultimate CBC block
copy(result[prefixEnd+blockSize:n], cbcOut[penultStart:penultStart+r])
}

return result, nil
}

// Decrypt decrypts ciphertext using AES-CTS with the given key and IV.
// ciphertext must be >= 16 bytes.
// Output length equals input length.
func Decrypt(key, iv, ciphertext []byte) ([]byte, error) {
n := len(ciphertext)
if n < blockSize {
return nil, errors.New("aescts: ciphertext must be at least 16 bytes")
}

// Special case: exactly one block — CBC with no un-swap
if n == blockSize {
return aesCBCDecrypt(key, iv, ciphertext)
}

r := n % blockSize

if r == 0 {
// Un-swap last two full blocks, then normal CBC decrypt.
buf := make([]byte, n)
copy(buf[:n-2*blockSize], ciphertext[:n-2*blockSize])
copy(buf[n-2*blockSize:n-blockSize], ciphertext[n-blockSize:n])
copy(buf[n-blockSize:n], ciphertext[n-2*blockSize:n-blockSize])
return aesCBCDecrypt(key, iv, buf)
}

// Non-multiple case.
// CTS ciphertext layout (from Encrypt):
// prefix: (numBlocks-2)*16 bytes — blocks C[0]..C[numBlocks-3]
// Clast: 16 bytes — last CBC block (C[numBlocks-1])
// Cpen_partial: r bytes — first r bytes of penultimate CBC block (C[numBlocks-2])
//
// To reconstruct the penultimate CBC block (C[numBlocks-2]):
// AES-ECB-decrypt Clast → X (= P[numBlocks-1]_padded XOR C[numBlocks-2])
// Since P[numBlocks-1] was zero-padded, X[r:] = 0 XOR C[numBlocks-2][r:] = C[numBlocks-2][r:]
// So full C[numBlocks-2] = Cpen_partial + X[r:]
//
// Recover last r bytes of plaintext:
// P[numBlocks-1][:r] = X[:r] XOR C[numBlocks-2][:r] = X[:r] XOR Cpen_partial
//
// Decrypt prefix + C[numBlocks-2] with CBC to get P[0]..P[numBlocks-2].

numBlocks := n / blockSize // integer division; excludes the partial block
prefixLen := (numBlocks - 1) * blockSize
clast := ciphertext[prefixLen : prefixLen+blockSize]
cpenPartial := ciphertext[prefixLen+blockSize:]

// AES-ECB decrypt Clast
blockCipher, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
x := make([]byte, blockSize)
blockCipher.Decrypt(x, clast)

// Reconstruct full penultimate CBC block
cpen := make([]byte, blockSize)
copy(cpen[:r], cpenPartial)
copy(cpen[r:], x[r:])

// Recover last r bytes of plaintext
lastPartial := make([]byte, r)
for i := 0; i < r; i++ {
lastPartial[i] = x[i] ^ cpenPartial[i]
}

// CBC-decrypt prefix + cpen to get P[0]..P[numBlocks-2]
cbcInput := make([]byte, prefixLen+blockSize)
copy(cbcInput[:prefixLen], ciphertext[:prefixLen])
copy(cbcInput[prefixLen:], cpen)

mainPlain, err := aesCBCDecrypt(key, iv, cbcInput)
if err != nil {
return nil, err
}

result := append(mainPlain, lastPartial...)
return result, nil
}

// aesCBCEncrypt performs standard AES-CBC encryption.
// plaintext length must be a multiple of blockSize.
func aesCBCEncrypt(key, iv, plaintext []byte) ([]byte, error) {
blockCipher, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
ciphertext := make([]byte, len(plaintext))
mode := cipher.NewCBCEncrypter(blockCipher, iv)
mode.CryptBlocks(ciphertext, plaintext)
return ciphertext, nil
}

// aesCBCDecrypt performs standard AES-CBC decryption.
// ciphertext length must be a multiple of blockSize.
func aesCBCDecrypt(key, iv, ciphertext []byte) ([]byte, error) {
blockCipher, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
plaintext := make([]byte, len(ciphertext))
mode := cipher.NewCBCDecrypter(blockCipher, iv)
mode.CryptBlocks(plaintext, ciphertext)
return plaintext, nil
}
115 changes: 115 additions & 0 deletions crypto/aescts/aescts_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
package aescts

import (
"bytes"
"crypto/rand"
"testing"
)

// TestEncryptDecryptRoundtrip tests that Encrypt followed by Decrypt returns the original plaintext
// for various input lengths.
func TestEncryptDecryptRoundtrip(t *testing.T) {
key := make([]byte, 16) // AES-128
iv := make([]byte, 16)
rand.Read(key)
rand.Read(iv)

lengths := []int{16, 17, 20, 31, 32, 33, 40, 48, 64, 100}

for _, length := range lengths {
plaintext := make([]byte, length)
rand.Read(plaintext)

ciphertext, err := Encrypt(key, iv, plaintext)
if err != nil {
t.Errorf("Encrypt(%d bytes) error: %v", length, err)
continue
}

if len(ciphertext) != len(plaintext) {
t.Errorf("Encrypt(%d bytes) output length = %d, want %d", length, len(ciphertext), len(plaintext))
continue
}

decrypted, err := Decrypt(key, iv, ciphertext)
if err != nil {
t.Errorf("Decrypt(%d bytes) error: %v", length, err)
continue
}

if !bytes.Equal(decrypted, plaintext) {
t.Errorf("Roundtrip(%d bytes) failed: got %x, want %x", length, decrypted, plaintext)
}
}
}

// TestEncryptDecryptAES256Roundtrip tests with a 256-bit key.
func TestEncryptDecryptAES256Roundtrip(t *testing.T) {
key := make([]byte, 32) // AES-256
iv := make([]byte, 16)
rand.Read(key)
rand.Read(iv)

lengths := []int{16, 20, 32, 40}

for _, length := range lengths {
plaintext := make([]byte, length)
rand.Read(plaintext)

ciphertext, err := Encrypt(key, iv, plaintext)
if err != nil {
t.Errorf("AES-256 Encrypt(%d bytes) error: %v", length, err)
continue
}

decrypted, err := Decrypt(key, iv, ciphertext)
if err != nil {
t.Errorf("AES-256 Decrypt(%d bytes) error: %v", length, err)
continue
}

if !bytes.Equal(decrypted, plaintext) {
t.Errorf("AES-256 Roundtrip(%d bytes) failed", length)
}
}
}

// TestEncryptTooShort verifies that Encrypt rejects plaintext shorter than 16 bytes.
func TestEncryptTooShort(t *testing.T) {
key := make([]byte, 16)
iv := make([]byte, 16)
_, err := Encrypt(key, iv, []byte("short"))
if err == nil {
t.Error("Encrypt with < 16 bytes should return an error")
}
}

// TestDecryptTooShort verifies that Decrypt rejects ciphertext shorter than 16 bytes.
func TestDecryptTooShort(t *testing.T) {
key := make([]byte, 16)
iv := make([]byte, 16)
_, err := Decrypt(key, iv, []byte("short"))
if err == nil {
t.Error("Decrypt with < 16 bytes should return an error")
}
}

// TestEncryptDeterministic verifies that Encrypt with the same inputs produces the same output.
func TestEncryptDeterministic(t *testing.T) {
key := make([]byte, 16)
iv := make([]byte, 16)
plaintext := make([]byte, 32)

c1, err := Encrypt(key, iv, plaintext)
if err != nil {
t.Fatal(err)
}
c2, err := Encrypt(key, iv, plaintext)
if err != nil {
t.Fatal(err)
}

if !bytes.Equal(c1, c2) {
t.Error("Encrypt is not deterministic")
}
}
92 changes: 92 additions & 0 deletions crypto/nfold/nfold.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
// Package nfold implements the N-FOLD function from RFC 3961 Section 5.1.
// N-FOLD is used by Kerberos to generate key derivation constants.
package nfold

// gcd computes the greatest common divisor of a and b.
func gcd(a, b int) int {
for b != 0 {
a, b = b, a%b
}
return a
}

// lcm computes the least common multiple of a and b.
func lcm(a, b int) int {
return a / gcd(a, b) * b
}

// getBit returns the value (0 or 1) of bit p in b.
// Bit 0 is the MSB of b[0], bit 7 is the LSB of b[0], bit 8 is MSB of b[1], etc.
func getBit(b []byte, p int) int {
pByte := p / 8
pBit := uint(p % 8)
return int(b[pByte]>>(8-(pBit+1))) & 0x01
}

// setBit sets bit p in b to v (0 or 1).
func setBit(b []byte, p, v int) {
pByte := p / 8
pBit := uint(p % 8)
b[pByte] = byte(v<<(8-(pBit+1))) | b[pByte]
}

// rotateRight performs a bit-level right rotation of b by step positions.
func rotateRight(b []byte, step int) []byte {
out := make([]byte, len(b))
bitLen := len(b) * 8
for i := 0; i < bitLen; i++ {
v := getBit(b, i)
setBit(out, (i+step)%bitLen, v)
}
return out
}

// onesComplementAddition adds two equal-length byte slices using ones' complement
// (end-around carry) arithmetic, processing from LSB to MSB.
func onesComplementAddition(n1, n2 []byte) []byte {
numBits := len(n1) * 8
out := make([]byte, len(n1))
carry := 0
for i := numBits - 1; i >= 0; i-- {
s := getBit(n1, i) + getBit(n2, i) + carry
setBit(out, i, s&1)
carry = s >> 1
}
if carry == 1 {
// End-around carry: add 1 to the result
carryBuf := make([]byte, len(n1))
carryBuf[len(carryBuf)-1] = 1
out = onesComplementAddition(out, carryBuf)
}
return out
}

// NFold folds the input byte string into n bits (n must be a multiple of 8).
//
// The algorithm (RFC 3961 Section 5.1):
// 1. Let k = len(in)*8 and l = lcm(n, k).
// 2. Build a buffer of l/8 bytes by concatenating l/k copies of the input,
// each copy rotated right by 13*i bits relative to the original.
// 3. XOR (with end-around carry / ones' complement addition) all n-bit
// blocks of the buffer together to produce the n-bit output.
func NFold(in []byte, n int) []byte {
k := len(in) * 8
lcmVal := lcm(n, k)
numCopies := lcmVal / k

// Build the concatenated rotated buffer
var buf []byte
for i := 0; i < numCopies; i++ {
buf = append(buf, rotateRight(in, 13*i)...)
}

// Ones' complement addition of all n-bit (n/8 byte) blocks
result := make([]byte, n/8)
block := make([]byte, n/8)
numBlocks := lcmVal / n
for i := 0; i < numBlocks; i++ {
copy(block, buf[i*(n/8):(i+1)*(n/8)])
result = onesComplementAddition(result, block)
}
return result
}
Loading