import random, math, sys, os
import functools

# pip3 install gmpy2
import gmpy2

# pip3 install hexdump
import hexdump

HEADER_BITS=16

BITS=256
"""
p=0x0dd29dc4b82f5edd65d0831c00fe9a77f
q=0x0cb0b6a4c94a7bf3e2bc6fb0ebbd2b1a9
m=0x0002123412345678abcdffff001122330002123412345678abcdffff00112233
"""
e=gmpy2.mpz(65537)
n=gmpy2.mpz(79342133121696534918807299511375466604890342423915266425517370442837632246231)
good_c=gmpy2.mpz(1195356876873443205698340669211873604233638791829911421871896423265195311312)
d=gmpy2.mpz(50964581503945237506706606764581427554546768075437342198400081727009258182065)

def binlog(x):
    assert (x>0)
    return math.log(int(x), 2)

def short_hex(x):
    if x<10:
        return str(x)
    if binlog(x)<64:
        return hex(x)
    s=hex(x)
    return s[0:12]+"..."+s[-10:]+(" log2=%.1f" % binlog(x))

B=2**(BITS-HEADER_BITS)
_2B=gmpy2.mpz(2*B)
_3B_1=gmpy2.mpz(3*B)-1

m_min=_2B
m_max=_3B_1
s1_min=n//m_min
s1_max=n//m_max

calls_to_oracle=0

def n_to_bytes(n, BITS):
    return int(n).to_bytes(length=BITS//8, byteorder="big", signed=False)

# Only header is checked
@functools.cache
def oracle_relaxed(c):
    assert isinstance(c, gmpy2.mpz)
    plaintext=pow(c, d, n)
    plaintext_b=int(plaintext).to_bytes(length=BITS//8, byteorder="big", signed=False)
    header=plaintext_b[0:2]
    if header==b"\x00\x02":
        return True
    else:
        return False

# first zero byte must be starting at idx=8+2, not earlier
def search_first_zero_byte(b):
    for i in range(len(b)):
        if b[i]==0:
            return i
    # no zero byte
    return -1

# Header checked plus a zero byte required somewhere after >8 bytes of random padding
@functools.cache
def oracle_real(c):
    assert isinstance(c, gmpy2.mpz)
    plaintext=pow(c, d, n)
    plaintext_b=int(plaintext).to_bytes(length=BITS//8, byteorder="big", signed=False)
    header=plaintext_b[0:2]
    if header==b"\x00\x02":
        idx=search_first_zero_byte(plaintext_b[2:])
        if idx>=8 and idx!=len(plaintext_b[2:])-1: # first zero byte is not the last byte
            return True
        else:
            return False
    else:
        return False

#oracle=oracle_relaxed
oracle=oracle_real

def intervals_intersection(i1_s, i1_e, i2_s, i2_e):
    # https://scicomp.stackexchange.com/questions/26258/the-easiest-way-to-find-intersection-of-two-intervals
    # check for non-intersecting intervals:
    if (i2_s > i1_e) or (i1_s > i2_e):
        # non-intersects
        return None
    return max(i1_s, i2_s), min(i1_e, i2_e)

def step_1():
    t=[]
    print ("step_1() begin")
    global calls_to_oracle
    s=n // _3B_1
    #s=2
    while True:
        print("step_1. trying s", short_hex(s),"  \r", end="")
        rt=oracle(good_c*pow(s, e, n) % n)
        calls_to_oracle=calls_to_oracle+1
        if rt:
            print ("rt=True")
            print ("s", short_hex(s))
            print ("overlaps in range (inclusive): ", s//s1_min, (s//s1_max))
            for overlaps in range(s//s1_min, (s//s1_max)+1):
                lower=(_2B+n*overlaps)//s
                upper=(_3B_1+n*overlaps)//s
                if lower>=_2B and upper<=_3B_1:
                    print ("m is between:", short_hex(lower), short_hex(upper))
                    print ("lower:")
                    hexdump.hexdump(n_to_bytes(lower, BITS))
                    print ("upper:")
                    hexdump.hexdump(n_to_bytes(upper, BITS))
                    print ("binlog(diff) %.1f" % binlog(upper-lower))
                    t.append((lower, upper, s))
            return t
        s=s+1
        if s-1 == n:
            assert False

def step_2(a, b, s_start):
    assert isinstance(a, gmpy2.mpz)
    assert isinstance(b, gmpy2.mpz)
    assert isinstance(s_start, gmpy2.mpz)
    print ("step_2")
    print ("a (lower)", short_hex(a))
    print ("b (upper)", short_hex(b))
    print ("s_start", short_hex(s_start))
    global calls_to_oracle
    r_start=2*((b*s_start - _2B) // n)
    print ("r_start", short_hex(r_start))
    for r in range(r_start, r_start+0x800): # tune this?
        print ("current r=", short_hex(r),"\r", end="")
        # pick s
        s_lower=(_2B + gmpy2.mpz(r)*n)//b
        s_upper=(_3B_1 + gmpy2.mpz(r)*n)//a
        for s in range(s_lower, s_upper+1):
            rt=oracle((good_c*pow(gmpy2.mpz(s), e, n)) % n)
            calls_to_oracle=calls_to_oracle+1
            if rt:
                print ("Oracle called with s", short_hex(s))
                print ("Valid padding.")
                # input interval - a,b
                # new interval:
                new_a=(_2B + gmpy2.mpz(r)*n)//gmpy2.mpz(s)   # lower
                new_b=(_3B_1 + gmpy2.mpz(r)*n)//gmpy2.mpz(s) # upper
                tmp=intervals_intersection(a, b, new_a, new_b)
                if tmp==None:
                    return None
                m_lower, m_upper=tmp
                print ("new m_lower/m_upper", short_hex(m_lower), short_hex(m_upper))
                print ("m_lower:")
                hexdump.hexdump(n_to_bytes(m_lower, BITS))
                print ("m_upper:")
                hexdump.hexdump(n_to_bytes(m_upper, BITS))
                assert m_upper >= m_lower
                if m_lower==m_upper:
                    # solution
                    return gmpy2.mpz(s), m_lower, m_upper
                print ("diff", short_hex(m_upper-m_lower))
                return gmpy2.mpz(s), m_lower, m_upper

    print ("Loop finished!")
    return None

bounds=step_1()
while True:
    new_bounds=[]
    for bound in bounds:
        lower, upper, s=bound
        tmp = step_2(lower, upper, s)
        if tmp!=None:
            s, lower, upper = tmp
            if upper-lower<=1:
                print ("Found! upper:")
                upper_b=n_to_bytes(upper, BITS)
                hexdump.hexdump(upper_b)
                print ("calls_to_oracle", calls_to_oracle)
                exit(0)
            new_bounds.append((lower, upper, s))
        print ("calls_to_oracle", calls_to_oracle)
        print ("current len(bounds)", len(bounds))
    bounds=new_bounds

print ("Main loop finished without any result. Error.")

