reman3/scripts/re3lib/FunctionDatabase.java

248 lines
8.6 KiB
Java

package re3lib;
import java.io.File;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import ghidra.app.script.GhidraScript;
import ghidra.program.model.address.Address;
import ghidra.program.model.listing.Function;
public class FunctionDatabase {
public enum Type {
Auto,
Fix,
Stub
}
public class Dependency implements java.io.Serializable {
private static final long serialVersionUID = 1L;
public Address address;
public String name;
public Dependency(Address address, String name) {
this.address = address;
this.name = name;
}
private void writeObject(java.io.ObjectOutputStream out) throws java.io.IOException {
out.writeObject(address != null ? address.toString() : null);
out.writeObject(name);
}
private void readObject(java.io.ObjectInputStream in) throws java.io.IOException, ClassNotFoundException {
String addressString = (String) in.readObject();
if (addressString != null) {
address = RecompileConfig.INSTANCE.script.getCurrentProgram().getAddressFactory().getAddress(addressString);
}
name = (String) in.readObject();
}
}
public class Entry implements java.io.Serializable {
private static final long serialVersionUID = 1L;
public Address address;
public String name;
public File file;
public Type type;
public List<Dependency> dependencies = new ArrayList<>();
private void writeObject(java.io.ObjectOutputStream out) throws java.io.IOException {
out.writeObject(address != null ? address.toString() : null);
out.writeObject(name);
out.writeObject(file != null ? file.toString() : null);
out.writeObject(type);
out.writeObject(dependencies);
}
private void readObject(java.io.ObjectInputStream in) throws java.io.IOException, ClassNotFoundException {
String addressString = (String) in.readObject();
if (addressString != null) {
address = RecompileConfig.INSTANCE.script.getCurrentProgram().getAddressFactory().getAddress(addressString);
}
name = (String) in.readObject();
String fileString = (String) in.readObject();
if (fileString != null) {
file = new File(fileString);
}
type = (Type) in.readObject();
dependencies = (List<Dependency>) in.readObject();
}
}
public List<Entry> entries = new ArrayList<>();
private File file;
private transient GhidraScript script;
public FunctionDatabase(GhidraScript script) {
this.script = script;
file = new File(RecompileConfig.INSTANCE.outputDir, "functions.dat");
}
public void load() throws Exception {
if (!file.exists()) {
return;
}
try (java.io.ObjectInputStream ois = new java.io.ObjectInputStream(new java.io.FileInputStream(file))) {
entries = (List<Entry>) ois.readObject();
script.println("Loaded " + entries.size() + " function entries from " + file);
} catch (java.io.IOException | ClassNotFoundException e) {
script.println("Error loading function database: " + e.getMessage());
}
}
public void save() throws Exception {
try (java.io.ObjectOutputStream oos = new java.io.ObjectOutputStream(new java.io.FileOutputStream(file))) {
oos.writeObject(entries);
script.println("Saved " + entries.size() + " function entries to " + file);
} catch (java.io.IOException e) {
script.println("Error saving function database: " + e.getMessage());
}
}
public void add(Entry entry) {
entries.add(entry);
}
public void applyDefaultFilters(boolean rebuildAllGlobals) throws Exception {
GlobalDumper globalDumper = new GlobalDumper(script);
FunctionDumper dumper = new FunctionDumper(script, globalDumper);
if (rebuildAllGlobals) {
globalDumper.removeGlobalManifest();
}
boolean madeAnyChanges = false;
// Create a hash map to store symbol names
Map<Address, String> symbolNames = new HashMap<>();
Map<String, File> exportedFunctionNames = new HashMap<>();
for (Entry entry : entries) {
Function function = script.getFunctionAt(entry.address);
if (function != null) {
String dirComponent = entry.file.getParent().toString();
boolean isAuto = entry.type == Type.Auto;
boolean isFix = entry.type == Type.Fix;
// Get the actual symbol name and store it in the hash map
String symbolName = function.getName();
symbolNames.put(entry.address, symbolName);
if (isAuto && !exportedFunctionNames.containsKey(entry.name)) {
exportedFunctionNames.put(entry.name, entry.file);
} else if (isFix) {
exportedFunctionNames.replace(entry.name, entry.file);
}
}
}
// Print the number of symbol names collected
script.println("Collected " + symbolNames.size() + " symbol names");
boolean dryMode = false;
HashSet<Function> functionsToRegenerate = new HashSet<>();
Iterator<Entry> iterator = entries.iterator();
while (iterator.hasNext()) {
Entry entry = iterator.next();
Function function = script.getFunctionAt(entry.address);
boolean pendingDelete = false;
boolean pendingRegenerate = false;
if (rebuildAllGlobals) {
pendingRegenerate = true;
}
// Remove CRT and other blacklisted functions
if (function == null || !dumper.isValidFunction(function)) {
// Remove the file
if (entry.file != null && entry.file.exists()) {
script.println("Removed file: " + entry.file.getAbsolutePath());
pendingDelete = true;
}
// Remove entry from the list
script.println("Removed invalid function entry: " + entry.name + " at " + entry.address);
function = null;
}
// Check if symbol name matches the symbol name parsed from the file
if (function != null) {
String actualSymbolName = symbolNames.get(entry.address);
if (actualSymbolName == null) {
throw new Exception(
"Symbol name not found for function at " + entry.address + " in file " + entry.file.getAbsolutePath());
}
if (actualSymbolName != null && !actualSymbolName.equals(entry.name)) {
File fnExportedFile = exportedFunctionNames.get(entry.name);
if (fnExportedFile != null && fnExportedFile != entry.file) {
// Already exists elsewhere, so remove this file
script.println("Removing duplicate function: " + entry.name + " at " + entry.address + " overridden by "
+ fnExportedFile);
pendingDelete = true;
} else {
// Regeneral this function
script.println("Symbol name mismatch for function at " + entry.address + ": " +
"File name: " + entry.name + ", Actual symbol: " + actualSymbolName);
entry.name = actualSymbolName; // Update the entry name to match the actual symbol
pendingRegenerate = true;
}
}
// Check if dependencies are valid
for (Dependency dependency : entry.dependencies) {
Function depFunction = script.getFunctionAt(dependency.address);
if (depFunction == null) {
script.println(
"Dependency not found: " + dependency.name + " at " + dependency.address + " in " + entry.file);
pendingRegenerate = true;
} else if (!dumper.isValidFunction(depFunction) || !depFunction.getName().equals(dependency.name)) {
script
.println("Invalid dependency: " + dependency.name + " at " + dependency.address + " in " + entry.file
+ " should be " + dependency.name);
pendingRegenerate = true;
}
}
entry.name = actualSymbolName; // Update the entry name to match the actual symbol
madeAnyChanges = true;
}
if (pendingDelete) {
iterator.remove();
if (!dryMode) {
entry.file.delete();
madeAnyChanges = true;
}
} else if (pendingRegenerate && entry.type != Type.Stub) {
if (!dryMode) {
functionsToRegenerate.add(function);
madeAnyChanges = true;
}
}
}
for (Function function : functionsToRegenerate) {
script.println("Regenerating function: " + function.getName() + " at " + function.getEntryPoint());
dumper.dump(function);
}
if (madeAnyChanges) {
// Update CMake timestamp
RecompileConfig.INSTANCE.touchCMakeTimestamp();
globalDumper.dumpGlobals();
globalDumper.saveGlobalManifest();
TypeDumper typeDumper = new TypeDumper(script);
typeDumper.run();
}
}
}