#include "tool.hpp" #include #include #include // Global address regex pattern const std::regex ADDRESS_REGEX(R"(//\s*([0-9a-fA-F]{8}))"); // Helper function to convert string to FileType FileType stringToFileType(const std::string &type_str) { if (type_str == "auto") return FileType::Auto; if (type_str == "fix") return FileType::Fix; if (type_str == "stub") return FileType::Stub; if (type_str == "ref") return FileType::Ref; throw std::invalid_argument("Invalid file type: " + type_str); } // Helper function to convert FileType to string std::string fileTypeToString(FileType type) { switch (type) { case FileType::Auto: return "auto"; case FileType::Fix: return "fix"; case FileType::Stub: return "stub"; case FileType::Ref: return "ref"; default: throw std::logic_error("Invalid file type: " + std::to_string((int)type)); } } // Helper function to check if a comment contains an address bool hasAddressPattern(const std::string &comment) { return std::regex_search(comment, ADDRESS_REGEX); } std::string extractAddress(const std::string &comment) { std::smatch match; return std::regex_search(comment, match, ADDRESS_REGEX) ? match[1].str() : ""; } // Helper function to extract text from a TSNode std::string extractNodeText(TSNode node, const char *source_code) { uint32_t start = ts_node_start_byte(node); uint32_t end = ts_node_end_byte(node); return std::string(source_code + start, end - start); } // Helper function to find first identifier in a node std::string findIdentifierInNode(TSNode node, const char *source_code) { uint32_t child_count = ts_node_child_count(node); for (uint32_t i = 0; i < child_count; i++) { TSNode child = ts_node_child(node, i); if (strcmp(ts_node_type(child), "identifier") == 0) { return extractNodeText(child, source_code); } } return ""; } // Helper function to recursively find identifier in any declarator std::string findIdentifierInDeclarator(TSNode node, const char *source_code) { const char *type = ts_node_type(node); // If this is an identifier, return it if (strcmp(type, "identifier") == 0) { return extractNodeText(node, source_code); } // Recursively search all children uint32_t child_count = ts_node_child_count(node); for (uint32_t i = 0; i < child_count; i++) { TSNode child = ts_node_child(node, i); std::string result = findIdentifierInDeclarator(child, source_code); if (!result.empty()) { return result; } } return ""; } std::string getFunctionName(TSNode node, const char *source_code) { uint32_t child_count = ts_node_child_count(node); for (uint32_t i = 0; i < child_count; i++) { TSNode child = ts_node_child(node, i); const char *type = ts_node_type(child); if (strcmp(type, "function_declarator") == 0) { std::string name = findIdentifierInNode(child, source_code); if (!name.empty()) return name; } else if (strcmp(type, "identifier") == 0) { return extractNodeText(child, source_code); } else if (strcmp(type, "pointer_declarator") == 0) { std::string name = getFunctionName(child, source_code); if (!name.empty()) return name; } } return ""; } std::string getGlobalName(TSNode node, const char *source_code) { uint32_t child_count = ts_node_child_count(node); for (uint32_t i = 0; i < child_count; i++) { TSNode child = ts_node_child(node, i); const char *type = ts_node_type(child); // Look for any kind of declarator and recursively search for identifier if (strcmp(type, "init_declarator") == 0 || strcmp(type, "declarator") == 0 || strcmp(type, "reference_declarator") == 0 || strcmp(type, "pointer_declarator") == 0 || strcmp(type, "parenthesized_declarator") == 0 || strcmp(type, "array_declarator") == 0) { std::string name = findIdentifierInDeclarator(child, source_code); if (!name.empty()) { return name; } } // Direct identifier child else if (strcmp(type, "identifier") == 0) { return extractNodeText(child, source_code); } } return ""; } std::string getComment(TSNode node, const char *source_code, uint32_t source_length, bool search_before) { TSNode current = node; if (search_before) { // Look for comments before the current node while (!ts_node_is_null(current)) { TSNode prev_sibling = ts_node_prev_sibling(current); while (!ts_node_is_null(prev_sibling)) { const char *type = ts_node_type(prev_sibling); if (strcmp(type, "comment") == 0) { std::string comment_text = extractNodeText(prev_sibling, source_code); // Check if it contains an address pattern if (hasAddressPattern(comment_text)) { return comment_text; } } // Skip whitespace and continue looking else if (strcmp(type, "ERROR") != 0) { // If we hit non-comment, non-whitespace content, stop searching break; } prev_sibling = ts_node_prev_sibling(prev_sibling); } // Move up to parent and continue searching current = ts_node_parent(current); } } else { // Look for comments after the current node TSNode next_sibling = ts_node_next_sibling(node); while (!ts_node_is_null(next_sibling)) { const char *type = ts_node_type(next_sibling); if (strcmp(type, "comment") == 0) { std::string comment_text = extractNodeText(next_sibling, source_code); // Check if it contains an address pattern if (hasAddressPattern(comment_text)) { return comment_text; } } // Skip whitespace and continue looking else if (strcmp(type, "ERROR") != 0) { // If we hit non-comment, non-whitespace content, stop searching break; } next_sibling = ts_node_next_sibling(next_sibling); } } return ""; } bool hasFunctionBody(TSNode node) { if (strcmp(ts_node_type(node), "function_definition") != 0) return false; uint32_t child_count = ts_node_child_count(node); for (uint32_t i = 0; i < child_count; i++) { if (strcmp(ts_node_type(ts_node_child(node, i)), "compound_statement") == 0) { return true; } } return false; } // Add this helper function to detect calling convention from a function declaration CallingConvention getCallingConvention(TSNode node, const char *source_code) { uint32_t child_count = ts_node_child_count(node); for (uint32_t i = 0; i < child_count; i++) { TSNode child = ts_node_child(node, i); const char *type = ts_node_type(child); // Handle Microsoft calling convention modifiers (tree-sitter specific) if (strcmp(type, "ms_call_modifier") == 0) { std::string text = extractNodeText(child, source_code); if (text == "__fastcall" || text == "__FASTCALL") { return CallingConvention::Fastcall; } else if (text == "__stdcall" || text == "__STDCALL") { return CallingConvention::Stdcall; } else if (text == "__cdecl" || text == "__CDECL") { return CallingConvention::Cdecl; } } // Look for identifiers that might be calling conventions (fallback) else if (strcmp(type, "identifier") == 0) { std::string text = extractNodeText(child, source_code); if (text == "__fastcall" || text == "__FASTCALL") { return CallingConvention::Fastcall; } else if (text == "__stdcall" || text == "__STDCALL") { return CallingConvention::Stdcall; } else if (text == "__cdecl" || text == "__CDECL") { return CallingConvention::Cdecl; } } // Also check in type specifiers and declaration specifiers else if (strcmp(type, "type_specifier") == 0 || strcmp(type, "declaration_specifier") == 0 || strcmp(type, "specifier_qualifier_list") == 0) { CallingConvention conv = getCallingConvention(child, source_code); if (conv != CallingConvention::Cdecl) { return conv; } } // Check in declarators as well (calling convention can appear in various positions) else if (strcmp(type, "function_declarator") == 0 || strcmp(type, "pointer_declarator") == 0) { CallingConvention conv = getCallingConvention(child, source_code); if (conv != CallingConvention::Cdecl) { return conv; } } } return CallingConvention::Cdecl; // Default } void findFunctions(TSNode node, const char *source_code, uint32_t source_length, std::vector &functions, FileType file_type) { const char *type = ts_node_type(node); if (strcmp(type, "function_definition") == 0 || strcmp(type, "declaration") == 0) { std::string func_name = getFunctionName(node, source_code); if (!func_name.empty()) { std::string address = extractAddress(getComment(node, source_code, source_length, false)); if (address.empty() && strcmp(type, "function_definition") == 0) { address = extractAddress(getComment(node, source_code, source_length, true)); } if (!address.empty()) { // Detect calling convention CallingConvention calling_conv = getCallingConvention(node, source_code); // Extract parameter information and return type std::string parameter_names = getParameterNames(node, source_code); std::string parameter_types = getParameterTypes(node, source_code); std::string return_type = getReturnType(node, source_code); FunctionInfo func; func.name = func_name; func.address = address; func.filepath = ""; func.parameter_names = parameter_names; func.parameter_types = parameter_types; func.return_type = return_type; func.is_import = strcmp(type, "function_definition") == 0 ? !hasFunctionBody(node) : true; func.type = file_type; func.calling_convention = calling_conv; spdlog::debug("Found function: {} at {} with calling convention: {}, return type: {}, param names: {}, param types: {}", func_name, address, callingConventionToString(calling_conv), return_type, parameter_names, parameter_types); functions.push_back(func); } // We'll never nest function declarations return; } else { spdlog::error("Failed to get function name for {}", extractNodeText(node, source_code)); } } uint32_t child_count = ts_node_child_count(node); for (uint32_t i = 0; i < child_count; i++) { findFunctions(ts_node_child(node, i), source_code, source_length, functions, file_type); } } void findGlobals(TSNode node, const char *source_code, uint32_t source_length, std::vector &globals) { const char *type = ts_node_type(node); // Look for extern declarations if (strcmp(type, "declaration") == 0) { // Check if this is an extern declaration uint32_t child_count = ts_node_child_count(node); bool is_extern = false; for (uint32_t i = 0; i < child_count; i++) { TSNode child = ts_node_child(node, i); if (strcmp(ts_node_type(child), "storage_class_specifier") == 0) { std::string storage_class = extractNodeText(child, source_code); if (storage_class == "extern") { is_extern = true; break; } } } if (is_extern) { std::string global_name = getGlobalName(node, source_code); if (!global_name.empty()) { // Look for address comment after the declaration std::string address = extractAddress(getComment(node, source_code, source_length, false)); if (!address.empty()) { GlobalInfo global{global_name, address, ""}; globals.push_back(global); } } else { std::string src = extractNodeText(node, source_code); spdlog::error("Failed to get global name for {}", src); } return; } } // Recursively search child nodes uint32_t child_count = ts_node_child_count(node); for (uint32_t i = 0; i < child_count; i++) { findGlobals(ts_node_child(node, i), source_code, source_length, globals); } } // Add helper functions to extract parameter information and return type std::string getParameterNames(TSNode node, const char *source_code) { std::vector names; uint32_t child_count = ts_node_child_count(node); for (uint32_t i = 0; i < child_count; i++) { TSNode child = ts_node_child(node, i); const char *type = ts_node_type(child); if (strcmp(type, "function_declarator") == 0) { // Look for parameter_list within function_declarator uint32_t func_child_count = ts_node_child_count(child); for (uint32_t j = 0; j < func_child_count; j++) { TSNode func_child = ts_node_child(child, j); if (strcmp(ts_node_type(func_child), "parameter_list") == 0) { // Extract each parameter_declaration uint32_t param_list_count = ts_node_child_count(func_child); for (uint32_t k = 0; k < param_list_count; k++) { TSNode param_node = ts_node_child(func_child, k); if (strcmp(ts_node_type(param_node), "parameter_declaration") == 0) { // Find the identifier (parameter name) in this parameter std::string param_name = findIdentifierInDeclarator(param_node, source_code); if (!param_name.empty()) { names.push_back(param_name); } } } break; } } break; } } // Join names with semicolon std::string result; for (size_t i = 0; i < names.size(); i++) { if (i > 0) result += ";"; result += names[i]; } return result; } std::string getParameterTypes(TSNode node, const char *source_code) { std::vector types; uint32_t child_count = ts_node_child_count(node); for (uint32_t i = 0; i < child_count; i++) { TSNode child = ts_node_child(node, i); const char *type = ts_node_type(child); if (strcmp(type, "function_declarator") == 0) { // Look for parameter_list within function_declarator uint32_t func_child_count = ts_node_child_count(child); for (uint32_t j = 0; j < func_child_count; j++) { TSNode func_child = ts_node_child(child, j); if (strcmp(ts_node_type(func_child), "parameter_list") == 0) { // Extract each parameter_declaration uint32_t param_list_count = ts_node_child_count(func_child); for (uint32_t k = 0; k < param_list_count; k++) { TSNode param_node = ts_node_child(func_child, k); if (strcmp(ts_node_type(param_node), "parameter_declaration") == 0) { // Extract type part by getting everything except the identifier std::string full_param = extractNodeText(param_node, source_code); std::string param_name = findIdentifierInDeclarator(param_node, source_code); // Remove the parameter name from the full parameter text to get just the type std::string param_type = full_param; if (!param_name.empty()) { size_t name_pos = param_type.rfind(param_name); if (name_pos != std::string::npos) { param_type = param_type.substr(0, name_pos); // Clean up trailing whitespace and pointer/reference symbols after the type param_type.erase(param_type.find_last_not_of(" \t\n\r*&") + 1); // Add back any pointer/reference symbols that were part of the type size_t type_end = name_pos; while (type_end < full_param.length() && (full_param[type_end] == ' ' || full_param[type_end] == '\t')) { type_end++; } if (type_end > name_pos + param_name.length()) { param_type += full_param.substr(name_pos + param_name.length(), type_end - (name_pos + param_name.length())); } } } // Clean up any extra whitespace param_type.erase(0, param_type.find_first_not_of(" \t\n\r")); param_type.erase(param_type.find_last_not_of(" \t\n\r") + 1); if (!param_type.empty()) { types.push_back(param_type); } } } break; } } break; } } // Join types with semicolon std::string result; for (size_t i = 0; i < types.size(); i++) { if (i > 0) result += ";"; result += types[i]; } return result; } std::string getReturnType(TSNode node, const char *source_code) { uint32_t child_count = ts_node_child_count(node); for (uint32_t i = 0; i < child_count; i++) { TSNode child = ts_node_child(node, i); const char *type = ts_node_type(child); // Look for type specifiers (return type comes before function declarator) if (strcmp(type, "type_identifier") == 0 || strcmp(type, "primitive_type") == 0 || strcmp(type, "sized_type_specifier") == 0) { return extractNodeText(child, source_code); } // Handle more complex types else if (strcmp(type, "qualified_identifier") == 0 || strcmp(type, "template_type") == 0) { return extractNodeText(child, source_code); } } return "void"; // Default return type }