deayzl's blog

[SECCON CTF 13 Quals] F is for Flag 본문

CTF writeup/SECCON CTF

[SECCON CTF 13 Quals] F is for Flag

deayzl 2024. 11. 24. 18:49

solved

This challenge is easy but there were some mistakes so I couldn't solve it quickly.

I think I need to practice taking care of details of my solve script fast.

Recently, I'm even thinking of solving leetcode thanks to my poor script writing skills :(

 

This challenge is in c++ and uses std::variant, std::_Function_base::_Base_manager, blahblahblah.

The tricky point of this challenge is that it implements important things in lambda.

 

Some of them are easy to read so naming them was the first step.

 

Following is the part of the main function and it looks like making a structure with raw int values using variant.

 

Following part constructs a structure that recursively calls functions that were initialized by main depending on some values inside variants.

 

Initialize function base structure using function handlers:

 

Calls function handlers depending on the value inside variant:

 

After spending some time on debugging, the calculation of input bytes was found.

And then I wrote a script that prints out arithmetic, bitwise operations.

 

 

 

 

As usual, gdb script is the goat for this shit.

import gdb
import re
import struct

ge = gdb.execute
gp = gdb.parse_and_eval
ge('file F')

def rotl(a1, a2):
    return ((a1 << a2) | (a1 >> (32 - a2))) & 0xffffffff


identifier = 'solve'

def get_parent_func_addr():
    return int(re.findall(r'0x[0-9a-f]+', ge('bt', to_string=True))[1], 16) - 5

with open(f'./inst_{identifier}.txt', 'wt') as f:
    f.write('')
def out(string):
    with open(f'./inst_{identifier}.txt', 'at') as f:
        f.write(string+'\n')
def u32(a1):
    return struct.unpack('<I', a1)[0]

payload = bytes([0x41+i for i in range(0x40)])
everyvalues = []
for i in range(0, 0x40, 4):
    everyvalues.append(u32(bytes(payload[i:i+4])))
idx_history = []
idx_history.append([])
result_values = []
result_idxes = []

def rfind_list(a1 : list, a2 : int):
    print('rfind_list:', hex(a2))
    for i in range(len(a1) - 1, -1, -1):
        if a1[i] == a2:
            return i
    assert False
box = [3, 14, 1, 10, 4, 9, 5, 6, 8, 11, 15, 2, 13, 12, 0, 7]
his = {}
addr = None
get_constant_arg = None
count = 0
prev_res = None
drop_length = 0
bStartXor = False
bStartNewRound = False
skip_count = 0
bDoneCalculating = False
class mybp(gdb.Breakpoint):
    def stop(self):
        global addr, get_constant_arg, count, everyvalues, prev_res, bStartXor, drop_length, bStartNewRound, skip_count, bDoneCalculating, result_values
        rip = int(gp('$rip')) - 0x555555554000
        rax = int(gp('$rax'))
        rbx = int(gp('$rbx'))
        rdi = int(gp('$rdi'))
        rsi = int(gp('$rsi'))
        r12 = int(gp('$r12'))
        rcx = int(gp('$rcx'))
        parent = get_parent_func_addr() - 0x555555554000
        if rip == 0x37A2:
            res = rax & rbx
            out(f'{hex(rax)} & {hex(rbx)} ({hex(res)}) from {hex(parent)}')
        elif rip == 0x39B9:
            rdi = int(gp(f'*(unsigned int*){rdi}'))
            rsi = int(gp(f'*(unsigned int*){rsi}'))
            res = rdi < rsi
            out(f'{hex(rdi)} < {hex(rsi)} ({hex(res)}) from {hex(parent)}')
        elif rip == 0x3A2B:
            rdi = int(gp(f'*(unsigned int*){rdi}'))
            rsi = int(gp(f'*(unsigned int*){rsi}'))
            res = rdi >= rsi
            out(f'{hex(rdi)} >= {hex(rsi)} ({hex(res)}) from {hex(parent)}')
        elif rip == 0x3A9D:
            rdi = int(gp(f'*(unsigned int*){rdi}'))
            rsi = int(gp(f'*(unsigned int*){rsi}'))
            res = rdi == rsi
            out(f'{hex(rdi)} == {hex(rsi)} ({hex(res)}) from {hex(parent)}')
        elif rip == 0x3B07:
            rdi = int(gp(f'*(unsigned int*){rdi}'))
            rsi = int(gp(f'*(unsigned int*){rsi}'))
            res = rdi != rsi
            out(f'{hex(rdi)} != {hex(rsi)} ({hex(res)}) from {hex(parent)}')
            org_value = rdi
            ge(f'set *(unsigned int*)$rdi={rsi}')
            rdi = int(gp(f'*(unsigned int*)$rdi'))
            rsi = int(gp(f'*(unsigned int*)$rsi'))
            res = rdi != rsi
            out(f'{hex(rdi)} != {hex(rsi)} ({hex(res)}) // force equal from {hex(parent)}')
            tmp = rfind_list(everyvalues, org_value)
            print('validate result:', tmp)
            result_values.append(rsi)
            result_idxes.append(tmp)
            if len(result_values) == 16:
                with open('./idx_history.txt', 'wt') as f:
                    f.write(str(idx_history))
                with open('./result_values.txt', 'wt') as f:
                    f.write(str(result_values))
                with open('./result_idxes.txt', 'wt') as f:
                    f.write(str(result_idxes))
                print(idx_history)
                print(result_values)
                print(result_idxes)
                print('done')
                exit()
        elif rip == 0x362F:
            res = rbx % rsi
            out(f'{hex(rbx)} % {hex(rsi)} ({hex(res)}) from {hex(parent)}')
        elif rip == 0x35AC:
            res = (rax * rbx) & 0xffffffff
            out(f'{hex(rax)} * {hex(rbx)} ({hex(res)}) from {hex(parent)}')
        elif rip == 0x3728:
            res = rax | rbx
            out(f'{hex(rax)} | {hex(rbx)} ({hex(res)}) from {hex(parent)}')
        elif rip == 0x3532:
            res = (rax + rbx) & 0xffffffff
            out(f'{hex(rax)} + {hex(rbx)} ({hex(res)}) from {hex(parent)}')
        elif rip == 0x392B:
            res = rotl(r12, rcx)
            out(f'rotl {hex(r12)} {hex(rcx)} ({hex(res)}) from {hex(parent)}')
            if not bStartXor:
                drop_length = len(idx_history[-1])-1
                print('bStartXor')
                if drop_length > 0:
                    print('drop_length:', hex(drop_length))
                    print('idx_history[-1]:', idx_history[-1])
                    idx_history[-1] = [idx_history[-1][-1]]
                else:
                    print('no drop')
                    drop_length = 0
                bStartXor = True
        elif rip == 0x381E:
            res = (rbx << rcx) & 0xffffffff
            out(f'{hex(rbx)} << {hex(rcx)} ({hex(res)}) from {hex(parent)}')
        elif rip == 0x389C:
            res = rbx >> rcx
            out(f'{hex(rbx)} >> {hex(rcx)} ({hex(res)}) from {hex(parent)}')
            if not bStartNewRound:
                drop_length = len(idx_history[-1])-1
                print('bStartNewRound')
                if drop_length > 0:
                    print('drop_length:', hex(drop_length))
                    print('idx_history[-1]:', idx_history[-1])
                    idx_history[-1] = [idx_history[-1][-1]]
                else:
                    print('no drop')
                    drop_length = 0
                count = 1
                bStartNewRound = True
        elif rip == 0x36AE:
            res = rax ^ rbx
            out(f'{hex(rax)} ^ {hex(rbx)} ({hex(res)}) from {hex(parent)}')
        elif rip == 0x76FC:
            addr = rdi
            rsi = int(gp(f'*(unsigned int*)$rsi'))
            rdx = int(gp(f'*(unsigned int*)$rdx'))
            rcx = int(gp(f'*(unsigned int*)$rcx'))
            r8 = int(gp(f'*(unsigned int*)$r8'))
            r9 = int(gp(f'*(unsigned int*)$r9'))
            get_constant_arg = (rsi, rdx, rcx, r8, r9)
        elif rip == 0x7970:
            rsi, rdx, rcx, r8, r9 = get_constant_arg
            res = int(gp(f'*(unsigned int*){addr}'))
            out(f'get constant({hex(rsi)}, {hex(rdx)}, {hex(rcx)}, {hex(r8)} (input), {hex(r9)}) res: {hex(res)}')
            if res == prev_res:
                return False
            prev_res = res
            if res >= 0x10:
                if res == 0x11793013:
                    bDoneCalculating = True
                if bDoneCalculating:
                    return False
                if skip_count != 0:
                    print('skip_count:', skip_count, '->', skip_count - 1)
                    skip_count -= 1
                    return False
                print('count:', hex(count))
                idx_history[-1].append(rfind_list(everyvalues, res))
                count += 1
                if res == 0x4b2133f2 or res == 0x85319541 or res == 0x227a463e or res == 0xea7a6c80:
                    print('skip_count activated after 0x4b2133f2')
                    skip_count = 3
                if bStartNewRound and count == 0x10:
                    tmp_res = []
                    for i in idx_history[-1]:
                        res = 0
                        for j in range(0, 0x1c+4, 4):
                            tmp = ((everyvalues[i] >> j) & 0xf)
                            tmp = box[tmp]
                            res = (tmp << j) | res
                            res &= 0xffffffff
                        tmp_res.append(res)
                        print(hex(everyvalues[i]), '->', hex(res))
                    everyvalues += tmp_res
                    idx_history.append([])
                elif count == 0x20:
                    tmp_res = []
                    for i in idx_history[-1]:
                        print(hex(everyvalues[i]), '->', end=' ')
                        tmp_res.append((everyvalues[i] * 0x4e6a44b9) & 0xffffffff)
                        print(hex(tmp_res[-1]))
                    everyvalues += tmp_res
                    idx_history.append([])
                elif bStartXor and count == 0x20+(0x10-3)*4+drop_length:
                    tmp_res = []
                    assert len(idx_history[-1]) % 4 == 0
                    for i in range(0, len(idx_history[-1]), 4):
                        idx1 = idx_history[-1][i]
                        idx2 = idx_history[-1][i+1]
                        idx3 = idx_history[-1][i+2]
                        idx4 = idx_history[-1][i+3]
                        print(hex(everyvalues[idx1]), hex(everyvalues[idx2]), hex(everyvalues[idx3]), hex(everyvalues[idx4]), '->')
                        print(f'{hex(rotl(everyvalues[idx1], 0x1d))} ^ {hex(rotl(everyvalues[idx2], 0x11))} ^ {hex(rotl(everyvalues[idx3], 0x7))} ^ {hex(everyvalues[idx4])} = ', end=' ')
                        tmp_res.append(rotl(everyvalues[idx1], 0x1d) ^ rotl(everyvalues[idx2], 0x11) ^ rotl(everyvalues[idx3], 0x7) ^ everyvalues[idx4])
                        print(hex(tmp_res[-1]))
                    everyvalues += tmp_res
                    idx_history.append([])
                    count = 0
                    bStartNewRound = False
                    bStartXor = False
        return False


# for i in [0x3B07, 0x392B,0x389C,0x76FC,0x7970]:
#     mybp(f'*0x555555554000+{i}')

for i in [     0x37A2, # and
            0x39B9, # below
            0x3A2B, # above equal
            0x3A9D, # equal
            0x3B07, # not equal
            0x362F, # mod
            0x35AC, # mul
            0x3728, # or
            0x3532, # add
            0x392B, # rotl
            0x381E, # shl
            0x389C, # shr
            0x36AE, # xor
            0x76FC, # get constant
            0x7970, # get constant ret
            ]:
    mybp(f'*0x555555554000+{i}')

# payload = b'AECCON{'
# payload = b''
# payload += b'A'*(0x40-len(payload))
with open('./input.txt', 'wb') as f:
    f.write(payload)
ge('r < input.txt')

 

The tricky part of the script was that the indexes referring to the previous result of calculation were different everytime.

So I needed to take care of exceptions depending on result values from "get constant" and put every values from it inside "everyvalues" list and record all idxes.

It looks kinda dumb to do that but it's actually better and faster than analyzing inputs that get constants and put constants inside the variant blahblah structure.

 

instructions

 

I did analyze the instructions from it and wrote z3 solve script.

Then I got the flag.

 

 

solve.py

from z3 import *
import struct
box = [3, 14, 1, 10, 4, 9, 5, 6, 8, 11, 15, 2, 13, 12, 0, 7]
idx_history = [[15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0], [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31], [32, 33, 34, 35, 33, 34, 35, 36, 34, 35, 36, 37, 35, 36, 37, 38, 36, 37, 38, 39, 37, 38, 39, 40, 38, 39, 40, 41, 39, 40, 41, 42, 40, 41, 42, 43, 41, 42, 43, 44, 42, 43, 44, 45, 43, 44, 45, 46, 44, 45, 46, 47], [32, 33, 34, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60], [61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76], [92, 77, 78, 79, 77, 78, 79, 80, 78, 79, 80, 81, 79, 80, 81, 82, 80, 81, 82, 83, 81, 82, 83, 84, 82, 83, 84, 85, 83, 84, 85, 86, 84, 85, 86, 87, 85, 86, 87, 88, 86, 87, 88, 89, 87, 88, 89, 90, 88, 89, 90, 91], [77, 78, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 92], [106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121], [136, 137, 122, 123, 137, 122, 123, 124, 122, 123, 124, 125, 123, 124, 125, 126, 124, 125, 126, 127, 125, 126, 127, 128, 126, 127, 128, 129, 127, 128, 129, 130, 128, 129, 130, 131, 129, 130, 131, 132, 130, 131, 132, 133, 131, 132, 133, 134, 132, 133, 134, 135], [122, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 136, 137], [151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166], [180, 181, 182, 167, 181, 182, 167, 168, 182, 167, 168, 169, 167, 168, 169, 170, 168, 169, 170, 171, 169, 170, 171, 172, 170, 171, 172, 173, 171, 172, 173, 174, 172, 173, 174, 175, 173, 174, 175, 176, 174, 175, 176, 177, 175, 176, 177, 178, 176, 177, 178, 179], [183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 180, 181, 182], [196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211], [225, 226, 227, 212, 226, 227, 212, 213, 227, 212, 213, 214, 212, 213, 214, 215, 213, 214, 215, 216, 214, 215, 216, 217, 215, 216, 217, 218, 216, 217, 218, 219, 217, 218, 219, 220, 218, 219, 220, 221, 219, 220, 221, 222, 220, 221, 222, 223, 224, 225, 226, 227], [228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 224, 225, 226, 240], [241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256], [270, 271, 272, 257, 271, 272, 257, 258, 272, 257, 258, 259, 257, 258, 259, 260, 258, 259, 260, 261, 259, 260, 261, 262, 260, 261, 262, 263, 261, 262, 263, 264, 262, 263, 264, 265, 263, 264, 265, 266, 264, 265, 266, 267, 268, 269, 270, 271, 269, 270, 271, 272], [273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 268, 269, 270, 284, 285], [286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301], [315, 316, 317, 302, 316, 317, 302, 303, 317, 302, 303, 304, 302, 303, 304, 305, 303, 304, 305, 306, 304, 305, 306, 307, 305, 306, 307, 308, 306, 307, 308, 309, 307, 308, 309, 310, 308, 309, 310, 311, 312, 313, 314, 315, 313, 314, 315, 316, 314, 315, 316, 317], [318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 312, 313, 314, 328, 329, 330], [331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 346], [360, 361, 362, 347, 361, 362, 347, 348, 362, 347, 348, 349, 347, 348, 349, 350, 348, 349, 350, 351, 349, 350, 351, 352, 350, 351, 352, 353, 351, 352, 353, 354, 352, 353, 354, 355, 356, 357, 358, 359, 357, 358, 359, 360, 358, 359, 360, 361, 359, 360, 361, 362], []]
result_values = [293154835, 2295524507, 739036381, 2544826473, 884767840, 4167860951, 2457557977, 3580402316, 1159664922, 255368266, 3535918511, 2189696244, 3172497242, 3196780109, 419743314, 3085542052]
# result_values = [631555752, 1065290957, 3784651718, 3396205352, 2708555000, 2151138649, 1080378095, 587160633, 1373283846, 973376585, 3369093675, 1107219797, 174269838, 3199073120, 3224127144, 942122456]
result_idxes = [375, 374, 373, 372, 358, 357, 356, 371, 370, 369, 368, 367, 366, 365, 364, 363]

def rotl(a1, a2):
    return ((a1 << a2) | (a1 >> (32 - a2))) & 0xffffffff
def p32(a1):
    return struct.pack('<I', a1)
idx_history = idx_history[:-1]
everyvalues = [0 for i in range(0x178)]
idx = -0x10-0xd
for i in range(len(idx_history) - 3, -1, -3):
    tmp_everyvalues = [j for j in everyvalues]
    first_idxes = idx_history[i]
    second_idxes = idx_history[i+1]
    third_idxes = idx_history[i+2]
    
    s = Solver()

    inputs_z3 = []
    for j in range(len(second_idxes)):
        if everyvalues[second_idxes[j]] == 0:
            tmp = BitVec('input[%d]'%j, 64)
            tmp_everyvalues[second_idxes[j]] = tmp
            inputs_z3.append(tmp)

    for j in second_idxes:
        res = (tmp_everyvalues[j] * 0x4e6a44b9) & 0xffffffff
        # assert tmp_everyvalues[idx] == 0
        tmp_everyvalues[idx] = res
        if everyvalues[idx] != 0:
            s.add(tmp_everyvalues[idx] == everyvalues[idx])
        idx += 1

    for j in range(0, len(third_idxes), 4):
        idx1 = third_idxes[j]
        idx2 = third_idxes[j+1]
        idx3 = third_idxes[j+2]
        idx4 = third_idxes[j+3]
        res = rotl(tmp_everyvalues[idx1], 0x1d) ^ rotl(tmp_everyvalues[idx2], 0x11) ^ rotl(tmp_everyvalues[idx3], 0x7) ^ tmp_everyvalues[idx4]
        tmp_everyvalues[idx] = res
        if everyvalues[idx] != 0:
            s.add(tmp_everyvalues[idx] == everyvalues[idx])
        idx += 1

    if i == len(idx_history) - 3:
        for j in range(len(result_idxes)):
            s.add(tmp_everyvalues[result_idxes[j]] == result_values[j])
    else:
        idx -= 0xd
        for j in range(0xd):
            if everyvalues[idx] != 0:
                s.add(tmp_everyvalues[idx] == everyvalues[idx])
            idx += 1

    print(s.check())
    m = s.model()
    res = []
    for j in range(0x10):
        res.append(m[inputs_z3[j]].as_long())
        print(hex(res[-1]))
    # if i == len(idx_history) - 3 -3:
    #     exit()
    for j in range(len(second_idxes)):
        tmp_everyvalues[second_idxes[j]] = res[j]

    idx -= 0x10
    idx -= 0xd

    for j in second_idxes:
        res = (tmp_everyvalues[j] * 0x4e6a44b9) & 0xffffffff
        tmp_everyvalues[idx] = res
        idx += 1

    for j in range(0, len(third_idxes), 4):
        idx1 = third_idxes[j]
        idx2 = third_idxes[j+1]
        idx3 = third_idxes[j+2]
        idx4 = third_idxes[j+3]
        res = rotl(tmp_everyvalues[idx1], 0x1d) ^ rotl(tmp_everyvalues[idx2], 0x11) ^ rotl(tmp_everyvalues[idx3], 0x7) ^ tmp_everyvalues[idx4]
        tmp_everyvalues[idx] = res
        idx += 1
    if i == len(idx_history) - 3:
        for j in range(len(result_idxes)):
            if tmp_everyvalues[result_idxes[j]] != result_values[j]:
                assert False
    else:
        idx -= 0xd
        for j in range(0xd):
            if tmp_everyvalues[idx] != everyvalues[idx]:
                assert False
            idx += 1

    for j in range(len(tmp_everyvalues)):
        if tmp_everyvalues[j] != 0 and everyvalues[j] == 0:
            everyvalues[j] = tmp_everyvalues[j]

    idx -= 0x10
    idx -= 0xd
    idx -= 0x10
    tmp_org = []
    for j in first_idxes:
        org = 0
        res = everyvalues[idx]
        idx += 1
        for k in range(0, 0x1c+4, 4):
            org = ((box.index((res >> k) & 0xf)) << k) | org
        tmp_org.append(org)
        print(hex(org))
    for j in range(len(first_idxes)):
        everyvalues[first_idxes[j]] = tmp_org[j]
    
    idx -= 0x10
    idx -= 0x10
    idx -= 0xd

print('done')
flag = b''
for i in range(0x10):
    flag += p32(everyvalues[i])
print(flag.decode())

 

 

Everytime I participate ctf, I think of focusing on pwn challenges but cuz of my poor writing solve script I have only little time on solving pwn challenges.

Well, I hope i can get better everytime I write any script.

Comments