Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add basic block support #55

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions dncil/cil/block.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright (C) 2022 Mandiant, Inc. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at: [package root]/LICENSE.txt
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and limitations under the License.

from __future__ import annotations

from typing import TYPE_CHECKING, List, Optional

if TYPE_CHECKING:
from dncil.cil.instruction import Instruction


class BasicBlock:
def __init__(self, instructions: Optional[List[Instruction]] = None):
self.instructions = instructions or []
self.preds: List[BasicBlock] = []
self.succs: List[BasicBlock] = []

@property
def start_offset(self) -> int:
return self.instructions[0].offset

@property
def end_offset(self) -> int:
return self.instructions[-1].offset + self.instructions[-1].size

@property
def size(self) -> int:
return self.end_offset - self.start_offset

def get_bytes(self) -> bytes:
block_bytes: bytes = b""

for insn in self.instructions:
block_bytes += insn.get_bytes()

return block_bytes
74 changes: 73 additions & 1 deletion dncil/cil/body/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@

from __future__ import annotations

from typing import TYPE_CHECKING, List, Optional
from typing import TYPE_CHECKING, Set, Dict, List, Iterator, Optional, cast

if TYPE_CHECKING:
from dncil.cil.instruction import Instruction
from dncil.cil.body.reader import CilMethodBodyReaderBase

from dncil.cil.block import BasicBlock
from dncil.cil.enums import CorILMethod, CorILMethodSect
from dncil.cil.error import MethodBodyFormatError
from dncil.clr.token import Token
Expand All @@ -37,6 +38,7 @@ def __init__(self, reader: CilMethodBodyReaderBase):
self.exception_handlers_size: int

self.instructions: List[Instruction] = []
self.basic_blocks: List[BasicBlock] = []
self.exception_handlers: List[ExceptionHandler] = []

# set method offset
Expand Down Expand Up @@ -75,6 +77,11 @@ def get_exception_handler_bytes(self) -> bytes:
"""get method exception handler bytes"""
return self.raw_bytes[self.header_size + self.code_size :]

def get_basic_blocks(self) -> Iterator[BasicBlock]:
if not self.basic_blocks:
self.parse_basic_blocks()
yield from self.basic_blocks

def parse_header(self, reader: CilMethodBodyReaderBase):
"""get method body header"""
# header byte gives us the format and, in fat format, implementation flags used at runtime
Expand Down Expand Up @@ -201,3 +208,68 @@ def parse_tiny_exception_handlers(self, reader: CilMethodBodyReaderBase):
_ = reader.read_uint32()[0]

self.exception_handlers.append(eh)

def parse_basic_blocks(self):
# calculate basic block leaders where,
# 1. The first instruction of the intermediate code is a leader
# 2. Instructions that are targets of unconditional or conditional jump/goto statements are leaders
# 3. Instructions that immediately follow unconditional or conditional jump/goto statements are considered leaders
# https://www.geeksforgeeks.org/basic-blocks-in-compiler-design/

leaders: Set[int] = set()
for idx, insn in enumerate(self.instructions):
if idx == 0:
# add #1
leaders.add(insn.offset)

if any((insn.is_br(), insn.is_cond_br(), insn.is_leave())):
# add #2
leaders.add(cast(int, insn.operand))
# add #3
try:
leaders.add(self.instructions[idx + 1].offset)
except IndexError:
# end of method
continue

# build basic blocks using leaders
bb_curr: Optional[BasicBlock] = None
for idx, insn in enumerate(self.instructions):
if insn.offset in leaders:
# new leader, new basic block
bb_curr = BasicBlock(instructions=[insn])
self.basic_blocks.append(bb_curr)
continue

assert bb_curr is not None
bb_curr.instructions.append(insn)

# create mapping of first instruction to basic block
bb_map: Dict[int, BasicBlock] = {}
for bb in self.basic_blocks:
bb_map[bb.start_offset] = bb

# connect basic blocks
for idx, bb in enumerate(self.basic_blocks):
last = bb.instructions[-1]

# connect branches to other basic blocks
if any((last.is_br(), last.is_cond_br(), last.is_leave())):
bb_branch: Optional[BasicBlock] = bb_map.get(cast(int, last.operand), None)
if bb_branch is not None:
# invalid branch, may be seen in obfuscated IL
bb.succs.append(bb_branch)
bb_branch.preds.append(bb)

if any((last.is_br(), last.is_leave())):
# no fallthrough
continue

# connect fallthrough
try:
bb_next: BasicBlock = self.basic_blocks[idx + 1]
bb.succs.append(bb_next)
bb_next.preds.append(bb)
except IndexError:
# end of method
continue
146 changes: 146 additions & 0 deletions tests/test_method_body.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,102 @@
"""
method_body_tiny = binascii.unhexlify("1E02280C00000A2A")

"""
.method private hidebysig static
void Main (
string[] args
) cil managed
{
// Header Size: 12 bytes
// Code Size: 142 (0x8E) bytes
// LocalVarSig Token: 0x11000001 RID: 1
.maxstack 2
.entrypoint
.locals init (
[0] int32 i,
[1] string text
)

/* 0x0000025C 02 */ IL_0000: ldarg.0
/* 0x0000025D 17 */ IL_0001: ldc.i4.1
/* 0x0000025E 9A */ IL_0002: ldelem.ref
/* 0x0000025F 7201000070 */ IL_0003: ldstr "test"
/* 0x00000264 280F00000A */ IL_0008: call bool [mscorlib]System.String::op_Equality(string, string)
/* 0x00000269 2C0C */ IL_000D: brfalse.s IL_001B

/* 0x0000026B 720B000070 */ IL_000F: ldstr "Hello from test"
/* 0x00000270 281000000A */ IL_0014: call void [mscorlib]System.Console::WriteLine(string)
/* 0x00000275 2B68 */ IL_0019: br.s IL_0083

/* 0x00000277 02 */ IL_001B: ldarg.0
/* 0x00000278 17 */ IL_001C: ldc.i4.1
/* 0x00000279 9A */ IL_001D: ldelem.ref
/* 0x0000027A 722B000070 */ IL_001E: ldstr "testtest"
/* 0x0000027F 280F00000A */ IL_0023: call bool [mscorlib]System.String::op_Equality(string, string)
/* 0x00000284 2C0C */ IL_0028: brfalse.s IL_0036

/* 0x00000286 723D000070 */ IL_002A: ldstr "Hello from testtest"
/* 0x0000028B 281000000A */ IL_002F: call void [mscorlib]System.Console::WriteLine(string)
/* 0x00000290 2B4D */ IL_0034: br.s IL_0083

/* 0x00000292 16 */ IL_0036: ldc.i4.0
/* 0x00000293 0A */ IL_0037: stloc.0
/* 0x00000294 2B0E */ IL_0038: br.s IL_0048

// loop start (head: IL_0048)
/* 0x00000296 7265000070 */ IL_003A: ldstr "Hello from unknown"
/* 0x0000029B 281000000A */ IL_003F: call void [mscorlib]System.Console::WriteLine(string)
/* 0x000002A0 06 */ IL_0044: ldloc.0
/* 0x000002A1 17 */ IL_0045: ldc.i4.1
/* 0x000002A2 58 */ IL_0046: add
/* 0x000002A3 0A */ IL_0047: stloc.0

/* 0x000002A4 06 */ IL_0048: ldloc.0
/* 0x000002A5 1F64 */ IL_0049: ldc.i4.s 100
/* 0x000002A7 32ED */ IL_004B: blt.s IL_003A
// end loop

/* 0x000002A9 00 */ IL_004D: nop
.try
{
/* 0x000002AA 02 */ IL_004E: ldarg.0
/* 0x000002AB 17 */ IL_004F: ldc.i4.1
/* 0x000002AC 9A */ IL_0050: ldelem.ref
/* 0x000002AD 281100000A */ IL_0051: call string [mscorlib]System.IO.File::ReadAllText(string)
/* 0x000002B2 0B */ IL_0056: stloc.1
/* 0x000002B3 07 */ IL_0057: ldloc.1
/* 0x000002B4 728B000070 */ IL_0058: ldstr "exit"
/* 0x000002B9 280F00000A */ IL_005D: call bool [mscorlib]System.String::op_Equality(string, string)
/* 0x000002BE 2C02 */ IL_0062: brfalse.s IL_0066

/* 0x000002C0 DE27 */ IL_0064: leave.s IL_008D

/* 0x000002C2 07 */ IL_0066: ldloc.1
/* 0x000002C3 281000000A */ IL_0067: call void [mscorlib]System.Console::WriteLine(string)
/* 0x000002C8 07 */ IL_006C: ldloc.1
/* 0x000002C9 281000000A */ IL_006D: call void [mscorlib]System.Console::WriteLine(string)
/* 0x000002CE 07 */ IL_0072: ldloc.1
/* 0x000002CF 281000000A */ IL_0073: call void [mscorlib]System.Console::WriteLine(string)
/* 0x000002D4 17 */ IL_0078: ldc.i4.1
/* 0x000002D5 281200000A */ IL_0079: call void [mscorlib]System.Environment::Exit(int32)
/* 0x000002DA DE03 */ IL_007E: leave.s IL_0083
} // end .try
catch [mscorlib]System.Object
{
/* 0x000002DC 26 */ IL_0080: pop
/* 0x000002DD DE0A */ IL_0081: leave.s IL_008D
} // end handler

/* 0x000002DF 7295000070 */ IL_0083: ldstr "Failed"
/* 0x000002E4 281000000A */ IL_0088: call void [mscorlib]System.Console::WriteLine(string)

/* 0x000002E9 2A */ IL_008D: ret
} // end of method Program::Main
"""
method_body_fat_complex = binascii.unhexlify(
"1b3002008e0000000100001102179a7201000070280f00000a2c0c720b000070281000000a2b6802179a722b000070280f00000a2c0c723d000070281000000a2b4d160a2b0e7265000070281000000a0617580a061f6432ed0002179a281100000a0b07728b000070280f00000a2c02de2707281000000a07281000000a07281000000a17281200000ade0326de0a7295000070281000000a2a00000110000000004e003280000310000001"
)


def test_invalid_header_format():
reader = CilMethodBodyReaderBytes(b"\x00")
Expand Down Expand Up @@ -195,3 +291,53 @@ def test_read_fat_header_exception_handlers():
assert body.exception_handlers[0].handler_end == 0x19
assert isinstance(body.exception_handlers[0].catch_type, Token)
assert body.exception_handlers[1].is_finally()


def test_read_tiny_header_blocks():
reader = CilMethodBodyReaderBytes(method_body_tiny)
body = CilMethodBody(reader)
blocks = list(body.get_basic_blocks())

assert len(blocks) == 1
assert blocks[0].get_bytes() == b"\x02\x28\x0C\x00\x00\x0A\x2a"
assert blocks[0].size == 7
assert blocks[0].instructions[-1].opcode.value == OpCodeValue.Ret

block_bytes = b""
for bb in blocks:
block_bytes += bb.get_bytes()

assert block_bytes == body.get_instruction_bytes()
assert len(blocks[0].preds) == 0
assert len(blocks[-1].succs) == 0


def test_read_fat_header_complex_blocks():
reader = CilMethodBodyReaderBytes(method_body_fat_complex)
body = CilMethodBody(reader)
blocks = list(body.get_basic_blocks())

assert len(blocks) == 13
assert blocks[4].get_bytes() == b"\x16\x0a\x2b\x0e"
assert blocks[11].size == 10
assert blocks[11].instructions[0].opcode.value == OpCodeValue.Ldstr
assert blocks[11].instructions[-1].opcode.value == OpCodeValue.Call

block_bytes = b""
for bb in blocks:
block_bytes += bb.get_bytes()

assert block_bytes == body.get_instruction_bytes()
assert len(blocks[0].preds) == 0
assert len(blocks[-1].succs) == 0
assert len(blocks[-1].preds) == 3
assert len(blocks[9].preds) == 1
assert len(blocks[9].succs) == 1
assert len(blocks[1].succs) == 1
assert blocks[8].start_offset in [bb.start_offset for bb in blocks[-1].preds]
assert blocks[10].start_offset in [bb.start_offset for bb in blocks[-1].preds]
assert blocks[11].start_offset in [bb.start_offset for bb in blocks[-1].preds]
assert blocks[7].start_offset in [bb.start_offset for bb in blocks[9].preds]
assert blocks[7].start_offset in [bb.start_offset for bb in blocks[8].preds]
assert blocks[9].start_offset in [bb.start_offset for bb in blocks[7].succs]
assert blocks[8].start_offset in [bb.start_offset for bb in blocks[7].succs]