reman3/java/ghidra/re3lib/FunctionDatabase.java

488 lines
16 KiB
Java

package re3lib;
import java.io.File;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.sqlite.JDBC;
import ghidra.app.script.GhidraScript;
import ghidra.program.model.address.Address;
import ghidra.program.model.listing.Function;
public class FunctionDatabase {
public enum Type {
Auto(0),
Fix(1),
Stub(2),
Ref(3);
private final int value;
Type(int value) {
this.value = value;
}
public int getValue() {
return value;
}
public static Type fromValue(int value) {
for (Type type : Type.values()) {
if (type.value == value) {
return type;
}
}
throw new IllegalArgumentException("Unknown type value: " + value);
}
}
public class Entry {
public Address address;
public String name;
public File file;
public Type type;
public Entry(Address address, String name, File file, Type type) {
this.address = address;
this.name = name;
this.file = file;
this.type = type;
}
}
private File dbFile;
private transient GhidraScript script;
private Connection connection;
// Prepared statements for better performance
private PreparedStatement findByNameFunctions;
private PreparedStatement findByNameImports;
private PreparedStatement findByAddressFunctions;
private PreparedStatement findByAddressImports;
private PreparedStatement insertOrReplaceFunctions;
private PreparedStatement deleteByFilepathFunctions;
private PreparedStatement deleteByFilepathImports;
private PreparedStatement loadAllFunctions;
private PreparedStatement loadAllImports;
public FunctionDatabase(GhidraScript script) {
this.script = script;
dbFile = RemanConfig.INSTANCE.databasePath;
try {
java.sql.DriverManager.registerDriver(new JDBC());
} catch (SQLException e) {
script.printerr("Error registering JDBC driver: " + e.getMessage());
}
}
public void connect() throws Exception {
if (connection != null && !connection.isClosed()) {
return; // Already connected
}
if (!dbFile.exists()) {
script.println("Database file not found: " + dbFile);
// Create parent directories if they don't exist
dbFile.getParentFile().mkdirs();
}
try {
connection = DriverManager.getConnection("jdbc:sqlite:" + dbFile.getAbsolutePath());
createTablesIfNotExist();
prepareCachedStatements();
script.println("Connected to database: " + dbFile);
} catch (SQLException e) {
script.println("Error connecting to database: " + e.getMessage());
throw new Exception("Failed to connect to database", e);
}
}
public void disconnect() throws Exception {
if (connection != null && !connection.isClosed()) {
try {
// Close prepared statements
closePreparedStatements();
connection.close();
script.println("Disconnected from database");
} catch (SQLException e) {
script.println("Error disconnecting from database: " + e.getMessage());
throw new Exception("Failed to disconnect from database", e);
}
}
}
private void ensureConnection() throws Exception {
if (connection == null || connection.isClosed()) {
connect();
}
}
private void prepareCachedStatements() throws SQLException {
// Find by name
findByNameFunctions = connection.prepareStatement(
"SELECT filepath, name, address, type FROM Functions WHERE name = ?");
findByNameImports = connection.prepareStatement(
"SELECT filepath, name, address, type FROM Imports WHERE name = ?");
// Find by address
findByAddressFunctions = connection.prepareStatement(
"SELECT filepath, name, address, type FROM Functions WHERE address = ?");
findByAddressImports = connection.prepareStatement(
"SELECT filepath, name, address, type FROM Imports WHERE address = ?");
// Insert or replace
insertOrReplaceFunctions = connection.prepareStatement(
"INSERT OR REPLACE INTO Functions (filepath, name, address, type) VALUES (?, ?, ?, ?)");
// Delete by filepath
deleteByFilepathFunctions = connection.prepareStatement(
"DELETE FROM Functions WHERE filepath = ?");
deleteByFilepathImports = connection.prepareStatement(
"DELETE FROM Imports WHERE filepath = ?");
// Load all entries
loadAllFunctions = connection.prepareStatement(
"SELECT filepath, name, address, type FROM Functions");
loadAllImports = connection.prepareStatement(
"SELECT filepath, name, address, type FROM Imports");
}
private void closePreparedStatements() throws SQLException {
if (findByNameFunctions != null)
findByNameFunctions.close();
if (findByNameImports != null)
findByNameImports.close();
if (findByAddressFunctions != null)
findByAddressFunctions.close();
if (findByAddressImports != null)
findByAddressImports.close();
if (insertOrReplaceFunctions != null)
insertOrReplaceFunctions.close();
if (deleteByFilepathFunctions != null)
deleteByFilepathFunctions.close();
if (deleteByFilepathImports != null)
deleteByFilepathImports.close();
if (loadAllFunctions != null)
loadAllFunctions.close();
if (loadAllImports != null)
loadAllImports.close();
}
public List<Entry> loadAllEntries() throws Exception {
ensureConnection();
List<Entry> entries = new ArrayList<>();
try {
// Load from Functions table
try (ResultSet rs = loadAllFunctions.executeQuery()) {
while (rs.next()) {
Entry entry = createEntryFromResultSet(rs);
if (entry != null) {
entries.add(entry);
}
}
}
script.println("Loaded " + entries.size() + " function entries from database");
return entries;
} catch (SQLException e) {
script.println("Error loading entries: " + e.getMessage());
throw new Exception("Failed to load entries", e);
}
}
private Entry createEntryFromResultSet(ResultSet rs) throws SQLException {
String filepath = rs.getString("filepath");
String name = rs.getString("name");
String addressStr = rs.getString("address");
int typeValue = rs.getInt("type");
if (addressStr != null && !addressStr.isEmpty()) {
Address address = script.getCurrentProgram().getAddressFactory().getAddress(addressStr);
File file = new File(RemanConfig.INSTANCE.outputDir, filepath);
Type type = Type.fromValue(typeValue);
return new Entry(address, name, file, type);
}
return null;
}
private void createTablesIfNotExist() throws SQLException {
String createFunctions = """
CREATE TABLE IF NOT EXISTS Functions (
filepath TEXT,
name TEXT,
address TEXT,
type INTEGER,
PRIMARY KEY (name, filepath)
)""";
String createImports = """
CREATE TABLE IF NOT EXISTS Imports (
filepath TEXT,
name TEXT,
address TEXT,
type INTEGER,
PRIMARY KEY (name, filepath)
)""";
connection.prepareStatement(createFunctions).executeUpdate();
connection.prepareStatement(createImports).executeUpdate();
}
// Helper method to find entries by name
public List<Entry> findEntriesByName(String name) throws Exception {
ensureConnection();
List<Entry> results = new ArrayList<>();
try {
// Search Functions table
findByNameFunctions.setString(1, name);
try (ResultSet rs = findByNameFunctions.executeQuery()) {
while (rs.next()) {
Entry entry = createEntryFromResultSet(rs);
if (entry != null) {
results.add(entry);
}
}
}
// Search Imports table
findByNameImports.setString(1, name);
try (ResultSet rs = findByNameImports.executeQuery()) {
while (rs.next()) {
Entry entry = createEntryFromResultSet(rs);
if (entry != null) {
results.add(entry);
}
}
}
return results;
} catch (SQLException e) {
script.println("Error finding entries by name: " + e.getMessage());
throw new Exception("Failed to find entries by name", e);
}
}
// Helper method to find entries by address
public List<Entry> findEntriesByAddress(Address address) throws Exception {
ensureConnection();
List<Entry> results = new ArrayList<>();
String addressStr = address.toString();
try {
// Search Functions table
findByAddressFunctions.setString(1, addressStr);
try (ResultSet rs = findByAddressFunctions.executeQuery()) {
while (rs.next()) {
Entry entry = createEntryFromResultSet(rs);
if (entry != null) {
results.add(entry);
}
}
}
// Search Imports table
findByAddressImports.setString(1, addressStr);
try (ResultSet rs = findByAddressImports.executeQuery()) {
while (rs.next()) {
Entry entry = createEntryFromResultSet(rs);
if (entry != null) {
results.add(entry);
}
}
}
return results;
} catch (SQLException e) {
script.println("Error finding entries by address: " + e.getMessage());
throw new Exception("Failed to find entries by address", e);
}
}
// Helper method to add/update entry (insert or replace based on filename)
public void addEntryAt(Entry entry) throws Exception {
ensureConnection();
String relativePath = new File(RemanConfig.INSTANCE.outputDir).toPath()
.relativize(entry.file.toPath()).toString().replace('\\', '/');
try {
insertOrReplaceFunctions.setString(1, relativePath);
insertOrReplaceFunctions.setString(2, entry.name);
insertOrReplaceFunctions.setString(3, entry.address.toString());
insertOrReplaceFunctions.setInt(4, entry.type.getValue());
insertOrReplaceFunctions.executeUpdate();
script.println("Added/updated entry: " + entry.name + " at " + entry.address + " in " + relativePath);
} catch (SQLException e) {
script.println("Error adding entry: " + e.getMessage());
throw new Exception("Failed to add entry", e);
}
}
// Helper method to remove entry by file path
public void removeEntryAt(String filePath) throws Exception {
ensureConnection();
String relativePath = new File(RemanConfig.INSTANCE.outputDir).toPath()
.relativize(new File(filePath).toPath()).toString().replace('\\', '/');
try {
deleteByFilepathFunctions.setString(1, relativePath);
int deletedCount = deleteByFilepathFunctions.executeUpdate();
deleteByFilepathImports.setString(1, relativePath);
deletedCount += deleteByFilepathImports.executeUpdate();
script.println("Removed " + deletedCount + " entries for file: " + relativePath);
} catch (SQLException e) {
script.println("Error removing entries: " + e.getMessage());
throw new Exception("Failed to remove entries", e);
}
}
public void add(Entry entry) throws Exception {
// Add entry directly to database
addEntryAt(entry);
}
public void applyDefaultFilters(boolean rebuildAllGlobals) throws Exception {
GlobalDumper globalDumper = new GlobalDumper(script);
FunctionDumper dumper = new FunctionDumper(script, globalDumper);
if (rebuildAllGlobals) {
globalDumper.removeGlobalManifest();
}
boolean madeAnyChanges = false;
// Load all entries from database
List<Entry> entries = loadAllEntries();
// Create a hash map to store symbol names
Map<Address, String> symbolNames = new HashMap<>();
Map<String, File> exportedFunctionNames = new HashMap<>();
for (Entry entry : entries) {
Function function = script.getFunctionAt(entry.address);
if (function != null) {
boolean isAuto = entry.type == Type.Auto;
boolean isFix = entry.type == Type.Fix;
// Get the actual symbol name and store it in the hash map
String symbolName = function.getName();
symbolNames.put(entry.address, symbolName);
if (isAuto && !exportedFunctionNames.containsKey(entry.name)) {
exportedFunctionNames.put(entry.name, entry.file);
} else if (isFix) {
exportedFunctionNames.replace(entry.name, entry.file);
}
}
}
// Print the number of symbol names collected
script.println("Collected " + symbolNames.size() + " symbol names");
boolean dryMode = false;
HashSet<Function> functionsToRegenerate = new HashSet<>();
Iterator<Entry> iterator = entries.iterator();
while (iterator.hasNext()) {
Entry entry = iterator.next();
Function function = script.getFunctionAt(entry.address);
boolean pendingDelete = false;
boolean pendingRegenerate = false;
if (rebuildAllGlobals) {
pendingRegenerate = true;
}
// Remove CRT and other blacklisted functions
if (function == null || !dumper.isValidFunction(function)) {
// Remove the file
if (entry.file != null && entry.file.exists()) {
script.println("Removed file: " + entry.file.getAbsolutePath());
pendingDelete = true;
}
// Remove entry from the list
script.println("Removed invalid function entry: " + entry.name + " at " + entry.address);
function = null;
}
// Check if symbol name matches the symbol name parsed from the file
if (function != null) {
String actualSymbolName = symbolNames.get(entry.address);
if (actualSymbolName == null) {
throw new Exception(
"Symbol name not found for function at " + entry.address + " in file " + entry.file.getAbsolutePath());
}
if (actualSymbolName != null && !actualSymbolName.equals(entry.name)) {
File fnExportedFile = exportedFunctionNames.get(entry.name);
if (fnExportedFile != null && fnExportedFile != entry.file) {
// Already exists elsewhere, so remove this file
script.println("Removing duplicate function: " + entry.name + " at " + entry.address + " overridden by "
+ fnExportedFile);
pendingDelete = true;
} else {
// Regeneral this function
script.println("Symbol name mismatch for function at " + entry.address + ": " +
"File name: " + entry.name + ", Actual symbol: " + actualSymbolName);
entry.name = actualSymbolName; // Update the entry name to match the actual symbol
pendingRegenerate = true;
}
}
entry.name = actualSymbolName; // Update the entry name to match the actual symbol
madeAnyChanges = true;
}
if (pendingDelete) {
iterator.remove();
if (!dryMode) {
entry.file.delete();
// Remove from database
removeEntryAt(entry.file.getAbsolutePath());
madeAnyChanges = true;
}
} else if (pendingRegenerate && entry.type != Type.Stub) {
if (!dryMode) {
functionsToRegenerate.add(function);
// Update entry in database with corrected name
addEntryAt(entry);
madeAnyChanges = true;
}
}
}
for (Function function : functionsToRegenerate) {
script.println("Regenerating function: " + function.getName() + " at " + function.getEntryPoint());
dumper.dump(function);
}
if (madeAnyChanges) {
// Update CMake timestamp
RemanConfig.INSTANCE.touchCMakeTimestamp();
globalDumper.dumpGlobals();
globalDumper.saveGlobalManifest();
TypeDumper typeDumper = new TypeDumper(script);
typeDumper.run();
}
}
}