diff --git a/tooling/tool.cpp b/tooling/tool.cpp index 51f60df2..72accceb 100644 --- a/tooling/tool.cpp +++ b/tooling/tool.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include extern "C" TSLanguage *tree_sitter_cpp(); @@ -20,6 +21,12 @@ struct FunctionInfo { bool is_import; }; +struct GlobalInfo { + std::string name; + std::string address; + std::string filepath; +}; + class PreparedStatements { private: sqlite3 *db; @@ -27,6 +34,8 @@ private: sqlite3_stmt *delete_imports_stmt; sqlite3_stmt *insert_functions_stmt; sqlite3_stmt *insert_imports_stmt; + sqlite3_stmt *delete_globals_stmt; + sqlite3_stmt *insert_globals_stmt; void prepareStatement(const char *sql, sqlite3_stmt **stmt, const std::string &error_msg) { @@ -51,6 +60,13 @@ public: "VALUES (?, ?, ?)", &insert_imports_stmt, "Failed to prepare insert imports statement"); + prepareStatement("DELETE FROM Globals WHERE filepath = ?", + &delete_globals_stmt, + "Failed to prepare delete globals statement"); + prepareStatement("INSERT OR REPLACE INTO Globals (filepath, name, address) " + "VALUES (?, ?, ?)", + &insert_globals_stmt, + "Failed to prepare insert globals statement"); } ~PreparedStatements() { @@ -58,6 +74,8 @@ public: sqlite3_finalize(delete_imports_stmt); sqlite3_finalize(insert_functions_stmt); sqlite3_finalize(insert_imports_stmt); + sqlite3_finalize(delete_globals_stmt); + sqlite3_finalize(insert_globals_stmt); } void clearEntriesForFile(const std::string &filepath) { @@ -68,6 +86,12 @@ public: } } + void clearGlobalsForFile(const std::string &filepath) { + sqlite3_reset(delete_globals_stmt); + sqlite3_bind_text(delete_globals_stmt, 1, filepath.c_str(), -1, SQLITE_STATIC); + sqlite3_step(delete_globals_stmt); + } + void insertFunction(const FunctionInfo &func) { sqlite3_stmt *stmt = func.is_import ? insert_imports_stmt : insert_functions_stmt; @@ -77,6 +101,14 @@ public: sqlite3_bind_text(stmt, 3, func.address.c_str(), -1, SQLITE_STATIC); sqlite3_step(stmt); } + + void insertGlobal(const GlobalInfo &global) { + sqlite3_reset(insert_globals_stmt); + sqlite3_bind_text(insert_globals_stmt, 1, global.filepath.c_str(), -1, SQLITE_STATIC); + sqlite3_bind_text(insert_globals_stmt, 2, global.name.c_str(), -1, SQLITE_STATIC); + sqlite3_bind_text(insert_globals_stmt, 3, global.address.c_str(), -1, SQLITE_STATIC); + sqlite3_step(insert_globals_stmt); + } }; class DatabaseManager { @@ -95,6 +127,7 @@ public: const char *create_tables = R"( CREATE TABLE IF NOT EXISTS Functions (filepath TEXT, name TEXT, address TEXT, PRIMARY KEY (name, filepath)); CREATE TABLE IF NOT EXISTS Imports (filepath TEXT, name TEXT, address TEXT, PRIMARY KEY (name, filepath)); + CREATE TABLE IF NOT EXISTS Globals (filepath TEXT, name TEXT, address TEXT, PRIMARY KEY (name, filepath)); )"; sqlite3_exec(db, create_tables, nullptr, nullptr, nullptr); @@ -112,6 +145,12 @@ public: void insertFunction(const FunctionInfo &func) { prepared_stmts->insertFunction(func); } + void clearGlobalsForFile(const std::string &filepath) { + prepared_stmts->clearGlobalsForFile(filepath); + } + void insertGlobal(const GlobalInfo &global) { + prepared_stmts->insertGlobal(global); + } void beginTransaction() { sqlite3_exec(db, "BEGIN TRANSACTION", nullptr, nullptr, nullptr); } @@ -322,13 +361,181 @@ bool processFile(const std::string &filepath, DatabaseManager &db) { return true; } +std::string getGlobalName(TSNode node, const char *source_code) { + uint32_t child_count = ts_node_child_count(node); + + for (uint32_t i = 0; i < child_count; i++) { + TSNode child = ts_node_child(node, i); + const char *type = ts_node_type(child); + + // Handle reference declarators like "undefined& DAT_00000004" (direct child) + if (strcmp(type, "reference_declarator") == 0) { + uint32_t ref_children = ts_node_child_count(child); + for (uint32_t k = 0; k < ref_children; k++) { + TSNode ref_child = ts_node_child(child, k); + if (strcmp(ts_node_type(ref_child), "identifier") == 0) { + uint32_t start = ts_node_start_byte(ref_child); + uint32_t end = ts_node_end_byte(ref_child); + return std::string(source_code + start, end - start); + } + } + } + // Look for declarator in the declaration + else if (strcmp(type, "init_declarator") == 0 || strcmp(type, "declarator") == 0) { + uint32_t declarator_children = ts_node_child_count(child); + for (uint32_t j = 0; j < declarator_children; j++) { + TSNode declarator_child = ts_node_child(child, j); + const char *child_type = ts_node_type(declarator_child); + + // Handle reference declarators like "undefined& DAT_00000004" + if (strcmp(child_type, "reference_declarator") == 0) { + uint32_t ref_children = ts_node_child_count(declarator_child); + for (uint32_t k = 0; k < ref_children; k++) { + TSNode ref_child = ts_node_child(declarator_child, k); + if (strcmp(ts_node_type(ref_child), "identifier") == 0) { + uint32_t start = ts_node_start_byte(ref_child); + uint32_t end = ts_node_end_byte(ref_child); + return std::string(source_code + start, end - start); + } + } + } + // Handle array declarators like "char(&s_or_press_ESC_to_quit_Rayman_3__005b662c)[32]" + else if (strcmp(child_type, "parenthesized_declarator") == 0) { + uint32_t paren_children = ts_node_child_count(declarator_child); + for (uint32_t k = 0; k < paren_children; k++) { + TSNode paren_child = ts_node_child(declarator_child, k); + if (strcmp(ts_node_type(paren_child), "reference_declarator") == 0) { + uint32_t ref_children = ts_node_child_count(paren_child); + for (uint32_t l = 0; l < ref_children; l++) { + TSNode ref_child = ts_node_child(paren_child, l); + if (strcmp(ts_node_type(ref_child), "identifier") == 0) { + uint32_t start = ts_node_start_byte(ref_child); + uint32_t end = ts_node_end_byte(ref_child); + return std::string(source_code + start, end - start); + } + } + } + } + } + // Handle simple identifiers + else if (strcmp(child_type, "identifier") == 0) { + uint32_t start = ts_node_start_byte(declarator_child); + uint32_t end = ts_node_end_byte(declarator_child); + return std::string(source_code + start, end - start); + } + } + } + // Direct identifier child + else if (strcmp(type, "identifier") == 0) { + uint32_t start = ts_node_start_byte(child); + uint32_t end = ts_node_end_byte(child); + return std::string(source_code + start, end - start); + } + } + return ""; +} + +void findGlobals(TSNode node, const char *source_code, uint32_t source_length, + std::vector &globals) { + const char *type = ts_node_type(node); + + // Look for extern declarations + if (strcmp(type, "declaration") == 0) { + // Check if this is an extern declaration + uint32_t child_count = ts_node_child_count(node); + bool is_extern = false; + + for (uint32_t i = 0; i < child_count; i++) { + TSNode child = ts_node_child(node, i); + if (strcmp(ts_node_type(child), "storage_class_specifier") == 0) { + uint32_t start = ts_node_start_byte(child); + uint32_t end = ts_node_end_byte(child); + std::string storage_class(source_code + start, end - start); + if (storage_class == "extern") { + is_extern = true; + break; + } + } + } + + if (is_extern) { + std::string global_name = getGlobalName(node, source_code); + if (!global_name.empty()) { + // Look for address comment after the declaration + std::string address = extractAddress(getComment(node, source_code, source_length, false)); + + if (!address.empty()) { + GlobalInfo global{global_name, address, ""}; + globals.push_back(global); + } + } else { + size_t start = ts_node_start_byte(node); + size_t end = ts_node_end_byte(node); + std::string_view src = std::string_view(source_code + start, end - start); + SPDLOG_ERROR("Failed to get global name for {}", src); + } + } + } + + // Recursively search child nodes + uint32_t child_count = ts_node_child_count(node); + for (uint32_t i = 0; i < child_count; i++) { + findGlobals(ts_node_child(node, i), source_code, source_length, globals); + } +} + +bool processGlobalsFile(const std::string &filepath, DatabaseManager &db) { + std::ifstream file(filepath); + if (!file.is_open()) { + std::cerr << "Error: Could not open file " << filepath << std::endl; + return false; + } + + std::string file_content((std::istreambuf_iterator(file)), + std::istreambuf_iterator()); + + TSParser *parser = ts_parser_new(); + ts_parser_set_language(parser, tree_sitter_cpp()); + + TSTree *tree = ts_parser_parse_string(parser, nullptr, file_content.c_str(), + file_content.length()); + TSNode root_node = ts_tree_root_node(tree); + + if (ts_node_is_null(root_node)) { + std::cerr << "Error: Failed to parse file " << filepath << std::endl; + ts_tree_delete(tree); + ts_parser_delete(parser); + return false; + } + + db.clearGlobalsForFile(filepath); + + std::vector globals; + findGlobals(root_node, file_content.c_str(), file_content.length(), globals); + + for (auto &global : globals) { + global.filepath = filepath; + db.insertGlobal(global); + std::cout << "Global: " << global.name << " @ " << global.address + << " in " << filepath << std::endl; + } + + std::cout << "Processed " << globals.size() << " globals from " + << filepath << std::endl; + + ts_tree_delete(tree); + ts_parser_delete(parser); + return true; +} + int main(int argc, char *argv[]) { CLI::App app{ - "C++ Function Parser - Extracts function addresses from C++ files"}; + "C++ Function/Global Parser - Extracts function addresses or global variable addresses from C++ files"}; std::vector input_files; std::string list_file; std::string db_path = "functions.db"; + std::string mode = "functions"; app.add_option("files", input_files, "Input C++ files to parse (supports @listfile.txt syntax)"); @@ -336,6 +543,9 @@ int main(int argc, char *argv[]) { "File containing list of files to process"); app.add_option("-d,--database", db_path, "SQLite database path") ->default_val("functions.db"); + app.add_option("-m,--mode", mode, "Processing mode: 'functions' or 'globals'") + ->default_val("functions") + ->check(CLI::IsMember({"functions", "globals"})); CLI11_PARSE(app, argc, argv); @@ -375,7 +585,14 @@ int main(int argc, char *argv[]) { for (const auto &filepath : files_to_process) { std::cout << "\n=== Processing: " << filepath << " ===" << std::endl; - if (processFile(filepath, db)) + bool success = false; + if (mode == "functions") { + success = processFile(filepath, db); + } else if (mode == "globals") { + success = processGlobalsFile(filepath, db); + } + + if (success) processed_count++; if (++current_batch >= batch_size) { @@ -396,6 +613,7 @@ int main(int argc, char *argv[]) { std::cout << "\n=== Summary ===" << std::endl; std::cout << "Processed " << processed_count << " files successfully" << std::endl; + std::cout << "Mode: " << mode << std::endl; std::cout << "Database saved to: " << db_path << std::endl; } catch (const std::exception &e) {