import random, math, sys, os
import functools

# pip3 install gmpy2
import gmpy2

# pip3 install hexdump
import hexdump

BITS=64
# p,q = 3900617629,2378309513
n=gmpy2.mpz(9276876013626204677)
e=gmpy2.mpz(17)
good_c=gmpy2.mpz(3453425720950098108) # 00 48 65 6C 6C 6F 6F 6F | .Hellooo

def binlog(x):
    assert (x>0)
    return math.log(int(x), 2)

def short_hex(x):
    if x<10:
        return str(x)
    if binlog(x)<64:
        return hex(x)
    s=hex(x)
    return s[0:12]+"..."+s[-10:]+(" log2=%.1f" % binlog(x))

m_lower_bound=gmpy2.mpz(0x0001000000000000)
m_upper_bound=gmpy2.mpz(0x00FFFFFFFFFFFFFF)

m_min=m_lower_bound
m_max=m_upper_bound
s1_min=n//m_min
s1_max=n//m_max

calls_to_oracle=0

def n_to_bytes(n, BITS):
    return int(n).to_bytes(length=BITS//8, byteorder="big", signed=False)

def oracle(c):
    d=gmpy2.mpz(4911287298007382225)
    assert isinstance(c, gmpy2.mpz)
    plaintext=pow(c, d, n)
    plaintext_b=int(plaintext).to_bytes(length=BITS//8, byteorder="big", signed=False)
    return plaintext_b[0]==0 # first byte is zero

def intervals_intersection(i1_s, i1_e, i2_s, i2_e):
    # https://scicomp.stackexchange.com/questions/26258/the-easiest-way-to-find-intersection-of-two-intervals
    # check for non-intersecting intervals:
    if (i2_s > i1_e) or (i1_s > i2_e):
        # non-intersects
        return None
    return max(i1_s, i2_s), min(i1_e, i2_e)

def step_1():
    t=[]
    print ("step_1() begin")
    global calls_to_oracle
    s=n // m_upper_bound
    while True:
        print("step_1. trying s", short_hex(s),"  \r", end="")
        rt=oracle(good_c*pow(s, e, n) % n)
        calls_to_oracle=calls_to_oracle+1
        if rt:
            print ("rt=True")
            print ("s", short_hex(s))
            overlaps_lower=s//s1_min
            overlaps_higher=s//s1_max
            print ("overlaps in range (inclusive): ", overlaps_lower, overlaps_higher)
            for overlaps in range(overlaps_lower, overlaps_higher+1):
                lower=(m_lower_bound+n*overlaps)//s
                upper=(m_upper_bound+n*overlaps)//s
                if lower>=m_lower_bound and upper<=m_upper_bound:
                    print ("m is between:", short_hex(lower), short_hex(upper))
                    print ("lower:")
                    hexdump.hexdump(n_to_bytes(lower, BITS))
                    print ("upper:")
                    hexdump.hexdump(n_to_bytes(upper, BITS))
                    print ("binlog(diff) %.1f" % binlog(upper-lower))
                    t.append((lower, upper, s))
            return t
        s=s+1
        if s-1 == n:
            assert False

def step_2(a, b, s_start):
    assert isinstance(a, gmpy2.mpz)
    assert isinstance(b, gmpy2.mpz)
    assert isinstance(s_start, gmpy2.mpz)
    print ("step_2")
    print ("a (lower)", short_hex(a))
    print ("b (upper)", short_hex(b))
    print ("s_start", short_hex(s_start))
    global calls_to_oracle
    r_start=2*((b*s_start - m_lower_bound) // n)
    print ("r_start", short_hex(r_start))
    for r in range(r_start, r_start+0x800): # tune this?
        print ("current r=", short_hex(r),"\r", end="")
        # pick s
        s_lower=(m_lower_bound + gmpy2.mpz(r)*n)//b
        s_upper=(m_upper_bound + gmpy2.mpz(r)*n)//a
        print ("s between", hex(s_lower), hex(s_upper))
        for s in range(s_lower, s_upper+1):
            rt=oracle((good_c*pow(gmpy2.mpz(s), e, n)) % n)
            calls_to_oracle=calls_to_oracle+1
            if rt:
                print ("Oracle called with s", short_hex(s))
                print ("Valid padding.")
                # input interval - a,b
                # new interval:
                new_a=(m_lower_bound + gmpy2.mpz(r)*n)//gmpy2.mpz(s) # lower
                new_b=(m_upper_bound + gmpy2.mpz(r)*n)//gmpy2.mpz(s) # upper
                tmp=intervals_intersection(a, b, new_a, new_b)
                if tmp==None:
                    return None
                m_lower, m_upper=tmp
                print ("new m_lower/m_upper", short_hex(m_lower), short_hex(m_upper))
                print ("m_lower:")
                hexdump.hexdump(n_to_bytes(m_lower, BITS))
                print ("m_upper:")
                hexdump.hexdump(n_to_bytes(m_upper, BITS))
                assert m_upper >= m_lower
                if m_lower==m_upper:
                    # solution
                    return gmpy2.mpz(s), m_lower, m_upper
                print ("diff", short_hex(m_upper-m_lower))
                return gmpy2.mpz(s), m_lower, m_upper

    print ("Loop finished!")
    return None

bounds=step_1()
while True:
    new_bounds=[]
    for bound in bounds:
        lower, upper, s=bound
        tmp = step_2(lower, upper, s)
        if tmp!=None:
            s, lower, upper = tmp
            if upper-lower<=1:
                print ("Found! upper:")
                upper_b=n_to_bytes(upper, BITS)
                hexdump.hexdump(upper_b)
                print ("calls_to_oracle", calls_to_oracle)
                #exit(0)
            else:
                new_bounds.append((lower, upper, s))
        print ("calls_to_oracle", calls_to_oracle)
        print ("current len(bounds)", len(bounds))
    bounds=new_bounds
    if len(bounds)==0:
        break

