[Crypto][Python] D.Bleichenbacher attack on RSA PKCS#1, part II

Click for the previous part.

Full attack

Since I'm not a very good mathematician, I misunderstood Bleichenbacher's paper and implemented the second step simpler. It just narrows bounds recursively. This works. But I admit that this may be not as efficient as the original Bleichenbacher's idea.


import random, math, sys, os
import functools

HEADER_BITS=16

BITS=64
"""
p=3900617629
q=2378309513
m=0x0002112233445566
"""
e=17
n=9276876013626204677
good_c=7799333987906023269

def binlog(x):
    assert (x>0)
    return math.log(int(x), 2)

B=2**(BITS-HEADER_BITS)
_2B=2*B
_3B_1=(3*B)-1

m_min=_2B
m_max=_3B_1
s1_min=n//m_min
s1_max=n//m_max

calls_to_oracle=0

# Only header is checked
@functools.cache
def oracle(c):
    d=4911287298007382225
    plaintext=pow(c, d, n)
    return (plaintext>>(BITS-HEADER_BITS))==2

def intervals_intersection(i1_s, i1_e, i2_s, i2_e):
    # https://scicomp.stackexchange.com/questions/26258/the-easiest-way-to-find-intersection-of-two-intervals
    # check for non-intersecting intervals:
    if (i2_s > i1_e) or (i1_s > i2_e):
        # non-intersects
        return None
    return max(i1_s, i2_s), min(i1_e, i2_e)

def step_1():
    t=[]
    print ("step_1() begin")
    global calls_to_oracle
    s=n // _3B_1
    #s=2
    while True:
        print("step_1. trying s", hex(s),"  \r", end="")
        rt=oracle(good_c*pow(s, e, n) % n)
        calls_to_oracle=calls_to_oracle+1
        if rt:
            print ("rt=True")
            print ("s", hex(s))
            print ("overlaps in range (inclusive): ", s//s1_min, (s//s1_max))
            for overlaps in range(s//s1_min, (s//s1_max)+1):
                lower=(_2B+n*overlaps)//s
                upper=(_3B_1+n*overlaps)//s
                if lower>=_2B and upper<=_3B_1:
                    print ("m is between:", hex(lower), hex(upper))
                    print ("lower:", hex(lower))
                    print ("upper:", hex(lower))
                    print ("binlog(diff) %.1f" % binlog(upper-lower))
                    t.append((lower, upper, s))
            return t
        s=s+1
        if s-1 == n:
            assert False

def step_2(a, b, s_start):
    print ("step_2")
    print ("a (lower)", hex(a))
    print ("b (upper)", hex(b))
    print ("s_start", hex(s_start))
    global calls_to_oracle
    r_start=2*((b*s_start - _2B) // n)
    print ("r_start", hex(r_start))
    for r in range(r_start, r_start+0x800): # tune this?
        print ("current r=", hex(r),"\r", end="")
        # pick s
        s_lower=(_2B + r*n)//b
        s_upper=(_3B_1 + r*n)//a
        for s in range(s_lower, s_upper+1):
            rt=oracle((good_c*pow(s, e, n)) % n)
            calls_to_oracle=calls_to_oracle+1
            if rt:
                print ("Oracle called with s", hex(s))
                print ("Valid padding.")
                # input interval - a,b
                # new interval:
                new_a=(_2B + r*n)//s   # lower
                new_b=(_3B_1 + r*n)//s # upper
                tmp=intervals_intersection(a, b, new_a, new_b)
                if tmp==None:
                    return None
                m_lower, m_upper=tmp
                print ("new m_lower/m_upper", hex(m_lower), hex(m_upper))
                print ("m_lower:", hex(m_lower))
                print ("m_upper:", hex(m_upper))
                assert m_upper >= m_lower
                if m_lower==m_upper:
                    # solution
                    return s, m_lower, m_upper
                print ("diff", hex(m_upper-m_lower))
                return s, m_lower, m_upper

    print ("Loop finished!")
    return None

bounds=step_1()
while True:
    new_bounds=[]
    for bound in bounds:
        lower, upper, s=bound
        tmp = step_2(lower, upper, s)
        if tmp!=None:
            s, lower, upper = tmp
            if upper-lower<=1:
                print ("Found! upper:", hex(upper))
                print ("calls_to_oracle", calls_to_oracle)
                exit(0)
            new_bounds.append((lower, upper, s))
        print ("calls_to_oracle", calls_to_oracle)
        print ("current len(bounds)", len(bounds))
    bounds=new_bounds

print ("Main loop finished without any result. Error.")

r variable used in all these expressions is like 'how many rounds we walked around n'.

TL;DR: we can decrypt any message using calls to 'oracle' equipped with the private RSA key, that can check padding.

Here we see how we can narrow bounds to the single solution:

...

step_2
a (lower) 0x2112233445562
b (upper) 0x2112233445566
s_start 0x464c7467eadc
r_start 0x241db6006
Oracle called with s 0x8c98e8d01400
Valid padding.
new m_lower/m_upper 0x2112233445564 0x2112233445566
m_lower: 0x2112233445564
m_upper: 0x2112233445566
diff 0x2
calls_to_oracle 21007
current len(bounds) 1
step_2
a (lower) 0x2112233445564
b (upper) 0x2112233445566
s_start 0x8c98e8d01400
r_start 0x483b6c00e
Oracle called with s 0x11931d1a06648
Valid padding.
new m_lower/m_upper 0x2112233445565 0x2112233445566
m_lower: 0x2112233445565
m_upper: 0x2112233445566
diff 0x1
Found! upper: 0x2112233445566
calls_to_oracle 21011

Only ~21k calls to oracle for a ~48-bit (plaintext) message.

Raising the bar: less 'relaxed' oracle

But my oracle is 'relaxed'. In fact, openssl does more. It checks also for the second zero byte after random padding. That zero byte must reside somewhere between random padding and the message. And the (random) padding string must be at least 8 bytes long.

Let's fix our toyish oracle:


# Header checked plus a zero byte required somewhere after >8 bytes of random padding
@functools.cache
def oracle_real(c):
    assert isinstance(c, gmpy2.mpz)
    plaintext=pow(c, d, n)
    plaintext_b=int(plaintext).to_bytes(length=BITS//8, byteorder="big", signed=False)
    header=plaintext_b[0:2]
    if header==b"\x00\x02":
        idx=search_first_zero_byte(plaintext_b[2:])
        if idx>=8 and idx!=len(plaintext_b[2:])-1: # first zero byte is not the last byte
            return True
        else:
            return False
    else:
        return False

For full source code, click here. 256-bit RSA already.

Here I use GMP bignum library via Python's GMP bindings. Not that Python's original bignum support is bad. But it's too slow if compared to pure C GMP. And the most work happens here is bignum calculations.

Also, as a side effect, these 'gmpy2' prefixes you see in my source code is good for exploration -- you now clearly see the main data path. The data path where the most important happens.

Now it's all slower, because our 'oracle' is 'tougher' -- it requires more work so that the second zero byte must reside somewhere in the decrypted block.

...

step_2
a (lower) 0x2123412345...ff0011222d log2=241.1
b (upper) 0x2123412345...ff00112233 log2=241.1
s_start 0x1a089fc6e1...ff9222ae58 log2=236.7
r_start 0x9d60b261b8...f641984834 log2=223.3
Oracle called with s 0x34113f8dc3...ff2448aba4 log2=237.7
Valid padding.
new m_lower/m_upper 0x2123412345...ff00112232 log2=241.1 0x2123412345...ff00112233 log2=241.1
m_lower:
00000000: 00 02 12 34 12 34 56 78  AB CD FF FF 00 11 22 33  ...4.4Vx......"3
00000010: 00 02 12 34 12 34 56 78  AB CD FF FF 00 11 22 32  ...4.4Vx......"2
m_upper:
00000000: 00 02 12 34 12 34 56 78  AB CD FF FF 00 11 22 33  ...4.4Vx......"3
00000010: 00 02 12 34 12 34 56 78  AB CD FF FF 00 11 22 33  ...4.4Vx......"3
diff 1
Found! upper:
00000000: 00 02 12 34 12 34 56 78  AB CD FF FF 00 11 22 33  ...4.4Vx......"3
00000010: 00 02 12 34 12 34 56 78  AB CD FF FF 00 11 22 33  ...4.4Vx......"3
calls_to_oracle 1647198

~1.6 million calls to oracle, not that fast. But we got our 256-bit message decrypted.


Click for the next part.

(the post first published at 20230123.)


List of my other blog posts.

Subscribe to my news feed

Yes, I know about these lousy Disqus ads. Please use adblocker. I would consider to subscribe to 'pro' version of Disqus if the signal/noise ratio in comments would be good enough.