Inshall'hack
Security if God wills it
Inshall'hack

RScA Writeup (ECSC French Pre-qualifier 2019)

RScA was one of the four misc challenges of the pre-qualifier, and one of the least solved challenges of the competition. Cryptography-wise, it was by far not the hardest challenge of the CTF, but it involved extracting traffic from two sigrok captures, which is quite uncommon in a challenge I would say. The last time I encountered such a challenge was (iirc) in the SSTIC challenge 2017. Back then, I had just started playing CTFs and getting involved in information security altogether, and I couldn't figure out what to do. I guess the satisfaction I got from solving this challenge today is payback for the frustration I felt two years ago. Without further ado, let's take a look at the challenge!

Challenge description

Note: this is my own translation of the original challenge description.

Following an intervention, our teams retrieved a secure phone belonging to an agent of a criminal organization. Ever since we retrieved it, we have made sure to keep the device on at all times.

This phone is notably used to receive instructions signed by a superior officer of the criminal organization. The device is tasked with verifying the signature of the messages. Through some reverse-engineering, we managed to uncover that the algorithm used for the verification is RSA.

What is very peculiar here is that we have absolutely no knowledge about the public parameters used by the algorithm. Moreover, the verification is implemented in a secure way designed to prevent the retrieval of these parameters.

Our ultimate goal is to be able to impersonate the superior officer by signing messages in his stead in order to set a trap for the other members of the gang. In order to help us do that, you are tasked with finding all the parameters used.

Our reverse-engineering step gave us many important insights.

First, we managed to find the RSA implementation. Here is the pseudocode we reconstructed, where phi(N) denotes the result of applying Euler's totient function to N:

Function RSA(m,e,phi(N),N):
    r  <- random(0,2**32)
    e' <- e + r * phi(N)
    accumulator = 1
    dummy = 1
    for i from len(e’)-1 to 0:
        accumulator <- (accumulator * accumulator) mod N
        tmp <- (accumulator * m) mod N
        if (i-th lsb of e') == 1:
            accumulator <- tmp
        else:
            dummy <- tmp
    return accumulator

Second, we found out that the two modular multiplication operations

accumulator <- (accumulator * accumulator) mod N
tmp <- (accumulator * m) mod N

are computed by a hardware accelerator.

When the device boots, the block responsible for checking the signature recovers the parameters e and N from the SIM card. N is then sent to the accelerator and stored in non-readable SRAM. Rebooting the phone implies locking the SIM card, which we are not able to unlock (its owner not being very useful in communicating the required PIN code).

On the plus side, we managed to recover a piece of the documentation of this accelerator, from which we give you an excerpt. We also managed to sniff the communication bus between this accelerator and the signature verification block.

Ever since we recovered the device, we received two different messages. The data sniffed from the bus during both signature verifications is provided to you. The content of the messages in itself is irrelevant.

Your objective is to recover the parameters (N, p, q, e, d) of the RSA, with:

  • N: the "public" modulus;
  • p, q: the prime factors of N (p * q = N, p < q);
  • e: the "public" exponent stored on the device (0 < e < phi(N));
  • d: the "private" exponent used to sign the messages (0 < d < phi(N)).

The flag is ECSC{N + p + q + e + d} with N + p + q + e + d written in hex.

Along with this description came a zip file containing the description of the protocol as well as two files containing the intercepted traffic for the verification of the signature of the two received messages.

For convenience's sake, I will also write a translation of the description of the protocol here:

Communication protocol

The frames sent from the RSA block to the modular accelerator follow the following pattern:

| senderId | receiverId | opcode | operand1(opt) | operand2(opt) |

The frames sent from the modular accelerator to the RSA block follow the following pattern:

| senderId | receiverId | operand1 |

with :

  • senderId: 1 byte representing the id of the sender block;
  • receiverId: 1 byte representing the id of the receiver block;
  • opcode (optional): 1 byte encoding the operation (see below);
  • operand1: a fixed amount of bytes representing the first operand;
  • operand2: a fixed amount of bytes representing the second operand.

Operations

Loading of the modulus N

Requests the loading of the modulus N in the hardware accelerator.

  • opcode: 0x11
  • operand1: 2 bytes representing the size t of N in bytes
  • operand2: t bytes representing N

Modular addition

Requests the modular addition of two operands of size t. t is inferred by the accelerator based on its knowledge of N.

  • opcode: 0x22
  • operand1: t bytes representing the first operand
  • operand2: t bytes representing the second operand

Modular subtraction

Requests the modular subtraction of two operands of size t. t is inferred by the accelerator based on its knowledge of N.

  • opcode: 0x33
  • operand1: t bytes representing the first operand
  • operand2: t bytes representing the second operand

Modular inversion

Requests the modular inverse of one operand of size t. Returns 0 if the element can't be inverted. t is inferred by the accelerator based on its knowledge of N.

  • opcode: 0x44
  • operand1: t bytes representing the value to invert
  • operand2: absent

Modular multiplication

Requests the modular multiplication of two operands of size t. t is inferred by the accelerator based on its knowledge of N.

  • opcode: 0x55
  • operand1: t bytes representing the first operand
  • operand2: t bytes representing the second operand

Modular squaring

Requests the modular square of one operand of size t. t is inferred by the accelerator based on its knowledge of N.

  • opcode: 0x66
  • operand1: t bytes representing the value to square
  • operand2: absent

Response request

Requests a response following an operation request.

  • opcode: 0x77
  • operand1: absent
  • operand2: absent

The response of size t follows the frame format described at the beginning of the document.

Analysis

Even though I was faced with a sigrok capture more than two years ago, my starting point at the start of this challenge was the same as back then: I had no idea what to do with the files!

But this time, I was confident I could figure it out. Each capture has 8 channels, which means that the throughput is equal to 1 byte/tick, where tick represents the time it takes each channel to transmit one bit.

Now, and this is a bit tricky, we have to figure out how to recompose the output of the channels. There are 8! = 40320 possible orderings of the channels. While in some cases, we might need to check all of them, I felt that it was likely for the channels to be ordered either in order 0-7 or 7-0. There are several reasons for that (ranked from least to most convincing):

  • my gut feeling says that it should be the way a protocol is implemented most of the time, unless the developers explicitly try to obfuscate it;
  • the protocol specification is fairly straightforward (it is not supposed to be public, so I think the developers thought security by obscurity was good enough);
  • I couldn't seem to figure out a fast enough way to figure out which ordering is most likely to be correct to handle that many checks, which is quite a limiting factor.

But how can we check if we got the right stream? Well, first, we have to recompose the stream. I wrote the following function to do so:

def transposeChannels(lines, rev=False):
    res = []
    for channels in [lines[i:i+8] for i in range(0, len(lines), 8)]:
     for i in range(len(channels[0])):
        if channels[0][i] != '1' and channels[0][i] != '0':
            continue
        byte = ''
        for l in range(len(channels)):
            if i >= len(channels[l]):
                break
            byte += channels[l][i]
            if len(byte) == 8:
                if rev:
                    res.append(byte[::-1])
                else:
                    res.append(byte)
                byte = ''
    return res

in which rev indicates whether to reverse the order of the channels to recompose each byte.

From this, we should be able to extract the likely ids used in the communication by the block and the modular accelerator. To do so is fairly easy; since the protocol indicates that frames sent from the RSA block to the modular accelerator use the format | senderId | receiverId | opcode | operand1(opt) | operand2(opt) | for each of the two sniff_[01] files, we compute which pairs of bytes occur often before the byte representing the modular multiplication opcode 0x55 (which we know occurs several times, according to the description of the challenge). Then, we just take the intersection of the obtained sets, and hopefully, we will only get one likely pair. If not, it probably shouldn't be that hard to disambiguate. It is very important that we use the modular multiplication opcode 0x55 and not the modular squaring opcode 0x66 so we can use "being able to find a couple of likely ids" as a metric to decide if we're looking at the channels in the right order. The reason is that 0x66, in its 8 bits representation, is a palindrome. It would thus occur with the same pattern regardless of the order of bits we choose. We may use it to disambiguate the results if we get several possible patterns.

In any case, we can now extract the ids (or the ids written backwards) using the following code:

# Takes some time
def find_likely_ids(stream):
    # Binary representation of modular multiplication opcode, which occurs for sure.
    opcode = bin(0x55)[2:].zfill(8)
    id_pairs = {}
    slice_start = 0

    while slice_start < len(stream):
        try:
            next_cmd = stream[slice_start:].index(opcode)
        except:
            # Our work here is done!
            break
        key = (stream[slice_start + next_cmd - 2],
               stream[slice_start + next_cmd - 1])
        try:
            id_pairs[key] += 1
        except:
            id_pairs[key] = 1
        slice_start += next_cmd + 1

    # 100 is an arbitrary value, but it's probably not too much of a stretch
    # to assume it occurs at least 101 times (i.e. that e' >= 2**100).
    return {k:v for k, v in id_pairs.items() if v > 100}

which we call like this:

def main():
    channels_0 = []
    channels_1 = []

    with open('sniff_0') as sniff:
        channels_0 = [line[3:] for line in sniff.read().split('\n')[2:]]

    with open('sniff_1') as sniff:
        channels_1 = [line[3:] for line in sniff.read().split('\n')[2:]]

    print('[+] Transposing channels (sniff_0)')
    stream_0 = transposeChannels(channels_0, inv=True)

    print('[+] Transposing channels (sniff_1)')
    stream_1 = transposeChannels(channels_1, inv=True)

    print('[+] Finding possible ids (1/2)')
    print('--- This may take a while...')
    likely_ids_0 = find_likely_ids(stream_0)

    print('[+] Finding possible ids (2/2)')
    print('--- This may take a while...')
    likely_ids_1 = find_likely_ids(stream_1)

    likely_ids = list(set(likely_ids_0.keys()) & set(likely_ids_1.keys()))

    assert len(likely_ids) == 1

    id_block, id_accel = likely_ids[0]

    print('*** Found ids: {} {}'.format(id_block, id_accel))

    exit(0)

if __name__ == '__main__':
    main()

We run the code, and here it is:

~$ ./solve.py 
[+] Transposing channels (sniff_0)
[+] Transposing channels (sniff_1)
[+] Finding possible ids (1/2)
--- This may take a while...
[+] Finding possible ids (2/2)
--- This may take a while...
*** Found ids: 10001000 10101010

I cheated a bit here, as I removed the code that checked for possible ids with the "normal" ordering, the reason being that we are simply unable to find a likely pair of ids in that case (this requires checking the output of the function a little bit, but I would think it is quite trivial to do; since this writeup is a bit time-constrained, I will not explain it, but feel free to reach out if you can't figure it out yourself). This means that we now know the right order (7-0) and the ids!

To proceed further, we will need to compute \(t\), which is the size of \(N\) in bytes. Since we not only recovered the ids but also the number of times they occur in front of the modular squaring operation, we are able to estimate it.

For each signature verification, there should be exactly as many modular squaring operations as there are bits in the corresponding \(e'\). Now, we know that

$$ e' = e + r \times \phi(N) $$

where \(\phi(N) = (p - 1) \times (q - 1)\).

From that second formula, it is easy to see that \(\phi(N)\) has the same number of bits as N2 Since r is generated randomly and \(r \in \{0, 1, \ldots, 2^{32}\}\), it will be 32 bits long 50% of the time.

The exact size of \(e\) hardly matters here; since we have \(e < N\), it is a negligible term in the equation of \(e'\).

With this, we can deduce the following equation:

$$\begin{align} log_2(e') &= log_2(r) + log_2(\phi(N)) \\ &= log_2(r) + log_2(N)\\ &= log_2(r) + (8 \times t)\\ \iff t &= \frac{log_2(e') - log_2(r)}{8}\\ \iff t &= \frac{n_{sq} - log_2(r)}{8} \end{align}$$

with \(n_{sq}\) the number of squaring operations during one of the signature verifications. \(t\) will probably not be an integer at that point, so we will have to round it up. We thus add the following lines to our main function:

    # Max number of occurrences across files to get better odds of log_2(r) = 32.
    max_rounds = max(likely_ids_0[likely_ids[0]], likely_ids_1[likely_ids[0]])
    likely_t = int(ceil((max_rounds - 32) / 8))
    print('*** Likely value for t: {}'.format(likely_t))

We get:

*** Likely value for t: 128

which looks nice. That would make \(N\) 1024 bits long, which sounds like a common RSA modulus length.

Using this, we should now be able to generate the opcodes embedded in the captures. We write the following fairly trivial function based on the specification that was provided:

def generate_opcodes(code, t, id_block, id_accel):
    min_t = 0
    offset = 0
    opcodes = []
    while offset < len(code) - 2:
        if code[offset] == id_block and code[offset + 1] == id_accel:
            opcode = int(code[offset + 2], 2)
            if   opcode == 0x11:
                op = 'LOAD_N'
                min_t = t + 2

            elif opcode == 0x22:
                op = 'ADD_MOD'
                min_t = 2 * t

            elif opcode == 0x33:
                op = 'SUB_MOD'
                min_t = 2 * t

            elif opcode == 0x44:
                op = 'INV_MOD'
                min_t = t

            elif opcode == 0x55:
                op = 'MUL_MOD'
                min_t = 2 * t

            elif opcode == 0x66:
                op = 'SQU_MOD'
                min_t = t

            elif opcode == 0x77:
                # NO LOGGING NEEDED
                #opcodes.append('SND_WIZ')

                min_t = 0
            else:
                op = 'ERROR: READING'
                min_t = 0
            offset += 3
        elif code[offset] == id_accel and code[offset + 1] == id_block:
            op = 'RESULT'
            min_t = t
            offset += 2
        else:
            continue

        result = []

        min_t_iter = min_t

        while offset < len(code) - 2:
            if min_t_iter <= 0:
                if code[offset] == id_block and code[offset + 1] == id_accel or \
                   code[offset] == id_accel and code[offset + 1] == id_block:
                    break
                elif not code[offset] == '00000000':
                    opcodes.append('ERROR: FAILURE TO COMPLETE ARGUMENT PARSING')
            else:
                result.append(code[offset])
            min_t_iter -= 1
            offset += 1

        if min_t == t + 2:
            opcodes.append((op, str(int(''.join(result[:2]), 2)), str(int(''.join(result[2:]), 2))))
        elif min_t == t:
            opcodes.append((op, str(int(''.join(result), 2)), ''))
        elif min_t == 2 * t:
            opcodes.append((op, str(int(''.join(result[:t]), 2)), str(int(''.join(result[t:]), 2))))
        elif min_t == 0:
            pass
        else:
            print('unknown case')

    return opcodes

Then, we add the following lines to the main function:

    def and_(x, y):
        return x and y

    gen_code_0 = generate_opcodes(stream_0, likely_t, id_block, id_accel)
    assert reduce(and_, map(lambda op: not 'ERROR' in op[0], gen_code_0))
    print(gen_code_0)

    gen_code_1 = generate_opcodes(stream_1, likely_t, id_block, id_accel)
    assert reduce(and_, map(lambda op: not 'ERROR' in op[0], gen_code_1)
    print(gen_code_1)

Since our assertions don't crash the program, I figured I parsed it correctly at this point. We can check that by looking at the printed code, but essentially, the generated code looks like iterations of the following block:

SQU_MOD operand
RESULT  result
MUL_MOD operand1 operand2
RESULT  result

So now, for each block, we have the operand of a modular squaring, the operands of a modular multiplication, and the result of both these operations.

Saying that \(x \equiv y\;[N]\) is the same as saying \(x - y \equiv 0\;[N]\), meaning that \(N\) for each operation, \(N\) divides the (non-modular) result of the operation from which the result returned by the modular accelerator is subtracted.

Thus, we can just compute these values for all the operations we recovered. Since \(N\) divides all of them, their \(gcd\) is very likely to be \(N\) itself. We write the following:

def retrieve_N(gen_code):
    N = 0
    values = []
    for i in range(0, len(gen_code), 4):
        values.append(int(gen_code[i][1]) ** 2 - int(gen_code[i + 1][1]))
        values.append(int(gen_code[i + 2][1]) * int(gen_code[i + 2][2]) - int(gen_code[i + 3][1]))
    return reduce(gcd, values)

and add these lines to the main function:

    N = retrieve_N(gen_code_0)
    print('*** Found N: {}'.format(N))

We get:

*** Found N: 39532903641186369417531241952778798573569820048621097137963974782357845447014672434516903618675081216538742438887388080169354529805688887489468107774520659661114620111475883029401531763167346057985041757563018480022267471866276386314296399361514313879994207760483372229182433725669005488177446439016020179361

We can quickly check that it does not have any trivial divisor, which convinces us that it is indeed the right modulus.

Now, using these blocks, we are also able to recover both \(e'\) values (in the following they will respectively be called \(e'_0\) and \(e'_1\)). Remember that the pseudocode of the signature verification function tells us that the starting value of accumulator in round \(i + 1\) depends on the value of bit \(i\) in \(e'\). With that in mind, we can write the following function:

def retrieve_eprime(gen_code):
    e_prime = ''
    accumulator = gen_code[1][1]
    tmp = gen_code[3][1]
    for i in range(4, len(gen_code), 4):
        if gen_code[i][1] == tmp:
            e_prime = '1' + e_prime
        else:
            e_prime = '0' + e_prime
        accumulator = gen_code[i + 1][1]
        tmp = gen_code[i + 3][1]
    return int(('1' + e_prime)[::-1], 2)

and we add the following lines to the main function:

    eprime_0 = retrieve_eprime(gen_code_0)
    print("*** Found e'0: {}".format(eprime_0))

    eprime_1 = retrieve_eprime(gen_code_1)
    print("*** Found e'1: {}".format(eprime_1))

and there we go:

*** Found e'0: 129965543810659734688770529521858137674334091569073825487229970942558060325544844978933246149171394151779717554485110194419323212394396758046324617920283553807698690726394496071281895993741380960799754332246777239716916776520701682098652181969923994409035291727064305676927525922613768036944604558966724699647777366459
*** Found e'1: 151254433565461084178504508273750043658742643904549264256571382093081116556889793132226043656402175132526030145049170142935738670828566243963121619260615043832829272410592930734054500208769567196020510647952649238369072759736515848463672631369633294812752571110203947161503366628386167840126244206390801734284139629359

Now, to make sure we didn't make any mistake, we can reimplement the verification function whose pseudocode was in the challenge description, along with a function that retrieves the results at each step to check that our values are correct every step of the way:

def get_results(gen_code):
    results = []
    for i in range(0, len(gen_code), 4):
        results.append((int(gen_code[i + 1][1]), int(gen_code[i + 3][1])))
    return results

def rsa_test(m, eprime, N, results=None):
    accumulator = 1
    dummy = 1
    count = 0
    for i in bin(eprime)[2:]:
        accumulator = (accumulator * accumulator) % N
        tmp = (accumulator * m) % N
        if results:
            assert( accumulator == results[count][0] )
            assert( tmp == results[count][1] )
            count += 1
        if int(i) == 1:
            accumulator = tmp
        else:
            dummy = tmp
    return accumulator

We can just retrieve the message (the second operand of the first modular multiplication operation) and the signatures (the last result) in the generated code from each capture.

Then, we can add the following to the main function:

    gen_res_0 = get_results(gen_code_0)
    gen_res_1 = get_results(gen_code_1)

    # Test 1
    m = 10259772985491196142881149901633027701028043458646400390933712730333726596752159375714275345825043444507057516263060761558024690460540549870135724519397576851019728499317267670692619749356314832956631301640316367939452416680868911516271871510429085408074911806254
    mres = 939483704564581016655829553143855821373675412534738730204007962895358973752804393694756669586085591866301743965981206731369100994188871497614592773654444492547012361456203634454894351566013093317556843788922988409877258582088643265972618746918693831785987776780300079330647446534735056833076001311125397486
    assert rsa_test(m, eprime_0, N, gen_res_0) == mres

    # Test 2
    m = 11062666244680075991404761081932977311286296883672449327998617931649913662482584231900090154057791291194378340219352352813919133510297463540896941749476614681660132543468215880995792601154170272426906500742928980877609104950245856832311277984553409975953644307554816225840977501894895918
    mres = 37335712485806667929706116072030001023276084273990124699977034802100192388540111894501624963046848064082160361240058338665910425565647764707245623812960711880453651599979708868023662982160030771388212741074706127917689446300979533957648718744757074601933242896702225112607173976663058381302681134259813261582
    assert rsa_test(m, eprime_1, N, gen_res_1) == mres

Note: I started out with the assumption that the verification algorithm was reading \(e'\) from its LSB. Unfortunately, that was not the case, and this causes many problems down the line (which I had a hard time pinpointing, especially since it wouldn't break the verification function for the given examples since I implemented it with the same mistake). This should however be noticeable, mainly because the signature will then depend on \(r\), which is a behaviour unlikely to be desired. With \(e'\) in the right order however, we are basically done with this challenge.

Retrieving the private key

Now, we recovered all the possible data from the files. How can we get the other values we need? Actually, there is only one more step.

The goal here is to find \(\phi(N)\), which is the last element we're missing to retrieve the secret key. It is fairly trivial to show why that is:

$$\tag{1} N = p \times q $$
$$\begin{align}\tag{2} \phi(N) &= (p - 1) * (q - 1)\\ &= p \times q - q - p + 1\\ &= N - (p + q) + 1\\ \iff p + q &= N - phi(N) + 1 \end{align}$$
$$\begin{align}\tag{3} (p + q)^2 &= p^2 + 2 \times p \times q + q^2\\ &= p^2 + 2 \times N + q^2 \end{align}$$
$$\begin{align}\tag{4} (p - q)^2 &= p^2 - 2 \times p \times q + q^2\\ &= p^2 - 2 \times N + q^2\\ &= (p + q)^2 - 4 \times N\\ \end{align}$$
$$\tag{5} (p - q) = \sqrt{(p - q)^2} $$
$$\begin{align}\tag{6} \frac{(p + q) + (p - q)}{2} &= \frac{2 \times p}{2}\\ &= p \end{align}$$

I will not detail this any further as I assume the reader is familiar enough with RSA to understand that we are pretty much done here.

In the following, we will assume \(e'_0 \geq e'_1\). We can do so without loss of generality, and it avoids having to account for negative numbers.

We have the following two equations:

$$ e'_0 = e + r_0 \times \phi(N)\\ e'_1 = e + r_1 \times \phi(N) $$

We can therefore write:

$$\begin{align} e'_0 - e'_1 &= e + r_0 \times \phi(N) - (e + r_1 \times \phi(N))\\ &= (r_0 - r_1) \times \phi(N) \end{align}$$

This means that \(e'_0 - e'_1\) is a multiple of \(\phi(N)\).

There are several ways to get the flag from here. From discussing with the challenge author, I gathered that the one presented here was not an intended solution.

The intended solution was to run an algorithm not unlike the Miller-Rabin test1. In this particular case, I would argue that the technique outlined here is better (though less generic), as it is never guaranteed that the other algorithm will ever converge:

Since we know that \(0 \leq r_0, r_1 < 2^{32}\), we can deduce that \(0 \leq r_0 - r_1 < 2^{32}\). We could probably bruteforce this whole range of potential divisors, to be perfectly honest, but there is a very simple way to drastically reduce the range of integers to check.

First, we have to realize that \(N\) and \(\phi(N)\) have the same number of bits in their binary representation2.

Thus, we can deduce that \(2^{\lceil{log2(e'_0 - e'_1) - log2(N))}\rceil + 1}\) is an upper bound of \(r_0 - r_1\).

Second, since \(\phi(N) < N\), for any natural number \(x\), we have \(\frac{x}{N} < \frac{x}{\phi(N)}\), and thus, \(\frac{e'_0 - e'_1}{N}\) is a lower bound for \(\frac{e'_0 - e'_1}{\phi(N)} = r'_0 - r'_1\).

One could argue that we could just use \(2^{32}\) as the upper bound. It actually doesn't matter much, because, if \(p\) and \(q\) are roughly the same size, \(N\) and \(\phi(N)\) share roughly the first half of their bits2.

This is fairly easy to see from how \(\phi(N)\) is calculated:

\(\phi(N) = N - (p + q) + 1\)

With \(p\) and \(q\) roughly of length \(\frac{log_2(N)}{2}\) bits, \(p + q\) is roughly of length \(\frac{log_2(N)}{2}\) as well. Thus, roughly the first half of the bits of \(N\) are untouched in the computation of \(\phi(N)\).

It's unlikely that \(p\) and \(q\) differ by more than a few bits in length (mostly because this would make one of the factors much smaller and thus compromise the security of the key. Even if that were the case, though, \(N\) would still share some percentage of its most significant bits with \(\phi(N)\). This means that \(\frac{e'_0 - e'_1}{N}\) is not only a lower bound for \(\frac{e'0 - e'1}{\phi(N)} = r'_0 - r'_1\), but also very close to \(r'_0 - r'_1\).

Spoiler alert: in the end, we actually find out that \(r'_0 - r'_1 = \frac{e'_0 - e'_1}{N} + 1\)

Let's just write the code:

def solve(N, eprime_0, eprime_1):
    mult_phi = abs(eprime_0 - eprime_1)

    lower_bound = mult_phi // N
    upper_bound = 2**int(ceil(log(mult_phi, 2) - log(N, 2)) + 1)

    for divisor in range(lower_bound, upper_bound):
        if mult_phi % divisor != 0:
            continue

        phi = mult_phi // divisor

        if phi > N:
            continue

        p_plus_q  = N - phi + 1
        p_minus_q = isqrt(p_plus_q**2 - 4 * N)

        p = (p_plus_q + p_minus_q) // 2

        if N % p == 0:
            q = N // p

            assert( p * q == N )
            assert( (p - 1) * (q - 1) == phi )

            e = eprime_0 % phi
            d = invert(e, phi)

            assert( e * d % phi == 1 )
            print('[+] phi - lower_bound: ' + str(divisor - lower_bound))
            flag = hex(N + p + q + d + e)[2:].replace('L', '')
            print('[+] phi: ' + str(phi))
            print('[+] p: ' + str(p))
            print('[+] q: ' + str(q))
            print('[+] e: ' + str(e))
            print('[+] d: ' + str(d))
            print('[+] Flag: ECSC{' + flag + '}')

            return True

    return False

and we add the following lines to the main function:

    print("[+] Attempting to solve for N, e'0, e'1")

    if not solve(N, eprime_0, eprime_1):
        print("[-] Could not find a solution for the extracted values.")
        exit(-1337)

    exit(0)

Final output:

[+] Transposing channels (sniff_0)
[+] Transposing channels (sniff_1)
[+] Finding possible ids (1/2)
--- This may take a while...
[+] Finding possible ids (2/2)
--- This may take a while...
*** Found ids: 10001000 10101010
*** Likely value for t: 128
*** Found N: 39532903641186369417531241952778798573569820048621097137963974782357845447014672434516903618675081216538742438887388080169354529805688887489468107774520659661114620111475883029401531763167346057985041757563018480022267471866276386314296399361514313879994207760483372229182433725669005488177446439016020179361
*** Found e'0: 129965543810659734688770529521858137674334091569073825487229970942558060325544844978933246149171394151779717554485110194419323212394396758046324617920283553807698690726394496071281895993741380960799754332246777239716916776520701682098652181969923994409035291727064305676927525922613768036944604558966724699647777366459
*** Found e'1: 151254433565461084178504508273750043658742643904549264256571382093081116556889793132226043656402175132526030145049170142935738670828566243963121619260615043832829272410592930734054500208769567196020510647952649238369072759736515848463672631369633294812752571110203947161503366628386167840126244206390801734284139629359
[+] Attempting to solve for N, e'0, e'1
[+] phi - lower_bound: 1
[+] phi: 39532903641186369417531241952778798573569820048621097137963974782357845447014672434516903618675081216538742438887388080169354529805688887489468107774520646344819072101436096683444077038580304993930651759466440859160043363554061699063584633301041308097794082790669772478160874388073432974744957640464871239300
[+] p: 8848587944222328531083032308855816294720468610664598838702705332974037285300246480672785711167321565791768077382735171706419739033232740852638634921307991
[+] q: 4467707603787711255262925145868770746343585779333497738918156891134274929387004231093274761838460634333201736217015849852917856539280691636159916227632071
[+] e: 2765067788957045200870834631614214207369424811426997051414452805110629733312062161712026321137508875529054600664465357290609093525312073162592063981227903998526388812766624039265368164097115018233109906341164808639088309002616903244755015494320639428517407319299541866354313176786436852533921990092759
[+] d: 28019824448617635433530688844277876772120444498558603440664224054634795077183625220074993682328982647428864762191714345782213841700748456866866216067340906845312633983358504610398413238442629960048583085368740277822457766985619853974201362422613274345625884164713394784569986087049660125844460017564787209439
[+] Flag: ECSC{6032c29a8ea673552db65e8e680daa4c35350310039b735dd4aae92029c169f3119521617bdcf5726ee1aa61b491e00c0d2b3db663bd8c356c644b31a53db7a3a58011a8f7fda931dd074c915b25808506db444a8d650a9ef02ded4469f7ae83e853932184ba687cab0e78d70304daa54bbcb06f802d0f12c8e648221360e175}

And flagged!

Conclusion

This challenge was a lot of fun. I figured out the whole solution quite fast, but I spent several days confused after that because I read \(e'\) in reverse order.

I am a bit frustrated that I didn't get to spend that much time on the pre-qualifier (what, with work and all), because the challenges were really great, and there are still some I haven't solved and really would like to.

Thanks to everyone involved for organizing such a nice competition!

Writeup by SIben (Casimir).

Appendix: my complete solution code

#!/usr/bin/python3
from functools import reduce
from gmpy2 import gcd, invert, isqrt
from math import ceil, log

def transposeChannels(lines, rev=False):
    res = []
    for channels in [lines[i:i+8] for i in range(0, len(lines), 8)]:
     for i in range(len(channels[0])):
        if channels[0][i] != '1' and channels[0][i] != '0':
            continue
        byte = ''
        for l in range(len(channels)):
            if i >= len(channels[l]):
                break
            byte += channels[l][i]
            if len(byte) == 8:
                if rev:
                    res.append(byte[::-1])
                else:
                    res.append(byte)
                byte = ''
    return res

# Takes some time
def find_likely_ids(stream):
    # Binary representation of modular squaring opcode, which occurs for sure.
    opcode = bin(0x55)[2:].zfill(8)
    id_pairs = {}
    slice_start = 0

    while slice_start < len(stream):
        try:
            next_cmd = stream[slice_start:].index(opcode)
        except:
            # Our work here is done!
            break
        key = (stream[slice_start + next_cmd - 2],
               stream[slice_start + next_cmd - 1])
        try:
            id_pairs[key] += 1
        except:
            id_pairs[key] = 1
        slice_start += next_cmd + 1

    # 100 is an arbitrary value, but it's probably not too much of a stretch
    # to assume it occurs at least 101 times (i.e. that e' >= 2**100).
    return {k:v for k, v in id_pairs.items() if v > 100}


def rsa_test(m, eprime, N, results=None):
    accumulator = 1
    dummy = 1
    count = 0
    for i in bin(eprime)[2:]:
        accumulator = (accumulator * accumulator) % N
        tmp = (accumulator * m) % N
        if results:
            assert( accumulator == results[count][0] )
            assert( tmp == results[count][1] )
            count += 1
        if int(i) == 1:
            accumulator = tmp
        else:
            dummy = tmp
    return accumulator

def get_results(gen_code):
    results = []
    for i in range(0, len(gen_code), 4):
        results.append((int(gen_code[i + 1][1]), int(gen_code[i + 3][1])))
    return results

def retrieve_N(gen_code):
    N = 0
    values = []
    for i in range(0, len(gen_code), 4):
        values.append(int(gen_code[i][1]) ** 2 - int(gen_code[i + 1][1]))
        values.append(int(gen_code[i + 2][1]) * int(gen_code[i + 2][2]) - int(gen_code[i + 3][1]))
    return reduce(gcd, values)


def retrieve_eprime(gen_code):
    e_prime = ''
    accumulator = gen_code[1][1]
    tmp = gen_code[3][1]
    for i in range(4, len(gen_code), 4):
        if gen_code[i][1] == tmp:
            e_prime = '1' + e_prime
        else:
            e_prime = '0' + e_prime
        accumulator = gen_code[i + 1][1]
        tmp = gen_code[i + 3][1]
    return int(('1' + e_prime)[::-1], 2)

def generate_opcodes(code, t, id_block, id_accel):
    min_t = 0
    offset = 0
    opcodes = []
    while offset < len(code) - 2:
        if code[offset] == id_block and code[offset + 1] == id_accel:
            opcode = int(code[offset + 2], 2)
            if   opcode == 0x11:
                op = 'LOAD_N'
                min_t = t + 2

            elif opcode == 0x22:
                op = 'ADD_MOD'
                min_t = 2 * t

            elif opcode == 0x33:
                op = 'SUB_MOD'
                min_t = 2 * t

            elif opcode == 0x44:
                op = 'INV_MOD'
                min_t = t

            elif opcode == 0x55:
                op = 'MUL_MOD'
                min_t = 2 * t

            elif opcode == 0x66:
                op = 'SQU_MOD'
                min_t = t

            elif opcode == 0x77:
                # NO LOGGING NEEDED
                #opcodes.append('SND_WIZ')

                min_t = 0
            else:
                op = 'ERROR: READING'
                min_t = 0
            offset += 3
        elif code[offset] == id_accel and code[offset + 1] == id_block:
            op = 'RESULT'
            min_t = t
            offset += 2
        else:
            continue

        result = []

        min_t_iter = min_t

        while offset < len(code) - 2:
            if min_t_iter <= 0:
                if code[offset] == id_block and code[offset + 1] == id_accel or \
                   code[offset] == id_accel and code[offset + 1] == id_block:
                    break
                elif not code[offset] == '00000000':
                    opcodes.append('ERROR: FAILURE TO COMPLETE ARGUMENT PARSING')
            else:
                result.append(code[offset])
            min_t_iter -= 1
            offset += 1

        if min_t == t + 2:
            opcodes.append((op, str(int(''.join(result[:2]), 2)), str(int(''.join(result[2:]), 2))))
        elif min_t == t:
            opcodes.append((op, str(int(''.join(result), 2)), ''))
        elif min_t == 2 * t:
            opcodes.append((op, str(int(''.join(result[:t]), 2)), str(int(''.join(result[t:]), 2))))
        elif min_t == 0:
            pass
        else:
            print('unknown case')

    return opcodes


def solve(N, eprime_0, eprime_1):
    """
    The goal here is to find phi(N), which is the last element we're missing
    to retrieve the secret key.

    It is fairly trivial to show why that is:

    (1) N = p * q

    (2)     phi(N) = (p - 1) * (q - 1)
                   = p * q - q - p + 1
                   = N - (p + q) + 1
        <=> p + q  = N - phi(N)  + 1

    (3) (p + q)^2 = p^2 + 2 * p * q + q^2
                  = p^2 + 2 * N + q^2

    (4) (p - q)^2 = p^2 - 2 * p * q + q^2
                  = p^2 - 2 * N + q^2
                  = (p + q)^2 - 4 * N

    (5) (p - q) = sqrt((p - q)^2)

    (6) ((p + q) + (p - q)) / 2 = 2 * p / 2
                                = p

    I will not detail this any further as I assume the reader is familiar
    enough with RSA to know that we're pretty much done here.

    In the following, we will assume e'0 >= e'1. We can do so without loss of
    generality, and it avoids having to account for negative numbers.

    We have the following two equations:
    e'0 = e + r0 * phi(N)
    e'1 = e + r1 * phi(N)

    We can therefore write:
    e'0 - e'1 = e + r0 * phi(N) - (e + r1 * phi(N))
              = (r0 - r1) * phi(N)

    This means that e'0 - e'1 is a multiple of phi(N).

    There are several ways to get the flag from here. From discussing with
    the challenge author, I gathered that the one presented here was not an
    intended solution.

    The intended solution was to run an algorithm not unlike the Miller-Rabin
    test [1]. In this particular case, I would argue that the technique
    outlined here is better (though less generic), as it is never guaranteed
    that the other algorithm will ever converge:

    Since we know that 0 <= r0, r1 < 2**32, we can deduce that
    0 <= r0 - r1 < 2**32. We could probably bruteforce this whole range of
    potential divisors, to be perfectly honest, but there is a very simple way
    to drastically reduce the range of integers to check.

    First, we have to realize that N and phi(N) have the same number of bits
    in their binary representation [2].

    Thus, we can deduce that 2^(ceil(log2(e'0 - e'1) - log2(N)) + 1) is an upper
    bound.

    Second, since phi(N) < N, for any natural number x, we have x/N < x/phi(N),
    and thus, (e'0 - e'1) / N is a lower bound for (e'0 - e'1) / phi(N) =
    (r'0 - r'1).

    One could argue that we could just user 2^32 as the upper bound. It
    actually doesn't matter much, because, if p and q are roughly the same size,
    N and phi(N) share roughly the first half of their bits [2].

    This is fairly easy to see from how phi(N) is calculated:

    phi(N) = N - (p + q) + 1

    With p and q roughly of length log2(N) / 2 bits, p + q is roughly of length
    log2(N) / 2 as well. Thus, roughly the first half of the bits of N are
    untouched in the computation of phi(N).

    It's unlikely that p and q differ by more than a few bits in length (mostly
    because this would make one of the factors much smaller and thus compromise
    the security of the key. Even if that were the case, though, N would still
    share some percentage of its most significant bits with phi(N).

    This means that (e'0 - e'1) / N is not only a lower bound for
    (e'0 - e'1) / phi(N) = (r'0 - r'1), but also very close to (r'0 - r'1).

    In fact, in the following, we can see that
    (r'0 - r'1) = (e'0 - e'1) / N + 1

    [1]: https://groups.google.com/forum/#!topic/sci.crypt/wDV49EsAZQ0
    [2]: https://eprint.iacr.org/2012/666.pdf
    """
    mult_phi = abs(eprime_0 - eprime_1)

    lower_bound = mult_phi // N
    upper_bound = 2**int(ceil(log(mult_phi, 2) - log(N, 2)) + 1)

    for divisor in range(lower_bound, upper_bound):
        if mult_phi % divisor != 0:
            continue

        phi = mult_phi // divisor

        if phi > N:
            continue

        p_plus_q  = N - phi + 1
        p_minus_q = isqrt(p_plus_q**2 - 4 * N)

        p = (p_plus_q + p_minus_q) // 2

        if N % p == 0:
            q = N // p

            assert( p * q == N )
            assert( (p - 1) * (q - 1) == phi )

            e = eprime_0 % phi
            d = invert(e, phi)

            assert( e * d % phi == 1 )
            print('[+] phi - lower_bound: ' + str(divisor - lower_bound))
            flag = hex(N + p + q + d + e)[2:].replace('L', '')
            print('[+] phi: ' + str(phi))
            print('[+] p: ' + str(p))
            print('[+] q: ' + str(q))
            print('[+] e: ' + str(e))
            print('[+] d: ' + str(d))
            print('[+] Flag: ECSC{' + flag + '}')

            return True

    return False

def main():
    channels_0 = []
    channels_1 = []

    with open('sniff_0') as sniff:
        channels_0 = [line[3:] for line in sniff.read().split('\n')[2:]]

    with open('sniff_1') as sniff:
        channels_1 = [line[3:] for line in sniff.read().split('\n')[2:]]

    print('[+] Transposing channels (sniff_0)')
    stream_0 = transposeChannels(channels_0, rev=True)

    print('[+] Transposing channels (sniff_1)')
    stream_1 = transposeChannels(channels_1, rev=True)

    print('[+] Finding possible ids (1/2)')
    print('--- This may take a while...')

    # Presetting likely_ids_0 because find_likely_ids takes some time to run.
    likely_ids_0 = {('10001000', '10101010'): 1054}
    #likely_ids_0 = find_likely_ids(stream_0)

    print('[+] Finding possible ids (2/2)')
    print('--- This may take a while...')

    # Presetting likely_ids_1 because find_likely_ids takes some time to run.
    likely_ids_1 = {('10001000', '10101010'): 1054}
    #likely_ids_1 = find_likely_ids(stream_1)

    likely_ids = list(set(likely_ids_0.keys()) & set(likely_ids_1.keys()))

    assert len(likely_ids) == 1

    id_block, id_accel = likely_ids[0]

    print('*** Found ids: {} {}'.format(id_block, id_accel))

    max_rounds = min(likely_ids_0[likely_ids[0]], likely_ids_1[likely_ids[0]])

    """
    Knowing the likely ids, we can estimating the size t of N in bytes.
    For each signature verification, there should be exactly as many modular
    squaring operations as there are bits in the corresponding e'. Now, we
    know that e' = e + r * phi(N), where phi(N) = (p - 1) * (q - 1). From that
    second formula, it is easy to see that phi(N) has the same number of bits
    as N.
    Since r is generated randomly in the range 0 .. 2**32, 50% of the time, it
    is 32 bits long.
    The exact size of e hardly matters here; since we have e < N, it is a
    negligible term in the equation of e'.
    With this, we can deduce the following equation:
        log(e', 2) = log(r, 2) + log(phi(N), 2)
                   = log(r, 2) + log(N, 2)
                   = log(r, 2) + (8 * t)
    <=>          t = (log(e', 2) - log(r, 2)) / 8
    <=>          t = (nb_mod_sq  - log(r, 2)) / 8

    with nb_mod_sq the number of squaring operations during one of the
    signature verifications.
    """
    likely_t = int(ceil((max_rounds - 32) / 8))

    print('*** Likely value for t: {}'.format(likely_t))

    def and_(x, y):
        return x and y

    gen_code_0 = generate_opcodes(stream_0, likely_t, id_block, id_accel)
    assert reduce(and_, map(lambda op: not 'ERROR' in op[0], gen_code_0))

    gen_code_1 = generate_opcodes(stream_1, likely_t, id_block, id_accel)
    assert reduce(and_, map(lambda op: not 'ERROR' in op[0], gen_code_1))

    N = retrieve_N(gen_code_0)
    print('*** Found N: {}'.format(N))

    eprime_0 = retrieve_eprime(gen_code_0)
    print("*** Found e'0: {}".format(eprime_0))

    eprime_1 = retrieve_eprime(gen_code_1)
    print("*** Found e'1: {}".format(eprime_1))

    gen_res_0 = get_results(gen_code_0)
    gen_res_1 = get_results(gen_code_1)

    # Test 1
    m = 10259772985491196142881149901633027701028043458646400390933712730333726596752159375714275345825043444507057516263060761558024690460540549870135724519397576851019728499317267670692619749356314832956631301640316367939452416680868911516271871510429085408074911806254
    mres = 939483704564581016655829553143855821373675412534738730204007962895358973752804393694756669586085591866301743965981206731369100994188871497614592773654444492547012361456203634454894351566013093317556843788922988409877258582088643265972618746918693831785987776780300079330647446534735056833076001311125397486
    assert rsa_test(m, eprime_0, N, gen_res_0) == mres

    # Test 2
    m = 11062666244680075991404761081932977311286296883672449327998617931649913662482584231900090154057791291194378340219352352813919133510297463540896941749476614681660132543468215880995792601154170272426906500742928980877609104950245856832311277984553409975953644307554816225840977501894895918
    mres = 37335712485806667929706116072030001023276084273990124699977034802100192388540111894501624963046848064082160361240058338665910425565647764707245623812960711880453651599979708868023662982160030771388212741074706127917689446300979533957648718744757074601933242896702225112607173976663058381302681134259813261582
    assert rsa_test(m, eprime_1, N, gen_res_1) == mres

    print("[+] Attempting to solve for N, e'0, e'1")

    if not solve(N, eprime_0, eprime_1):
        print("[-] Could not find a solution for the extracted values.")
        exit(-1337)

    exit(0)

if __name__ == '__main__':
    main()

  1. https://groups.google.com/forum/#!topic/sci.crypt/wDV49EsAZQ0 

  2. https://eprint.iacr.org/2012/666.pdf 


comments powered by Disqus

Receive Updates

ATOM

Contacts