502 lines
17 KiB
C++
502 lines
17 KiB
C++
#include "tool.hpp"
|
|
#include <cstring>
|
|
#include <stdexcept>
|
|
#include <spdlog/spdlog.h>
|
|
|
|
// Global address regex pattern
|
|
const std::regex ADDRESS_REGEX(R"(//\s*([0-9a-fA-F]{8}))");
|
|
|
|
// Helper function to convert string to FileType
|
|
FileType stringToFileType(const std::string &type_str) {
|
|
if (type_str == "auto")
|
|
return FileType::Auto;
|
|
if (type_str == "fix")
|
|
return FileType::Fix;
|
|
if (type_str == "stub")
|
|
return FileType::Stub;
|
|
if (type_str == "ref")
|
|
return FileType::Ref;
|
|
throw std::invalid_argument("Invalid file type: " + type_str);
|
|
}
|
|
|
|
// Helper function to convert FileType to string
|
|
std::string fileTypeToString(FileType type) {
|
|
switch (type) {
|
|
case FileType::Auto:
|
|
return "auto";
|
|
case FileType::Fix:
|
|
return "fix";
|
|
case FileType::Stub:
|
|
return "stub";
|
|
case FileType::Ref:
|
|
return "ref";
|
|
default:
|
|
throw std::logic_error("Invalid file type: " + std::to_string((int)type));
|
|
}
|
|
}
|
|
|
|
// Helper function to check if a comment contains an address
|
|
bool hasAddressPattern(const std::string &comment) {
|
|
return std::regex_search(comment, ADDRESS_REGEX);
|
|
}
|
|
|
|
std::string extractAddress(const std::string &comment) {
|
|
std::smatch match;
|
|
return std::regex_search(comment, match, ADDRESS_REGEX) ? match[1].str() : "";
|
|
}
|
|
|
|
// Helper function to extract text from a TSNode
|
|
std::string extractNodeText(TSNode node, const char *source_code) {
|
|
uint32_t start = ts_node_start_byte(node);
|
|
uint32_t end = ts_node_end_byte(node);
|
|
return std::string(source_code + start, end - start);
|
|
}
|
|
|
|
// Helper function to find first identifier in a node
|
|
std::string findIdentifierInNode(TSNode node, const char *source_code) {
|
|
uint32_t child_count = ts_node_child_count(node);
|
|
for (uint32_t i = 0; i < child_count; i++) {
|
|
TSNode child = ts_node_child(node, i);
|
|
if (strcmp(ts_node_type(child), "identifier") == 0) {
|
|
return extractNodeText(child, source_code);
|
|
}
|
|
}
|
|
return "";
|
|
}
|
|
|
|
// Helper function to recursively find identifier in any declarator
|
|
std::string findIdentifierInDeclarator(TSNode node, const char *source_code) {
|
|
const char *type = ts_node_type(node);
|
|
|
|
// If this is an identifier, return it
|
|
if (strcmp(type, "identifier") == 0) {
|
|
return extractNodeText(node, source_code);
|
|
}
|
|
|
|
// Recursively search all children
|
|
uint32_t child_count = ts_node_child_count(node);
|
|
for (uint32_t i = 0; i < child_count; i++) {
|
|
TSNode child = ts_node_child(node, i);
|
|
std::string result = findIdentifierInDeclarator(child, source_code);
|
|
if (!result.empty()) {
|
|
return result;
|
|
}
|
|
}
|
|
|
|
return "";
|
|
}
|
|
|
|
std::string getFunctionName(TSNode node, const char *source_code) {
|
|
uint32_t child_count = ts_node_child_count(node);
|
|
|
|
for (uint32_t i = 0; i < child_count; i++) {
|
|
TSNode child = ts_node_child(node, i);
|
|
const char *type = ts_node_type(child);
|
|
|
|
if (strcmp(type, "function_declarator") == 0) {
|
|
std::string name = findIdentifierInNode(child, source_code);
|
|
if (!name.empty())
|
|
return name;
|
|
} else if (strcmp(type, "identifier") == 0) {
|
|
return extractNodeText(child, source_code);
|
|
} else if (strcmp(type, "pointer_declarator") == 0) {
|
|
std::string name = getFunctionName(child, source_code);
|
|
if (!name.empty())
|
|
return name;
|
|
}
|
|
}
|
|
return "";
|
|
}
|
|
|
|
std::string getGlobalName(TSNode node, const char *source_code) {
|
|
uint32_t child_count = ts_node_child_count(node);
|
|
|
|
for (uint32_t i = 0; i < child_count; i++) {
|
|
TSNode child = ts_node_child(node, i);
|
|
const char *type = ts_node_type(child);
|
|
|
|
// Look for any kind of declarator and recursively search for identifier
|
|
if (strcmp(type, "init_declarator") == 0 ||
|
|
strcmp(type, "declarator") == 0 ||
|
|
strcmp(type, "reference_declarator") == 0 ||
|
|
strcmp(type, "pointer_declarator") == 0 ||
|
|
strcmp(type, "parenthesized_declarator") == 0 ||
|
|
strcmp(type, "array_declarator") == 0) {
|
|
std::string name = findIdentifierInDeclarator(child, source_code);
|
|
if (!name.empty()) {
|
|
return name;
|
|
}
|
|
}
|
|
// Direct identifier child
|
|
else if (strcmp(type, "identifier") == 0) {
|
|
return extractNodeText(child, source_code);
|
|
}
|
|
}
|
|
return "";
|
|
}
|
|
|
|
std::string getComment(TSNode node, const char *source_code,
|
|
uint32_t source_length, bool search_before) {
|
|
TSNode current = node;
|
|
|
|
if (search_before) {
|
|
// Look for comments before the current node
|
|
while (!ts_node_is_null(current)) {
|
|
TSNode prev_sibling = ts_node_prev_sibling(current);
|
|
|
|
while (!ts_node_is_null(prev_sibling)) {
|
|
const char *type = ts_node_type(prev_sibling);
|
|
|
|
if (strcmp(type, "comment") == 0) {
|
|
std::string comment_text = extractNodeText(prev_sibling, source_code);
|
|
|
|
// Check if it contains an address pattern
|
|
if (hasAddressPattern(comment_text)) {
|
|
return comment_text;
|
|
}
|
|
}
|
|
// Skip whitespace and continue looking
|
|
else if (strcmp(type, "ERROR") != 0) {
|
|
// If we hit non-comment, non-whitespace content, stop searching
|
|
break;
|
|
}
|
|
|
|
prev_sibling = ts_node_prev_sibling(prev_sibling);
|
|
}
|
|
|
|
// Move up to parent and continue searching
|
|
current = ts_node_parent(current);
|
|
}
|
|
} else {
|
|
// Look for comments after the current node
|
|
TSNode next_sibling = ts_node_next_sibling(node);
|
|
|
|
while (!ts_node_is_null(next_sibling)) {
|
|
const char *type = ts_node_type(next_sibling);
|
|
|
|
if (strcmp(type, "comment") == 0) {
|
|
std::string comment_text = extractNodeText(next_sibling, source_code);
|
|
|
|
// Check if it contains an address pattern
|
|
if (hasAddressPattern(comment_text)) {
|
|
return comment_text;
|
|
}
|
|
}
|
|
// Skip whitespace and continue looking
|
|
else if (strcmp(type, "ERROR") != 0) {
|
|
// If we hit non-comment, non-whitespace content, stop searching
|
|
break;
|
|
}
|
|
|
|
next_sibling = ts_node_next_sibling(next_sibling);
|
|
}
|
|
}
|
|
|
|
return "";
|
|
}
|
|
|
|
bool hasFunctionBody(TSNode node) {
|
|
if (strcmp(ts_node_type(node), "function_definition") != 0)
|
|
return false;
|
|
|
|
uint32_t child_count = ts_node_child_count(node);
|
|
for (uint32_t i = 0; i < child_count; i++) {
|
|
if (strcmp(ts_node_type(ts_node_child(node, i)), "compound_statement") ==
|
|
0) {
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
// Add this helper function to detect calling convention from a function declaration
|
|
CallingConvention getCallingConvention(TSNode node, const char *source_code) {
|
|
uint32_t child_count = ts_node_child_count(node);
|
|
|
|
for (uint32_t i = 0; i < child_count; i++) {
|
|
TSNode child = ts_node_child(node, i);
|
|
const char *type = ts_node_type(child);
|
|
|
|
// Handle Microsoft calling convention modifiers (tree-sitter specific)
|
|
if (strcmp(type, "ms_call_modifier") == 0) {
|
|
std::string text = extractNodeText(child, source_code);
|
|
if (text == "__fastcall" || text == "__FASTCALL") {
|
|
return CallingConvention::Fastcall;
|
|
} else if (text == "__stdcall" || text == "__STDCALL") {
|
|
return CallingConvention::Stdcall;
|
|
} else if (text == "__cdecl" || text == "__CDECL") {
|
|
return CallingConvention::Cdecl;
|
|
}
|
|
}
|
|
// Look for identifiers that might be calling conventions (fallback)
|
|
else if (strcmp(type, "identifier") == 0) {
|
|
std::string text = extractNodeText(child, source_code);
|
|
if (text == "__fastcall" || text == "__FASTCALL") {
|
|
return CallingConvention::Fastcall;
|
|
} else if (text == "__stdcall" || text == "__STDCALL") {
|
|
return CallingConvention::Stdcall;
|
|
} else if (text == "__cdecl" || text == "__CDECL") {
|
|
return CallingConvention::Cdecl;
|
|
}
|
|
}
|
|
// Also check in type specifiers and declaration specifiers
|
|
else if (strcmp(type, "type_specifier") == 0 ||
|
|
strcmp(type, "declaration_specifier") == 0 ||
|
|
strcmp(type, "specifier_qualifier_list") == 0) {
|
|
CallingConvention conv = getCallingConvention(child, source_code);
|
|
if (conv != CallingConvention::Cdecl) {
|
|
return conv;
|
|
}
|
|
}
|
|
// Check in declarators as well (calling convention can appear in various positions)
|
|
else if (strcmp(type, "function_declarator") == 0 ||
|
|
strcmp(type, "pointer_declarator") == 0) {
|
|
CallingConvention conv = getCallingConvention(child, source_code);
|
|
if (conv != CallingConvention::Cdecl) {
|
|
return conv;
|
|
}
|
|
}
|
|
}
|
|
|
|
return CallingConvention::Cdecl; // Default
|
|
}
|
|
|
|
void findFunctions(TSNode node, const char *source_code, uint32_t source_length,
|
|
std::vector<FunctionInfo> &functions, FileType file_type) {
|
|
const char *type = ts_node_type(node);
|
|
|
|
if (strcmp(type, "function_definition") == 0 ||
|
|
strcmp(type, "declaration") == 0) {
|
|
std::string func_name = getFunctionName(node, source_code);
|
|
if (!func_name.empty()) {
|
|
std::string address =
|
|
extractAddress(getComment(node, source_code, source_length, false));
|
|
|
|
if (address.empty() && strcmp(type, "function_definition") == 0) {
|
|
address =
|
|
extractAddress(getComment(node, source_code, source_length, true));
|
|
}
|
|
|
|
if (!address.empty()) {
|
|
// Detect calling convention
|
|
CallingConvention calling_conv = getCallingConvention(node, source_code);
|
|
|
|
// Extract parameter information and return type
|
|
std::string parameter_names = getParameterNames(node, source_code);
|
|
std::string parameter_types = getParameterTypes(node, source_code);
|
|
std::string return_type = getReturnType(node, source_code);
|
|
|
|
FunctionInfo func;
|
|
func.name = func_name;
|
|
func.address = address;
|
|
func.filepath = "";
|
|
func.parameter_names = parameter_names;
|
|
func.parameter_types = parameter_types;
|
|
func.return_type = return_type;
|
|
func.is_import = strcmp(type, "function_definition") == 0 ? !hasFunctionBody(node) : true;
|
|
func.type = file_type;
|
|
func.calling_convention = calling_conv;
|
|
|
|
spdlog::debug("Found function: {} at {} with calling convention: {}, return type: {}, param names: {}, param types: {}",
|
|
func_name, address, callingConventionToString(calling_conv), return_type, parameter_names, parameter_types);
|
|
|
|
functions.push_back(func);
|
|
}
|
|
// We'll never nest function declarations
|
|
return;
|
|
} else {
|
|
spdlog::error("Failed to get function name for {}",
|
|
extractNodeText(node, source_code));
|
|
}
|
|
}
|
|
|
|
uint32_t child_count = ts_node_child_count(node);
|
|
for (uint32_t i = 0; i < child_count; i++) {
|
|
findFunctions(ts_node_child(node, i), source_code, source_length, functions,
|
|
file_type);
|
|
}
|
|
}
|
|
|
|
void findGlobals(TSNode node, const char *source_code, uint32_t source_length,
|
|
std::vector<GlobalInfo> &globals) {
|
|
const char *type = ts_node_type(node);
|
|
|
|
// Look for extern declarations
|
|
if (strcmp(type, "declaration") == 0) {
|
|
// Check if this is an extern declaration
|
|
uint32_t child_count = ts_node_child_count(node);
|
|
bool is_extern = false;
|
|
|
|
for (uint32_t i = 0; i < child_count; i++) {
|
|
TSNode child = ts_node_child(node, i);
|
|
if (strcmp(ts_node_type(child), "storage_class_specifier") == 0) {
|
|
std::string storage_class = extractNodeText(child, source_code);
|
|
if (storage_class == "extern") {
|
|
is_extern = true;
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
if (is_extern) {
|
|
std::string global_name = getGlobalName(node, source_code);
|
|
if (!global_name.empty()) {
|
|
// Look for address comment after the declaration
|
|
std::string address =
|
|
extractAddress(getComment(node, source_code, source_length, false));
|
|
|
|
if (!address.empty()) {
|
|
GlobalInfo global{global_name, address, ""};
|
|
globals.push_back(global);
|
|
}
|
|
} else {
|
|
std::string src = extractNodeText(node, source_code);
|
|
spdlog::error("Failed to get global name for {}", src);
|
|
}
|
|
return;
|
|
}
|
|
}
|
|
|
|
// Recursively search child nodes
|
|
uint32_t child_count = ts_node_child_count(node);
|
|
for (uint32_t i = 0; i < child_count; i++) {
|
|
findGlobals(ts_node_child(node, i), source_code, source_length, globals);
|
|
}
|
|
}
|
|
|
|
// Add helper functions to extract parameter information and return type
|
|
std::string getParameterNames(TSNode node, const char *source_code) {
|
|
std::vector<std::string> names;
|
|
uint32_t child_count = ts_node_child_count(node);
|
|
|
|
for (uint32_t i = 0; i < child_count; i++) {
|
|
TSNode child = ts_node_child(node, i);
|
|
const char *type = ts_node_type(child);
|
|
|
|
if (strcmp(type, "function_declarator") == 0) {
|
|
// Look for parameter_list within function_declarator
|
|
uint32_t func_child_count = ts_node_child_count(child);
|
|
for (uint32_t j = 0; j < func_child_count; j++) {
|
|
TSNode func_child = ts_node_child(child, j);
|
|
if (strcmp(ts_node_type(func_child), "parameter_list") == 0) {
|
|
// Extract each parameter_declaration
|
|
uint32_t param_list_count = ts_node_child_count(func_child);
|
|
for (uint32_t k = 0; k < param_list_count; k++) {
|
|
TSNode param_node = ts_node_child(func_child, k);
|
|
if (strcmp(ts_node_type(param_node), "parameter_declaration") == 0) {
|
|
// Find the identifier (parameter name) in this parameter
|
|
std::string param_name = findIdentifierInDeclarator(param_node, source_code);
|
|
if (!param_name.empty()) {
|
|
names.push_back(param_name);
|
|
}
|
|
}
|
|
}
|
|
break;
|
|
}
|
|
}
|
|
break;
|
|
}
|
|
}
|
|
|
|
// Join names with semicolon
|
|
std::string result;
|
|
for (size_t i = 0; i < names.size(); i++) {
|
|
if (i > 0) result += ";";
|
|
result += names[i];
|
|
}
|
|
return result;
|
|
}
|
|
|
|
std::string getParameterTypes(TSNode node, const char *source_code) {
|
|
std::vector<std::string> types;
|
|
uint32_t child_count = ts_node_child_count(node);
|
|
|
|
for (uint32_t i = 0; i < child_count; i++) {
|
|
TSNode child = ts_node_child(node, i);
|
|
const char *type = ts_node_type(child);
|
|
|
|
if (strcmp(type, "function_declarator") == 0) {
|
|
// Look for parameter_list within function_declarator
|
|
uint32_t func_child_count = ts_node_child_count(child);
|
|
for (uint32_t j = 0; j < func_child_count; j++) {
|
|
TSNode func_child = ts_node_child(child, j);
|
|
if (strcmp(ts_node_type(func_child), "parameter_list") == 0) {
|
|
// Extract each parameter_declaration
|
|
uint32_t param_list_count = ts_node_child_count(func_child);
|
|
for (uint32_t k = 0; k < param_list_count; k++) {
|
|
TSNode param_node = ts_node_child(func_child, k);
|
|
if (strcmp(ts_node_type(param_node), "parameter_declaration") == 0) {
|
|
// Extract type part by getting everything except the identifier
|
|
std::string full_param = extractNodeText(param_node, source_code);
|
|
std::string param_name = findIdentifierInDeclarator(param_node, source_code);
|
|
|
|
// Remove the parameter name from the full parameter text to get just the type
|
|
std::string param_type = full_param;
|
|
if (!param_name.empty()) {
|
|
size_t name_pos = param_type.rfind(param_name);
|
|
if (name_pos != std::string::npos) {
|
|
param_type = param_type.substr(0, name_pos);
|
|
// Clean up trailing whitespace and pointer/reference symbols after the type
|
|
param_type.erase(param_type.find_last_not_of(" \t\n\r*&") + 1);
|
|
|
|
// Add back any pointer/reference symbols that were part of the type
|
|
size_t type_end = name_pos;
|
|
while (type_end < full_param.length() &&
|
|
(full_param[type_end] == ' ' || full_param[type_end] == '\t')) {
|
|
type_end++;
|
|
}
|
|
if (type_end > name_pos + param_name.length()) {
|
|
param_type += full_param.substr(name_pos + param_name.length(),
|
|
type_end - (name_pos + param_name.length()));
|
|
}
|
|
}
|
|
}
|
|
|
|
// Clean up any extra whitespace
|
|
param_type.erase(0, param_type.find_first_not_of(" \t\n\r"));
|
|
param_type.erase(param_type.find_last_not_of(" \t\n\r") + 1);
|
|
|
|
if (!param_type.empty()) {
|
|
types.push_back(param_type);
|
|
}
|
|
}
|
|
}
|
|
break;
|
|
}
|
|
}
|
|
break;
|
|
}
|
|
}
|
|
|
|
// Join types with semicolon
|
|
std::string result;
|
|
for (size_t i = 0; i < types.size(); i++) {
|
|
if (i > 0) result += ";";
|
|
result += types[i];
|
|
}
|
|
return result;
|
|
}
|
|
|
|
std::string getReturnType(TSNode node, const char *source_code) {
|
|
uint32_t child_count = ts_node_child_count(node);
|
|
|
|
for (uint32_t i = 0; i < child_count; i++) {
|
|
TSNode child = ts_node_child(node, i);
|
|
const char *type = ts_node_type(child);
|
|
|
|
// Look for type specifiers (return type comes before function declarator)
|
|
if (strcmp(type, "type_identifier") == 0 ||
|
|
strcmp(type, "primitive_type") == 0 ||
|
|
strcmp(type, "sized_type_specifier") == 0) {
|
|
return extractNodeText(child, source_code);
|
|
}
|
|
// Handle more complex types
|
|
else if (strcmp(type, "qualified_identifier") == 0 ||
|
|
strcmp(type, "template_type") == 0) {
|
|
return extractNodeText(child, source_code);
|
|
}
|
|
}
|
|
|
|
return "void"; // Default return type
|
|
}
|