reman3/tooling/parser.cpp

502 lines
17 KiB
C++

#include "tool.hpp"
#include <cstring>
#include <stdexcept>
#include <spdlog/spdlog.h>
// Global address regex pattern
const std::regex ADDRESS_REGEX(R"(//\s*([0-9a-fA-F]{8}))");
// Helper function to convert string to FileType
FileType stringToFileType(const std::string &type_str) {
if (type_str == "auto")
return FileType::Auto;
if (type_str == "fix")
return FileType::Fix;
if (type_str == "stub")
return FileType::Stub;
if (type_str == "ref")
return FileType::Ref;
throw std::invalid_argument("Invalid file type: " + type_str);
}
// Helper function to convert FileType to string
std::string fileTypeToString(FileType type) {
switch (type) {
case FileType::Auto:
return "auto";
case FileType::Fix:
return "fix";
case FileType::Stub:
return "stub";
case FileType::Ref:
return "ref";
default:
throw std::logic_error("Invalid file type: " + std::to_string((int)type));
}
}
// Helper function to check if a comment contains an address
bool hasAddressPattern(const std::string &comment) {
return std::regex_search(comment, ADDRESS_REGEX);
}
std::string extractAddress(const std::string &comment) {
std::smatch match;
return std::regex_search(comment, match, ADDRESS_REGEX) ? match[1].str() : "";
}
// Helper function to extract text from a TSNode
std::string extractNodeText(TSNode node, const char *source_code) {
uint32_t start = ts_node_start_byte(node);
uint32_t end = ts_node_end_byte(node);
return std::string(source_code + start, end - start);
}
// Helper function to find first identifier in a node
std::string findIdentifierInNode(TSNode node, const char *source_code) {
uint32_t child_count = ts_node_child_count(node);
for (uint32_t i = 0; i < child_count; i++) {
TSNode child = ts_node_child(node, i);
if (strcmp(ts_node_type(child), "identifier") == 0) {
return extractNodeText(child, source_code);
}
}
return "";
}
// Helper function to recursively find identifier in any declarator
std::string findIdentifierInDeclarator(TSNode node, const char *source_code) {
const char *type = ts_node_type(node);
// If this is an identifier, return it
if (strcmp(type, "identifier") == 0) {
return extractNodeText(node, source_code);
}
// Recursively search all children
uint32_t child_count = ts_node_child_count(node);
for (uint32_t i = 0; i < child_count; i++) {
TSNode child = ts_node_child(node, i);
std::string result = findIdentifierInDeclarator(child, source_code);
if (!result.empty()) {
return result;
}
}
return "";
}
std::string getFunctionName(TSNode node, const char *source_code) {
uint32_t child_count = ts_node_child_count(node);
for (uint32_t i = 0; i < child_count; i++) {
TSNode child = ts_node_child(node, i);
const char *type = ts_node_type(child);
if (strcmp(type, "function_declarator") == 0) {
std::string name = findIdentifierInNode(child, source_code);
if (!name.empty())
return name;
} else if (strcmp(type, "identifier") == 0) {
return extractNodeText(child, source_code);
} else if (strcmp(type, "pointer_declarator") == 0) {
std::string name = getFunctionName(child, source_code);
if (!name.empty())
return name;
}
}
return "";
}
std::string getGlobalName(TSNode node, const char *source_code) {
uint32_t child_count = ts_node_child_count(node);
for (uint32_t i = 0; i < child_count; i++) {
TSNode child = ts_node_child(node, i);
const char *type = ts_node_type(child);
// Look for any kind of declarator and recursively search for identifier
if (strcmp(type, "init_declarator") == 0 ||
strcmp(type, "declarator") == 0 ||
strcmp(type, "reference_declarator") == 0 ||
strcmp(type, "pointer_declarator") == 0 ||
strcmp(type, "parenthesized_declarator") == 0 ||
strcmp(type, "array_declarator") == 0) {
std::string name = findIdentifierInDeclarator(child, source_code);
if (!name.empty()) {
return name;
}
}
// Direct identifier child
else if (strcmp(type, "identifier") == 0) {
return extractNodeText(child, source_code);
}
}
return "";
}
std::string getComment(TSNode node, const char *source_code,
uint32_t source_length, bool search_before) {
TSNode current = node;
if (search_before) {
// Look for comments before the current node
while (!ts_node_is_null(current)) {
TSNode prev_sibling = ts_node_prev_sibling(current);
while (!ts_node_is_null(prev_sibling)) {
const char *type = ts_node_type(prev_sibling);
if (strcmp(type, "comment") == 0) {
std::string comment_text = extractNodeText(prev_sibling, source_code);
// Check if it contains an address pattern
if (hasAddressPattern(comment_text)) {
return comment_text;
}
}
// Skip whitespace and continue looking
else if (strcmp(type, "ERROR") != 0) {
// If we hit non-comment, non-whitespace content, stop searching
break;
}
prev_sibling = ts_node_prev_sibling(prev_sibling);
}
// Move up to parent and continue searching
current = ts_node_parent(current);
}
} else {
// Look for comments after the current node
TSNode next_sibling = ts_node_next_sibling(node);
while (!ts_node_is_null(next_sibling)) {
const char *type = ts_node_type(next_sibling);
if (strcmp(type, "comment") == 0) {
std::string comment_text = extractNodeText(next_sibling, source_code);
// Check if it contains an address pattern
if (hasAddressPattern(comment_text)) {
return comment_text;
}
}
// Skip whitespace and continue looking
else if (strcmp(type, "ERROR") != 0) {
// If we hit non-comment, non-whitespace content, stop searching
break;
}
next_sibling = ts_node_next_sibling(next_sibling);
}
}
return "";
}
bool hasFunctionBody(TSNode node) {
if (strcmp(ts_node_type(node), "function_definition") != 0)
return false;
uint32_t child_count = ts_node_child_count(node);
for (uint32_t i = 0; i < child_count; i++) {
if (strcmp(ts_node_type(ts_node_child(node, i)), "compound_statement") ==
0) {
return true;
}
}
return false;
}
// Add this helper function to detect calling convention from a function declaration
CallingConvention getCallingConvention(TSNode node, const char *source_code) {
uint32_t child_count = ts_node_child_count(node);
for (uint32_t i = 0; i < child_count; i++) {
TSNode child = ts_node_child(node, i);
const char *type = ts_node_type(child);
// Handle Microsoft calling convention modifiers (tree-sitter specific)
if (strcmp(type, "ms_call_modifier") == 0) {
std::string text = extractNodeText(child, source_code);
if (text == "__fastcall" || text == "__FASTCALL") {
return CallingConvention::Fastcall;
} else if (text == "__stdcall" || text == "__STDCALL") {
return CallingConvention::Stdcall;
} else if (text == "__cdecl" || text == "__CDECL") {
return CallingConvention::Cdecl;
}
}
// Look for identifiers that might be calling conventions (fallback)
else if (strcmp(type, "identifier") == 0) {
std::string text = extractNodeText(child, source_code);
if (text == "__fastcall" || text == "__FASTCALL") {
return CallingConvention::Fastcall;
} else if (text == "__stdcall" || text == "__STDCALL") {
return CallingConvention::Stdcall;
} else if (text == "__cdecl" || text == "__CDECL") {
return CallingConvention::Cdecl;
}
}
// Also check in type specifiers and declaration specifiers
else if (strcmp(type, "type_specifier") == 0 ||
strcmp(type, "declaration_specifier") == 0 ||
strcmp(type, "specifier_qualifier_list") == 0) {
CallingConvention conv = getCallingConvention(child, source_code);
if (conv != CallingConvention::Cdecl) {
return conv;
}
}
// Check in declarators as well (calling convention can appear in various positions)
else if (strcmp(type, "function_declarator") == 0 ||
strcmp(type, "pointer_declarator") == 0) {
CallingConvention conv = getCallingConvention(child, source_code);
if (conv != CallingConvention::Cdecl) {
return conv;
}
}
}
return CallingConvention::Cdecl; // Default
}
void findFunctions(TSNode node, const char *source_code, uint32_t source_length,
std::vector<FunctionInfo> &functions, FileType file_type) {
const char *type = ts_node_type(node);
if (strcmp(type, "function_definition") == 0 ||
strcmp(type, "declaration") == 0) {
std::string func_name = getFunctionName(node, source_code);
if (!func_name.empty()) {
std::string address =
extractAddress(getComment(node, source_code, source_length, false));
if (address.empty() && strcmp(type, "function_definition") == 0) {
address =
extractAddress(getComment(node, source_code, source_length, true));
}
if (!address.empty()) {
// Detect calling convention
CallingConvention calling_conv = getCallingConvention(node, source_code);
// Extract parameter information and return type
std::string parameter_names = getParameterNames(node, source_code);
std::string parameter_types = getParameterTypes(node, source_code);
std::string return_type = getReturnType(node, source_code);
FunctionInfo func;
func.name = func_name;
func.address = address;
func.filepath = "";
func.parameter_names = parameter_names;
func.parameter_types = parameter_types;
func.return_type = return_type;
func.is_import = strcmp(type, "function_definition") == 0 ? !hasFunctionBody(node) : true;
func.type = file_type;
func.calling_convention = calling_conv;
spdlog::debug("Found function: {} at {} with calling convention: {}, return type: {}, param names: {}, param types: {}",
func_name, address, callingConventionToString(calling_conv), return_type, parameter_names, parameter_types);
functions.push_back(func);
}
// We'll never nest function declarations
return;
} else {
spdlog::error("Failed to get function name for {}",
extractNodeText(node, source_code));
}
}
uint32_t child_count = ts_node_child_count(node);
for (uint32_t i = 0; i < child_count; i++) {
findFunctions(ts_node_child(node, i), source_code, source_length, functions,
file_type);
}
}
void findGlobals(TSNode node, const char *source_code, uint32_t source_length,
std::vector<GlobalInfo> &globals) {
const char *type = ts_node_type(node);
// Look for extern declarations
if (strcmp(type, "declaration") == 0) {
// Check if this is an extern declaration
uint32_t child_count = ts_node_child_count(node);
bool is_extern = false;
for (uint32_t i = 0; i < child_count; i++) {
TSNode child = ts_node_child(node, i);
if (strcmp(ts_node_type(child), "storage_class_specifier") == 0) {
std::string storage_class = extractNodeText(child, source_code);
if (storage_class == "extern") {
is_extern = true;
break;
}
}
}
if (is_extern) {
std::string global_name = getGlobalName(node, source_code);
if (!global_name.empty()) {
// Look for address comment after the declaration
std::string address =
extractAddress(getComment(node, source_code, source_length, false));
if (!address.empty()) {
GlobalInfo global{global_name, address, ""};
globals.push_back(global);
}
} else {
std::string src = extractNodeText(node, source_code);
spdlog::error("Failed to get global name for {}", src);
}
return;
}
}
// Recursively search child nodes
uint32_t child_count = ts_node_child_count(node);
for (uint32_t i = 0; i < child_count; i++) {
findGlobals(ts_node_child(node, i), source_code, source_length, globals);
}
}
// Add helper functions to extract parameter information and return type
std::string getParameterNames(TSNode node, const char *source_code) {
std::vector<std::string> names;
uint32_t child_count = ts_node_child_count(node);
for (uint32_t i = 0; i < child_count; i++) {
TSNode child = ts_node_child(node, i);
const char *type = ts_node_type(child);
if (strcmp(type, "function_declarator") == 0) {
// Look for parameter_list within function_declarator
uint32_t func_child_count = ts_node_child_count(child);
for (uint32_t j = 0; j < func_child_count; j++) {
TSNode func_child = ts_node_child(child, j);
if (strcmp(ts_node_type(func_child), "parameter_list") == 0) {
// Extract each parameter_declaration
uint32_t param_list_count = ts_node_child_count(func_child);
for (uint32_t k = 0; k < param_list_count; k++) {
TSNode param_node = ts_node_child(func_child, k);
if (strcmp(ts_node_type(param_node), "parameter_declaration") == 0) {
// Find the identifier (parameter name) in this parameter
std::string param_name = findIdentifierInDeclarator(param_node, source_code);
if (!param_name.empty()) {
names.push_back(param_name);
}
}
}
break;
}
}
break;
}
}
// Join names with semicolon
std::string result;
for (size_t i = 0; i < names.size(); i++) {
if (i > 0) result += ";";
result += names[i];
}
return result;
}
std::string getParameterTypes(TSNode node, const char *source_code) {
std::vector<std::string> types;
uint32_t child_count = ts_node_child_count(node);
for (uint32_t i = 0; i < child_count; i++) {
TSNode child = ts_node_child(node, i);
const char *type = ts_node_type(child);
if (strcmp(type, "function_declarator") == 0) {
// Look for parameter_list within function_declarator
uint32_t func_child_count = ts_node_child_count(child);
for (uint32_t j = 0; j < func_child_count; j++) {
TSNode func_child = ts_node_child(child, j);
if (strcmp(ts_node_type(func_child), "parameter_list") == 0) {
// Extract each parameter_declaration
uint32_t param_list_count = ts_node_child_count(func_child);
for (uint32_t k = 0; k < param_list_count; k++) {
TSNode param_node = ts_node_child(func_child, k);
if (strcmp(ts_node_type(param_node), "parameter_declaration") == 0) {
// Extract type part by getting everything except the identifier
std::string full_param = extractNodeText(param_node, source_code);
std::string param_name = findIdentifierInDeclarator(param_node, source_code);
// Remove the parameter name from the full parameter text to get just the type
std::string param_type = full_param;
if (!param_name.empty()) {
size_t name_pos = param_type.rfind(param_name);
if (name_pos != std::string::npos) {
param_type = param_type.substr(0, name_pos);
// Clean up trailing whitespace and pointer/reference symbols after the type
param_type.erase(param_type.find_last_not_of(" \t\n\r*&") + 1);
// Add back any pointer/reference symbols that were part of the type
size_t type_end = name_pos;
while (type_end < full_param.length() &&
(full_param[type_end] == ' ' || full_param[type_end] == '\t')) {
type_end++;
}
if (type_end > name_pos + param_name.length()) {
param_type += full_param.substr(name_pos + param_name.length(),
type_end - (name_pos + param_name.length()));
}
}
}
// Clean up any extra whitespace
param_type.erase(0, param_type.find_first_not_of(" \t\n\r"));
param_type.erase(param_type.find_last_not_of(" \t\n\r") + 1);
if (!param_type.empty()) {
types.push_back(param_type);
}
}
}
break;
}
}
break;
}
}
// Join types with semicolon
std::string result;
for (size_t i = 0; i < types.size(); i++) {
if (i > 0) result += ";";
result += types[i];
}
return result;
}
std::string getReturnType(TSNode node, const char *source_code) {
uint32_t child_count = ts_node_child_count(node);
for (uint32_t i = 0; i < child_count; i++) {
TSNode child = ts_node_child(node, i);
const char *type = ts_node_type(child);
// Look for type specifiers (return type comes before function declarator)
if (strcmp(type, "type_identifier") == 0 ||
strcmp(type, "primitive_type") == 0 ||
strcmp(type, "sized_type_specifier") == 0) {
return extractNodeText(child, source_code);
}
// Handle more complex types
else if (strcmp(type, "qualified_identifier") == 0 ||
strcmp(type, "template_type") == 0) {
return extractNodeText(child, source_code);
}
}
return "void"; // Default return type
}