reman3/scripts/DecompileC.java

470 lines
15 KiB
Java

/* ###
*/
// Script to export decompiled C code from Ghidra
//
//@category Examples.Demangler
import java.io.File;
import java.io.FileNotFoundException;
import java.io.PrintWriter;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Hashtable;
import java.util.Iterator;
import java.util.List;
import java.util.Scanner;
import java.util.Arrays;
import ghidra.app.cmd.label.AddLabelCmd;
import ghidra.app.decompiler.ClangMarkup;
import ghidra.app.decompiler.ClangNode;
import ghidra.app.decompiler.ClangTokenGroup;
import ghidra.app.decompiler.DecompInterface;
import ghidra.app.decompiler.DecompileResults;
import ghidra.app.decompiler.DecompiledFunction;
import ghidra.app.script.GhidraScript;
import ghidra.docking.settings.Settings;
import ghidra.program.model.address.Address;
import ghidra.program.model.address.AddressFactory;
import ghidra.program.model.data.AbstractStringDataType;
import ghidra.program.model.data.BuiltInDataType;
import ghidra.program.model.data.DataOrganization;
import ghidra.program.model.data.DataOrganizationImpl;
import ghidra.program.model.data.DataType;
import ghidra.program.model.data.PointerDataType;
import ghidra.program.model.data.StringDataInstance;
import ghidra.program.model.listing.Function;
import ghidra.program.model.listing.Variable;
import ghidra.program.model.listing.VariableStorage;
import ghidra.program.model.mem.MemoryAccessException;
import ghidra.program.model.mem.MemoryBlock;
import ghidra.program.model.mem.MemoryBufferImpl;
import ghidra.program.model.pcode.HighFunction;
import ghidra.program.model.pcode.HighSymbol;
import ghidra.program.model.pcode.PcodeOp;
import ghidra.program.model.pcode.PcodeOpAST;
import ghidra.program.model.pcode.Varnode;
import ghidra.program.model.symbol.SourceType;
import ghidra.program.model.symbol.Symbol;
import ghidra.program.model.symbol.SymbolTable;
import ghidra.program.model.symbol.SymbolType;
public class DecompileC extends GhidraScript {
public class PCallTracer {
public class QueueItem {
Function function;
int depth;
QueueItem(Function function, int depth) {
this.function = function;
this.depth = depth;
}
}
public List<Function> out = new ArrayList<>();
public boolean trace = false;
List<QueueItem> queue = new ArrayList<>();
HashSet<Address> visited = new HashSet<>();
DecompInterface decomp;
PCallTracer(DecompInterface decomp) {
this.decomp = decomp;
}
public void setBlacklist(HashSet<Address> blacklist) {
this.visited = new HashSet<>(blacklist);
}
void visit(HighFunction highFunction, int depth) {
Iterator<PcodeOpAST> opIter = highFunction.getPcodeOps();
while (opIter.hasNext()) {
PcodeOpAST op = opIter.next();
if (op.getOpcode() == PcodeOp.CALL) {
Varnode target = op.getInput(0);
if (target.isAddress()) {
Address callAddr = target.getAddress();
Function calledFunction = getFunctionAt(callAddr);
if (calledFunction == null) {
println("PCallTracer, called function not found: " + op.toString() + " - "
+ highFunction.getFunction().getName());
continue;
}
if (!visited.contains(calledFunction.getEntryPoint())) {
queue.add(new QueueItem(calledFunction, depth + 1));
}
}
}
}
}
void visit(Function function, int depth) {
if (!visited.contains(function.getEntryPoint())) {
visited.add(function.getEntryPoint());
if (trace) {
println("PCallTracer, visiting " + function.getName() + " (depth:" + depth + ")");
}
DecompileResults decompRes = decomp.decompileFunction(function, TIMEOUT, monitor);
visit(decompRes.getHighFunction(), depth);
out.add(function);
}
}
public void traceCalls(Function inFunction) {
queue.add(new QueueItem(inFunction, 0));
while (queue.size() > 0) {
QueueItem item = queue.remove(0);
visit(item.function, item.depth);
}
}
}
private static final String OUTPUT_DIR = "game_re";
private static final int TIMEOUT = 10000;
// The static memory block
private Address staticMemoryBlockStart;
private Address staticMemoryBlockEnd;
// Auto rename invalid symbols
private static final boolean AUTO_RENAME_SYMBOLS = true;
private static final HashSet<String> ONLY_SYMBOLS = new HashSet<>(Arrays.asList(
"r3_main", //
"_strrchr"));
File rootDir;
File outputDir;
File dirDecompAuto;
File dirDecompFix;
HashSet<Address> functionAddrBlackList = new HashSet<>();
void loadFunctionBlacklist() {
functionAddrBlackList.clear();
File blacklistFile = new File(outputDir, "blacklist.txt");
try (Scanner scanner = new Scanner(blacklistFile)) {
while (scanner.hasNextLine()) {
String line = scanner.nextLine();
// Strip comment
String line1 = line.split("//")[0].trim();
// Deserialize address
Address addr = currentProgram.getAddressFactory().getAddress(line1);
functionAddrBlackList.add(addr);
}
println("Loaded blacklist with " + functionAddrBlackList.size() + " entries");
} catch (FileNotFoundException e) {
println("No blacklist found");
return;
}
// for (Address a : functionAddrBlackList) {
// Function fn = getFunctionAt(a);
// println("In blacklist: " + a + " (" + (fn != null ? fn.getName() : "unknown")
// + ")");
// }
}
void saveFunctionBlacklist() {
println("Saving blacklist");
File blacklistFile = new File(outputDir, "blacklist.txt");
try (PrintWriter writer = new PrintWriter(blacklistFile)) {
for (Address addr : functionAddrBlackList) {
writer.println(addr.toString() + " // " + getFunctionAt(addr).getName());
}
} catch (FileNotFoundException e) {
println("Error saving blacklist: " + e.getMessage());
}
}
boolean shouldDecompileFunction(Function function) {
if (ONLY_SYMBOLS != null && !ONLY_SYMBOLS.contains(function.getName())) {
return false;
}
return !functionAddrBlackList.contains(function.getEntryPoint());
}
static final boolean BUILD_BLACKLIST = true;
void buildFunctionBlacklist(DecompInterface decomp) {
loadFunctionBlacklist();
if (BUILD_BLACKLIST) {
boolean modified = false;
Iterator<Function> functionsIt = currentProgram.getFunctionManager().getFunctions(true).iterator();
while (functionsIt.hasNext()) {
Function function = functionsIt.next();
if (functionAddrBlackList.contains(function.getEntryPoint())) {
continue;
}
String comment = function.getComment();
boolean isIgnoredFunction = false;
if (comment != null && comment.contains("Library Function")) {
println("Adding library function " + function.getName() + " to blacklist");
println("ac:" + functionAddrBlackList.size() + " jj:"
+ functionAddrBlackList.contains(function.getEntryPoint()) + " " + function.getEntryPoint());
isIgnoredFunction = true;
}
if (function.getName().startsWith("crt_")) {
println("Adding crt function " + function.getName() + " to blacklist");
isIgnoredFunction = true;
}
if (isIgnoredFunction) {
// Decompile and trace
PCallTracer tracer = new PCallTracer(decomp);
tracer.setBlacklist(functionAddrBlackList);
tracer.traceCalls(function);
for (Function f : tracer.out) {
println(" Adding " + f.getName() + " to blacklist");
functionAddrBlackList.add(f.getEntryPoint());
modified = true;
}
}
}
if (modified) {
saveFunctionBlacklist();
}
}
}
void sanitizeGlobalSymbolsPass(DecompInterface decomp, List<Function> functions) {
Hashtable<String, HighSymbol> globalSymbols = new Hashtable<>();
for (Function function : functions) {
println("Processing global symbols for " + function.getName());
DecompileResults decompRes = decomp.decompileFunction(function, TIMEOUT, monitor);
Iterator<HighSymbol> smyIt = decompRes.getHighFunction().getGlobalSymbolMap().getSymbols();
HighSymbol gsym = smyIt.next();
if (globalSymbols.containsKey(gsym.getName()))
continue;
println("GLOBAL: " + gsym.getName());
String sanitizedName = sanitizeFunctionName(gsym.getName());
if (!sanitizedName.equals(gsym.getName())) {
if (AUTO_RENAME_SYMBOLS) {
Symbol symbol = gsym.getSymbol();
VariableStorage storage = gsym.getStorage();
Address addr = storage.getMinAddress();
println("Renaming global symbol: " + gsym.getName() + " (" + addr
+ ") -> " + sanitizedName);
if (symbol != null) {
AddLabelCmd cmd = new AddLabelCmd(addr, sanitizedName, symbol.getParentNamespace(),
SourceType.USER_DEFINED);
if (cmd.applyTo(currentProgram)) {
println("Renamed global symbol: " + gsym.getName() + " -> " + sanitizedName);
} else {
println("Error renaming symbol: " + cmd.getStatusMsg());
}
} else {
println("Symbol is null: " + gsym.getName() + " - " + function.getName());
}
} else {
println("Invalid global symbol name: " + gsym.getName() + " - " + function.getName());
}
}
}
}
String escapeCString(String str) {
str = str.replace("\\", "\\\\");
str = str.replace("\"", "\\\"");
// str = str.replaceAll("\n", "\\n");
// str = str.replaceAll("\r", "\\r");
// str = str.replaceAll("\t", "\\t");
// str = str.replaceAll("\b", "\\b");
// str = str.replaceAll("\f", "\\f");
// str = str.replaceAll("\0", "\\0");
return str;
}
String readCString(Address addr, int maxLen) throws Exception {
StringBuilder sb = new StringBuilder();
int ofs = 0;
while (true) {
Address read = addr.add(ofs++);
// println("Reading: " + read);
byte b = currentProgram.getMemory().getByte(read);
// println("Read: " + b);
if (b == 0 || ofs >= maxLen) {
break;
}
sb.append((char) b);
}
if (sb.length() > 0) {
println("STR \"" + sb.toString() + "\"");
}
return sb.toString();
}
void decompileFunction(Hashtable<String, HighSymbol> outGlobalSymbols, DecompInterface decomp, Function function)
throws Exception {
String fileName = sanitizeFunctionName(function.getName()) + ".cxx";
File f1 = new File(dirDecompFix, fileName);
if (f1.exists()) {
println("Func " + function.getName() + " skipped (gh_fix)");
return;
}
File f0 = new File(dirDecompAuto, fileName);
if (f0.exists()) {
f0.delete();
}
println("Processing " + function.getName() + " => " + f0.toString());
DecompileResults decompRes = decomp.decompileFunction(function, TIMEOUT, monitor);
PrintWriter writer2 = new PrintWriter(f0, "UTF-8");
writer2.println("// AUTO-GENERATED FILE, MOVE TO 'gh_fix' FOLDER PREVENT OVERWRITING!!!!! ");
writer2.println("// " + function.getEntryPoint());
writer2.println();
writer2.println("#include <gh_auto_shared.h>");
writer2.println("#include \"../gh_global.h\"");
writer2.println();
HighFunction highFunction = decompRes.getHighFunction();
// ClangTokenGroup
// ClangNode.
// ClangTokenGroup ctg = decompRes.getCCodeMarkup();
// for (ClangTokenGroup it = ctg.groupIterator(); it.hasNext();) {
// }
writer2.println(cm.getCode());
writer2.close();
// Collect referenced global symbols
Iterator<HighSymbol> smyIt = decompRes.getHighFunction().getGlobalSymbolMap().getSymbols();
while (smyIt.hasNext()) {
HighSymbol gsym = smyIt.next();
if (outGlobalSymbols.containsKey(gsym.getName()))
continue;
outGlobalSymbols.put(gsym.getName(), gsym);
}
}
void decompileAll(DecompInterface decomp, List<Function> functions) throws Exception {
Hashtable<String, HighSymbol> globalSymbols = new Hashtable<>();
for (Function function : functions) {
decompileFunction(globalSymbols, decomp, function);
}
File globalSymbolsListH = new File(outputDir, "gh_global.h");
PrintWriter hwriter = new PrintWriter(globalSymbolsListH, "UTF-8");
hwriter.println("// AUTO-GENERATED FILE ");
hwriter.println("#include <gh_auto_shared.h>");
File globalSymbolsListC = new File(outputDir, "gh_global.cxx");
PrintWriter cwriter = new PrintWriter(globalSymbolsListC, "UTF-8");
cwriter.println("// AUTO-GENERATED FILE ");
cwriter.println("#include <gh_auto_shared.h>");
for (HighSymbol highSym : globalSymbols.values()) {
DataType dt = highSym.getDataType();
String dataType = dt.getDisplayName();
String name = highSym.getName();
String sanitizedName = sanitizeFunctionName(highSym.getName());
if (!sanitizedName.equals(highSym.getName())) {
println("Invalid global symbol name: " + highSym.getName() + " - "
+ highSym.getHighFunction().getFunction().getName());
} else {
Symbol symbol = highSym.getSymbol();
VariableStorage storage = highSym.getStorage();
Address addr = storage.getMinAddress();
int symSize = highSym.getSize();
if (addr == null) {
// Not sure why this is sometimes null
// also when it is not null, Symbol.getAddress() is not correct but very small
// like 00000056
// Not that storage will be <undefined> so maybe can check that
addr = symbol.getAddress();
}
println("Symbol: " + symbol + " Addr: " + addr + " Size:" + symSize + " " + storage.getSerializationString());
try {
String initBlk = " = ";
if (dt instanceof AbstractStringDataType) {
AbstractStringDataType sdt = (AbstractStringDataType) dt;
dataType = "const char*";
// String type
initBlk += "\"" + escapeCString(readCString(addr, 2048)) + "\"";
} else if (dt instanceof PointerDataType) {
PointerDataType pdt = (PointerDataType) dt;
DataType baseType = pdt.getDataType();
dataType = baseType.getDisplayName() + "*";
initBlk += "gh_ptr(0x" + addr + ")";
} else {
initBlk = " = 0";
}
cwriter.println(dataType + " " + name + initBlk + "; // " + addr);
} catch (Exception e) {
println("Error processing global symbol: " + e);
println("Symbol: " + highSym.getName() + " - " + addr + " - "
+ highSym.getHighFunction().getFunction().getName());
}
hwriter.println("extern " + dataType + " " + name + "; // " + addr);
}
}
hwriter.close();
cwriter.close();
}
@Override
public void run() throws Exception {
if (currentProgram == null) {
return;
}
staticMemoryBlockStart = currentProgram.getAddressFactory().getAddress("005b6400");
staticMemoryBlockEnd = currentProgram.getAddressFactory().getAddress("00843fff");
// Make sure to create OUTPUT_PATH
rootDir = new File(sourceFile.getAbsolutePath()).getParentFile().getParentFile();
outputDir = new File(rootDir, OUTPUT_DIR);
if (!outputDir.exists()) {
throw new Exception("Output directory does not exist: " + outputDir.getCanonicalPath());
}
dirDecompAuto = new File(outputDir, "gh_auto");
dirDecompFix = new File(outputDir, "gh_fix");
println("Output path: " + outputDir.getCanonicalPath());
DecompInterface decomp = new DecompInterface();
decomp.openProgram(currentProgram);
buildFunctionBlacklist(decomp);
List<Function> functions = new ArrayList<>();
Iterator<Function> functionsIt = currentProgram.getFunctionManager().getFunctions(true).iterator();
while (functionsIt.hasNext()) {
Function function = functionsIt.next();
if (!shouldDecompileFunction(function)) {
continue;
}
functions.add(function);
}
int mode = 1;
if (mode == 0) { // Sanitize symbols
sanitizeGlobalSymbolsPass(decomp, functions);
} else if (mode == 1) { // Decompile all functions
decompileAll(decomp, functions);
}
}
String sanitizeFunctionName(String name) {
return name.replaceAll("[^a-zA-Z0-9_]", "_");
}
}