From cfb41094d5216ababba8363e9fb8787d86b1e6bd Mon Sep 17 00:00:00 2001 From: Guus Waals <_@guusw.nl> Date: Sun, 1 Jun 2025 23:18:47 +0800 Subject: [PATCH] Add parameter names and types --- tooling/database.cpp | 16 +++-- tooling/parser.cpp | 161 ++++++++++++++++++++++++++++++++++++++++--- tooling/tool.hpp | 6 ++ 3 files changed, 170 insertions(+), 13 deletions(-) diff --git a/tooling/database.cpp b/tooling/database.cpp index 597471ab..ac51ea4c 100644 --- a/tooling/database.cpp +++ b/tooling/database.cpp @@ -27,11 +27,11 @@ public: &delete_imports_stmt, "Failed to prepare delete imports statement"); prepareStatement("INSERT OR REPLACE INTO Functions (filepath, name, " - "address, type, calling_convention) VALUES (?, ?, ?, ?, ?)", + "address, type, calling_convention, parameter_names, parameter_types, return_type) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", &insert_functions_stmt, "Failed to prepare insert functions statement"); prepareStatement("INSERT OR REPLACE INTO Imports (filepath, name, address, " - "type, calling_convention) VALUES (?, ?, ?, ?, ?)", + "type, calling_convention, parameter_names, parameter_types, return_type) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", &insert_imports_stmt, "Failed to prepare insert imports statement"); prepareStatement("DELETE FROM Globals WHERE filepath = ?", @@ -80,8 +80,8 @@ DatabaseManager::DatabaseManager(const std::string &db_path) : db(nullptr) { } const char *create_tables = R"( - CREATE TABLE IF NOT EXISTS Functions (filepath TEXT, name TEXT, address TEXT, type INTEGER DEFAULT 0, calling_convention INTEGER DEFAULT 0, PRIMARY KEY (name, filepath)); - CREATE TABLE IF NOT EXISTS Imports (filepath TEXT, name TEXT, address TEXT, type INTEGER DEFAULT 0, calling_convention INTEGER DEFAULT 0, PRIMARY KEY (name, filepath)); + CREATE TABLE IF NOT EXISTS Functions (filepath TEXT, name TEXT, address TEXT, type INTEGER DEFAULT 0, calling_convention INTEGER DEFAULT 0, parameter_names TEXT DEFAULT '', parameter_types TEXT DEFAULT '', return_type TEXT DEFAULT '', PRIMARY KEY (name, filepath)); + CREATE TABLE IF NOT EXISTS Imports (filepath TEXT, name TEXT, address TEXT, type INTEGER DEFAULT 0, calling_convention INTEGER DEFAULT 0, parameter_names TEXT DEFAULT '', parameter_types TEXT DEFAULT '', return_type TEXT DEFAULT '', PRIMARY KEY (name, filepath)); CREATE TABLE IF NOT EXISTS Globals (filepath TEXT, name TEXT, address TEXT); )"; @@ -119,6 +119,9 @@ void DatabaseManager::insertFunction(const FunctionInfo &func) { sqlite3_bind_text(stmt, 3, func.address.c_str(), -1, SQLITE_STATIC); sqlite3_bind_int(stmt, 4, static_cast(func.type)); sqlite3_bind_int(stmt, 5, static_cast(func.calling_convention)); + sqlite3_bind_text(stmt, 6, func.parameter_names.c_str(), -1, SQLITE_STATIC); + sqlite3_bind_text(stmt, 7, func.parameter_types.c_str(), -1, SQLITE_STATIC); + sqlite3_bind_text(stmt, 8, func.return_type.c_str(), -1, SQLITE_STATIC); sqlite3_step(stmt); } @@ -241,7 +244,7 @@ std::vector DatabaseManager::getFunctionsByType(FileType type) { std::vector functions; const char *sql = R"( - SELECT name, address, filepath, calling_convention + SELECT name, address, filepath, calling_convention, parameter_names, parameter_types, return_type FROM Functions WHERE type = ? AND address != '' ORDER BY address; @@ -261,6 +264,9 @@ std::vector DatabaseManager::getFunctionsByType(FileType type) { func.address = (const char *)sqlite3_column_text(stmt, 1); func.filepath = (const char *)sqlite3_column_text(stmt, 2); func.calling_convention = static_cast(sqlite3_column_int(stmt, 3)); + func.parameter_names = (const char *)sqlite3_column_text(stmt, 4); + func.parameter_types = (const char *)sqlite3_column_text(stmt, 5); + func.return_type = (const char *)sqlite3_column_text(stmt, 6); func.type = type; func.is_import = false; // Functions table contains non-imports diff --git a/tooling/parser.cpp b/tooling/parser.cpp index a844d05e..3b5e134f 100644 --- a/tooling/parser.cpp +++ b/tooling/parser.cpp @@ -281,15 +281,24 @@ void findFunctions(TSNode node, const char *source_code, uint32_t source_length, // Detect calling convention CallingConvention calling_conv = getCallingConvention(node, source_code); - FunctionInfo func{func_name, address, "", - strcmp(type, "function_definition") == 0 - ? !hasFunctionBody(node) - : true, - file_type}; // Add file_type parameter - func.calling_convention = calling_conv; // Set the calling convention + // Extract parameter information and return type + std::string parameter_names = getParameterNames(node, source_code); + std::string parameter_types = getParameterTypes(node, source_code); + std::string return_type = getReturnType(node, source_code); - spdlog::debug("Found function: {} at {} with calling convention: {}", - func_name, address, callingConventionToString(calling_conv)); + FunctionInfo func; + func.name = func_name; + func.address = address; + func.filepath = ""; + func.parameter_names = parameter_names; + func.parameter_types = parameter_types; + func.return_type = return_type; + func.is_import = strcmp(type, "function_definition") == 0 ? !hasFunctionBody(node) : true; + func.type = file_type; + func.calling_convention = calling_conv; + + spdlog::debug("Found function: {} at {} with calling convention: {}, return type: {}, param names: {}, param types: {}", + func_name, address, callingConventionToString(calling_conv), return_type, parameter_names, parameter_types); functions.push_back(func); } @@ -354,3 +363,139 @@ void findGlobals(TSNode node, const char *source_code, uint32_t source_length, findGlobals(ts_node_child(node, i), source_code, source_length, globals); } } + +// Add helper functions to extract parameter information and return type +std::string getParameterNames(TSNode node, const char *source_code) { + std::vector names; + uint32_t child_count = ts_node_child_count(node); + + for (uint32_t i = 0; i < child_count; i++) { + TSNode child = ts_node_child(node, i); + const char *type = ts_node_type(child); + + if (strcmp(type, "function_declarator") == 0) { + // Look for parameter_list within function_declarator + uint32_t func_child_count = ts_node_child_count(child); + for (uint32_t j = 0; j < func_child_count; j++) { + TSNode func_child = ts_node_child(child, j); + if (strcmp(ts_node_type(func_child), "parameter_list") == 0) { + // Extract each parameter_declaration + uint32_t param_list_count = ts_node_child_count(func_child); + for (uint32_t k = 0; k < param_list_count; k++) { + TSNode param_node = ts_node_child(func_child, k); + if (strcmp(ts_node_type(param_node), "parameter_declaration") == 0) { + // Find the identifier (parameter name) in this parameter + std::string param_name = findIdentifierInDeclarator(param_node, source_code); + if (!param_name.empty()) { + names.push_back(param_name); + } + } + } + break; + } + } + break; + } + } + + // Join names with semicolon + std::string result; + for (size_t i = 0; i < names.size(); i++) { + if (i > 0) result += ";"; + result += names[i]; + } + return result; +} + +std::string getParameterTypes(TSNode node, const char *source_code) { + std::vector types; + uint32_t child_count = ts_node_child_count(node); + + for (uint32_t i = 0; i < child_count; i++) { + TSNode child = ts_node_child(node, i); + const char *type = ts_node_type(child); + + if (strcmp(type, "function_declarator") == 0) { + // Look for parameter_list within function_declarator + uint32_t func_child_count = ts_node_child_count(child); + for (uint32_t j = 0; j < func_child_count; j++) { + TSNode func_child = ts_node_child(child, j); + if (strcmp(ts_node_type(func_child), "parameter_list") == 0) { + // Extract each parameter_declaration + uint32_t param_list_count = ts_node_child_count(func_child); + for (uint32_t k = 0; k < param_list_count; k++) { + TSNode param_node = ts_node_child(func_child, k); + if (strcmp(ts_node_type(param_node), "parameter_declaration") == 0) { + // Extract type part by getting everything except the identifier + std::string full_param = extractNodeText(param_node, source_code); + std::string param_name = findIdentifierInDeclarator(param_node, source_code); + + // Remove the parameter name from the full parameter text to get just the type + std::string param_type = full_param; + if (!param_name.empty()) { + size_t name_pos = param_type.rfind(param_name); + if (name_pos != std::string::npos) { + param_type = param_type.substr(0, name_pos); + // Clean up trailing whitespace and pointer/reference symbols after the type + param_type.erase(param_type.find_last_not_of(" \t\n\r*&") + 1); + + // Add back any pointer/reference symbols that were part of the type + size_t type_end = name_pos; + while (type_end < full_param.length() && + (full_param[type_end] == ' ' || full_param[type_end] == '\t')) { + type_end++; + } + if (type_end > name_pos + param_name.length()) { + param_type += full_param.substr(name_pos + param_name.length(), + type_end - (name_pos + param_name.length())); + } + } + } + + // Clean up any extra whitespace + param_type.erase(0, param_type.find_first_not_of(" \t\n\r")); + param_type.erase(param_type.find_last_not_of(" \t\n\r") + 1); + + if (!param_type.empty()) { + types.push_back(param_type); + } + } + } + break; + } + } + break; + } + } + + // Join types with semicolon + std::string result; + for (size_t i = 0; i < types.size(); i++) { + if (i > 0) result += ";"; + result += types[i]; + } + return result; +} + +std::string getReturnType(TSNode node, const char *source_code) { + uint32_t child_count = ts_node_child_count(node); + + for (uint32_t i = 0; i < child_count; i++) { + TSNode child = ts_node_child(node, i); + const char *type = ts_node_type(child); + + // Look for type specifiers (return type comes before function declarator) + if (strcmp(type, "type_identifier") == 0 || + strcmp(type, "primitive_type") == 0 || + strcmp(type, "sized_type_specifier") == 0) { + return extractNodeText(child, source_code); + } + // Handle more complex types + else if (strcmp(type, "qualified_identifier") == 0 || + strcmp(type, "template_type") == 0) { + return extractNodeText(child, source_code); + } + } + + return "void"; // Default return type +} diff --git a/tooling/tool.hpp b/tooling/tool.hpp index d0552a90..2b705258 100644 --- a/tooling/tool.hpp +++ b/tooling/tool.hpp @@ -22,6 +22,9 @@ struct FunctionInfo { std::string name; std::string address; std::string filepath; + std::string parameter_names; // Semicolon-separated parameter names + std::string parameter_types; // Semicolon-separated parameter types + std::string return_type; // Function return type bool is_import; FileType type; CallingConvention calling_convention = CallingConvention::Cdecl; // Default to cdecl @@ -50,6 +53,9 @@ std::string findIdentifierInDeclarator(TSNode node, const char *source_code); std::string getFunctionName(TSNode node, const char *source_code); std::string getGlobalName(TSNode node, const char *source_code); std::string getComment(TSNode node, const char *source_code, uint32_t source_length, bool search_before); +std::string getParameterNames(TSNode node, const char *source_code); +std::string getParameterTypes(TSNode node, const char *source_code); +std::string getReturnType(TSNode node, const char *source_code); bool hasFunctionBody(TSNode node); void findFunctions(TSNode node, const char *source_code, uint32_t source_length, std::vector &functions, FileType file_type);