#include "tool.hpp" #include #include #include // Global address regex pattern const std::regex ADDRESS_REGEX(R"(//\s*([0-9a-fA-F]{8}))"); // Helper function to convert string to FileType FileType stringToFileType(const std::string &type_str) { if (type_str == "auto") return FileType::Auto; if (type_str == "fix") return FileType::Fix; if (type_str == "stub") return FileType::Stub; if (type_str == "ref") return FileType::Ref; throw std::invalid_argument("Invalid file type: " + type_str); } // Helper function to convert FileType to string std::string fileTypeToString(FileType type) { switch (type) { case FileType::Auto: return "auto"; case FileType::Fix: return "fix"; case FileType::Stub: return "stub"; case FileType::Ref: return "ref"; default: throw std::logic_error("Invalid file type: " + std::to_string((int)type)); } } // Helper function to check if a comment contains an address bool hasAddressPattern(const std::string &comment) { return std::regex_search(comment, ADDRESS_REGEX); } std::string extractAddress(const std::string &comment) { std::smatch match; return std::regex_search(comment, match, ADDRESS_REGEX) ? match[1].str() : ""; } // Helper function to extract text from a TSNode std::string extractNodeText(TSNode node, const char *source_code) { uint32_t start = ts_node_start_byte(node); uint32_t end = ts_node_end_byte(node); return std::string(source_code + start, end - start); } // Helper function to find first identifier in a node std::string findIdentifierInNode(TSNode node, const char *source_code) { uint32_t child_count = ts_node_child_count(node); for (uint32_t i = 0; i < child_count; i++) { TSNode child = ts_node_child(node, i); if (strcmp(ts_node_type(child), "identifier") == 0) { return extractNodeText(child, source_code); } } return ""; } // Helper function to recursively find identifier in any declarator std::string findIdentifierInDeclarator(TSNode node, const char *source_code) { const char *type = ts_node_type(node); // If this is an identifier, return it if (strcmp(type, "identifier") == 0) { return extractNodeText(node, source_code); } // Recursively search all children uint32_t child_count = ts_node_child_count(node); for (uint32_t i = 0; i < child_count; i++) { TSNode child = ts_node_child(node, i); std::string result = findIdentifierInDeclarator(child, source_code); if (!result.empty()) { return result; } } return ""; } std::string getFunctionName(TSNode node, const char *source_code) { uint32_t child_count = ts_node_child_count(node); for (uint32_t i = 0; i < child_count; i++) { TSNode child = ts_node_child(node, i); const char *type = ts_node_type(child); if (strcmp(type, "function_declarator") == 0) { std::string name = findIdentifierInNode(child, source_code); if (!name.empty()) return name; } else if (strcmp(type, "identifier") == 0) { return extractNodeText(child, source_code); } else if (strcmp(type, "pointer_declarator") == 0) { std::string name = getFunctionName(child, source_code); if (!name.empty()) return name; } } return ""; } std::string getGlobalName(TSNode node, const char *source_code) { uint32_t child_count = ts_node_child_count(node); for (uint32_t i = 0; i < child_count; i++) { TSNode child = ts_node_child(node, i); const char *type = ts_node_type(child); // Look for any kind of declarator and recursively search for identifier if (strcmp(type, "init_declarator") == 0 || strcmp(type, "declarator") == 0 || strcmp(type, "reference_declarator") == 0 || strcmp(type, "pointer_declarator") == 0 || strcmp(type, "parenthesized_declarator") == 0 || strcmp(type, "array_declarator") == 0) { std::string name = findIdentifierInDeclarator(child, source_code); if (!name.empty()) { return name; } } // Direct identifier child else if (strcmp(type, "identifier") == 0) { return extractNodeText(child, source_code); } } return ""; } std::string getComment(TSNode node, const char *source_code, uint32_t source_length, bool search_before) { TSNode current = node; if (search_before) { // Look for comments before the current node while (!ts_node_is_null(current)) { TSNode prev_sibling = ts_node_prev_sibling(current); while (!ts_node_is_null(prev_sibling)) { const char *type = ts_node_type(prev_sibling); if (strcmp(type, "comment") == 0) { std::string comment_text = extractNodeText(prev_sibling, source_code); // Check if it contains an address pattern if (hasAddressPattern(comment_text)) { return comment_text; } } // Skip whitespace and continue looking else if (strcmp(type, "ERROR") != 0) { // If we hit non-comment, non-whitespace content, stop searching break; } prev_sibling = ts_node_prev_sibling(prev_sibling); } // Move up to parent and continue searching current = ts_node_parent(current); } } else { // Look for comments after the current node TSNode next_sibling = ts_node_next_sibling(node); while (!ts_node_is_null(next_sibling)) { const char *type = ts_node_type(next_sibling); if (strcmp(type, "comment") == 0) { std::string comment_text = extractNodeText(next_sibling, source_code); // Check if it contains an address pattern if (hasAddressPattern(comment_text)) { return comment_text; } } // Skip whitespace and continue looking else if (strcmp(type, "ERROR") != 0) { // If we hit non-comment, non-whitespace content, stop searching break; } next_sibling = ts_node_next_sibling(next_sibling); } } return ""; } bool hasFunctionBody(TSNode node) { if (strcmp(ts_node_type(node), "function_definition") != 0) return false; uint32_t child_count = ts_node_child_count(node); for (uint32_t i = 0; i < child_count; i++) { if (strcmp(ts_node_type(ts_node_child(node, i)), "compound_statement") == 0) { return true; } } return false; } // Add this helper function to detect calling convention from a function declaration CallingConvention getCallingConvention(TSNode node, const char *source_code) { uint32_t child_count = ts_node_child_count(node); for (uint32_t i = 0; i < child_count; i++) { TSNode child = ts_node_child(node, i); const char *type = ts_node_type(child); // Look for identifiers that might be calling conventions if (strcmp(type, "identifier") == 0) { std::string text = extractNodeText(child, source_code); if (text == "__fastcall" || text == "__FASTCALL") { return CallingConvention::Fastcall; } else if (text == "__stdcall" || text == "__STDCALL") { return CallingConvention::Stdcall; } else if (text == "__cdecl" || text == "__CDECL") { return CallingConvention::Cdecl; } } // Also check in type specifiers and declaration specifiers else if (strcmp(type, "type_specifier") == 0 || strcmp(type, "declaration_specifier") == 0 || strcmp(type, "specifier_qualifier_list") == 0) { CallingConvention conv = getCallingConvention(child, source_code); if (conv != CallingConvention::Cdecl) { return conv; } } // Check in declarators as well (calling convention can appear in various positions) else if (strcmp(type, "function_declarator") == 0 || strcmp(type, "pointer_declarator") == 0) { CallingConvention conv = getCallingConvention(child, source_code); if (conv != CallingConvention::Cdecl) { return conv; } } } return CallingConvention::Cdecl; // Default } void findFunctions(TSNode node, const char *source_code, uint32_t source_length, std::vector &functions, FileType file_type) { const char *type = ts_node_type(node); if (strcmp(type, "function_definition") == 0 || strcmp(type, "declaration") == 0) { std::string func_name = getFunctionName(node, source_code); if (!func_name.empty()) { std::string address = extractAddress(getComment(node, source_code, source_length, false)); if (address.empty() && strcmp(type, "function_definition") == 0) { address = extractAddress(getComment(node, source_code, source_length, true)); } if (!address.empty()) { // Detect calling convention CallingConvention calling_conv = getCallingConvention(node, source_code); FunctionInfo func{func_name, address, "", strcmp(type, "function_definition") == 0 ? !hasFunctionBody(node) : true, file_type}; // Add file_type parameter func.calling_convention = calling_conv; // Set the calling convention spdlog::debug("Found function: {} at {} with calling convention: {}", func_name, address, callingConventionToString(calling_conv)); functions.push_back(func); } // We'll never nest function declarations return; } else { spdlog::error("Failed to get function name for {}", extractNodeText(node, source_code)); } } uint32_t child_count = ts_node_child_count(node); for (uint32_t i = 0; i < child_count; i++) { findFunctions(ts_node_child(node, i), source_code, source_length, functions, file_type); } } void findGlobals(TSNode node, const char *source_code, uint32_t source_length, std::vector &globals) { const char *type = ts_node_type(node); // Look for extern declarations if (strcmp(type, "declaration") == 0) { // Check if this is an extern declaration uint32_t child_count = ts_node_child_count(node); bool is_extern = false; for (uint32_t i = 0; i < child_count; i++) { TSNode child = ts_node_child(node, i); if (strcmp(ts_node_type(child), "storage_class_specifier") == 0) { std::string storage_class = extractNodeText(child, source_code); if (storage_class == "extern") { is_extern = true; break; } } } if (is_extern) { std::string global_name = getGlobalName(node, source_code); if (!global_name.empty()) { // Look for address comment after the declaration std::string address = extractAddress(getComment(node, source_code, source_length, false)); if (!address.empty()) { GlobalInfo global{global_name, address, ""}; globals.push_back(global); } } else { std::string src = extractNodeText(node, source_code); spdlog::error("Failed to get global name for {}", src); } return; } } // Recursively search child nodes uint32_t child_count = ts_node_child_count(node); for (uint32_t i = 0; i < child_count; i++) { findGlobals(ts_node_child(node, i), source_code, source_length, globals); } }