Fix spacing

This commit is contained in:
Guus Waals 2025-05-27 00:19:30 +08:00
parent dfe339bcd3
commit e4325de092
1 changed files with 348 additions and 291 deletions

View File

@ -10,6 +10,7 @@
#include <memory> #include <memory>
#include <tree_sitter/api.h> #include <tree_sitter/api.h>
#include <CLI11.hpp> #include <CLI11.hpp>
extern "C" TSLanguage *tree_sitter_cpp(); extern "C" TSLanguage *tree_sitter_cpp();
struct FunctionInfo { struct FunctionInfo {
@ -21,24 +22,35 @@ struct FunctionInfo {
class PreparedStatements { class PreparedStatements {
private: private:
sqlite3* db; sqlite3 *db;
sqlite3_stmt* delete_functions_stmt; sqlite3_stmt *delete_functions_stmt;
sqlite3_stmt* delete_imports_stmt; sqlite3_stmt *delete_imports_stmt;
sqlite3_stmt* insert_functions_stmt; sqlite3_stmt *insert_functions_stmt;
sqlite3_stmt* insert_imports_stmt; sqlite3_stmt *insert_imports_stmt;
void prepareStatement(const char* sql, sqlite3_stmt** stmt, const std::string& error_msg) { void prepareStatement(const char *sql, sqlite3_stmt **stmt,
const std::string &error_msg) {
if (sqlite3_prepare_v2(db, sql, -1, stmt, nullptr) != SQLITE_OK) { if (sqlite3_prepare_v2(db, sql, -1, stmt, nullptr) != SQLITE_OK) {
throw std::runtime_error(error_msg + ": " + sqlite3_errmsg(db)); throw std::runtime_error(error_msg + ": " + sqlite3_errmsg(db));
} }
} }
public: public:
PreparedStatements(sqlite3* database) : db(database) { PreparedStatements(sqlite3 *database) : db(database) {
prepareStatement("DELETE FROM Functions WHERE filepath = ?", &delete_functions_stmt, "Failed to prepare delete functions statement"); prepareStatement("DELETE FROM Functions WHERE filepath = ?",
prepareStatement("DELETE FROM Imports WHERE filepath = ?", &delete_imports_stmt, "Failed to prepare delete imports statement"); &delete_functions_stmt,
prepareStatement("INSERT OR REPLACE INTO Functions (filepath, name, address) VALUES (?, ?, ?)", &insert_functions_stmt, "Failed to prepare insert functions statement"); "Failed to prepare delete functions statement");
prepareStatement("INSERT OR REPLACE INTO Imports (filepath, name, address) VALUES (?, ?, ?)", &insert_imports_stmt, "Failed to prepare insert imports statement"); prepareStatement("DELETE FROM Imports WHERE filepath = ?",
&delete_imports_stmt,
"Failed to prepare delete imports statement");
prepareStatement("INSERT OR REPLACE INTO Functions (filepath, name, "
"address) VALUES (?, ?, ?)",
&insert_functions_stmt,
"Failed to prepare insert functions statement");
prepareStatement("INSERT OR REPLACE INTO Imports (filepath, name, address) "
"VALUES (?, ?, ?)",
&insert_imports_stmt,
"Failed to prepare insert imports statement");
} }
~PreparedStatements() { ~PreparedStatements() {
@ -48,7 +60,7 @@ public:
sqlite3_finalize(insert_imports_stmt); sqlite3_finalize(insert_imports_stmt);
} }
void clearEntriesForFile(const std::string& filepath) { void clearEntriesForFile(const std::string &filepath) {
for (auto stmt : {delete_functions_stmt, delete_imports_stmt}) { for (auto stmt : {delete_functions_stmt, delete_imports_stmt}) {
sqlite3_reset(stmt); sqlite3_reset(stmt);
sqlite3_bind_text(stmt, 1, filepath.c_str(), -1, SQLITE_STATIC); sqlite3_bind_text(stmt, 1, filepath.c_str(), -1, SQLITE_STATIC);
@ -56,8 +68,9 @@ public:
} }
} }
void insertFunction(const FunctionInfo& func) { void insertFunction(const FunctionInfo &func) {
sqlite3_stmt* stmt = func.is_import ? insert_imports_stmt : insert_functions_stmt; sqlite3_stmt *stmt =
func.is_import ? insert_imports_stmt : insert_functions_stmt;
sqlite3_reset(stmt); sqlite3_reset(stmt);
sqlite3_bind_text(stmt, 1, func.filepath.c_str(), -1, SQLITE_STATIC); sqlite3_bind_text(stmt, 1, func.filepath.c_str(), -1, SQLITE_STATIC);
sqlite3_bind_text(stmt, 2, func.name.c_str(), -1, SQLITE_STATIC); sqlite3_bind_text(stmt, 2, func.name.c_str(), -1, SQLITE_STATIC);
@ -68,18 +81,18 @@ public:
class DatabaseManager { class DatabaseManager {
private: private:
sqlite3* db; sqlite3 *db;
std::unique_ptr<PreparedStatements> prepared_stmts; std::unique_ptr<PreparedStatements> prepared_stmts;
public: public:
DatabaseManager(const std::string& db_path) : db(nullptr) { DatabaseManager(const std::string &db_path) : db(nullptr) {
if (sqlite3_open(db_path.c_str(), &db) != SQLITE_OK) { if (sqlite3_open(db_path.c_str(), &db) != SQLITE_OK) {
std::cerr << "Can't open database: " << sqlite3_errmsg(db) << std::endl; std::cerr << "Can't open database: " << sqlite3_errmsg(db) << std::endl;
sqlite3_close(db); sqlite3_close(db);
throw std::runtime_error("Failed to open database"); throw std::runtime_error("Failed to open database");
} }
const char* create_tables = R"( const char *create_tables = R"(
CREATE TABLE IF NOT EXISTS Functions (filepath TEXT, name TEXT, address TEXT, PRIMARY KEY (name, filepath)); CREATE TABLE IF NOT EXISTS Functions (filepath TEXT, name TEXT, address TEXT, PRIMARY KEY (name, filepath));
CREATE TABLE IF NOT EXISTS Imports (filepath TEXT, name TEXT, address TEXT, PRIMARY KEY (name, filepath)); CREATE TABLE IF NOT EXISTS Imports (filepath TEXT, name TEXT, address TEXT, PRIMARY KEY (name, filepath));
)"; )";
@ -89,28 +102,39 @@ public:
} }
~DatabaseManager() { ~DatabaseManager() {
if (db) sqlite3_close(db); if (db)
sqlite3_close(db);
} }
void clearEntriesForFile(const std::string& filepath) { prepared_stmts->clearEntriesForFile(filepath); } void clearEntriesForFile(const std::string &filepath) {
void insertFunction(const FunctionInfo& func) { prepared_stmts->insertFunction(func); } prepared_stmts->clearEntriesForFile(filepath);
void beginTransaction() { sqlite3_exec(db, "BEGIN TRANSACTION", nullptr, nullptr, nullptr); } }
void commitTransaction() { sqlite3_exec(db, "COMMIT", nullptr, nullptr, nullptr); } void insertFunction(const FunctionInfo &func) {
void rollbackTransaction() { sqlite3_exec(db, "ROLLBACK", nullptr, nullptr, nullptr); } prepared_stmts->insertFunction(func);
}
void beginTransaction() {
sqlite3_exec(db, "BEGIN TRANSACTION", nullptr, nullptr, nullptr);
}
void commitTransaction() {
sqlite3_exec(db, "COMMIT", nullptr, nullptr, nullptr);
}
void rollbackTransaction() {
sqlite3_exec(db, "ROLLBACK", nullptr, nullptr, nullptr);
}
}; };
std::string extractAddress(const std::string& comment) { std::string extractAddress(const std::string &comment) {
std::regex addr_regex(R"(//\s*([0-9a-fA-F]{8}))"); std::regex addr_regex(R"(//\s*([0-9a-fA-F]{8}))");
std::smatch match; std::smatch match;
return std::regex_search(comment, match, addr_regex) ? match[1].str() : ""; return std::regex_search(comment, match, addr_regex) ? match[1].str() : "";
} }
std::string getFunctionName(TSNode node, const char* source_code) { std::string getFunctionName(TSNode node, const char *source_code) {
uint32_t child_count = ts_node_child_count(node); uint32_t child_count = ts_node_child_count(node);
for (uint32_t i = 0; i < child_count; i++) { for (uint32_t i = 0; i < child_count; i++) {
TSNode child = ts_node_child(node, i); TSNode child = ts_node_child(node, i);
const char* type = ts_node_type(child); const char *type = ts_node_type(child);
if (strcmp(type, "function_declarator") == 0) { if (strcmp(type, "function_declarator") == 0) {
uint32_t declarator_children = ts_node_child_count(child); uint32_t declarator_children = ts_node_child_count(child);
@ -122,8 +146,7 @@ std::string getFunctionName(TSNode node, const char* source_code) {
return std::string(source_code + start, end - start); return std::string(source_code + start, end - start);
} }
} }
} } else if (strcmp(type, "identifier") == 0) {
else if (strcmp(type, "identifier") == 0) {
uint32_t start = ts_node_start_byte(child); uint32_t start = ts_node_start_byte(child);
uint32_t end = ts_node_end_byte(child); uint32_t end = ts_node_end_byte(child);
return std::string(source_code + start, end - start); return std::string(source_code + start, end - start);
@ -132,10 +155,12 @@ std::string getFunctionName(TSNode node, const char* source_code) {
return ""; return "";
} }
std::string getComment(TSNode node, const char* source_code, uint32_t source_length, bool search_before) { std::string getComment(TSNode node, const char *source_code,
uint32_t source_length, bool search_before) {
if (search_before) { if (search_before) {
uint32_t start_byte = ts_node_start_byte(node); uint32_t start_byte = ts_node_start_byte(node);
if (start_byte == 0) return ""; if (start_byte == 0)
return "";
std::string before_text(source_code, start_byte); std::string before_text(source_code, start_byte);
std::regex addr_regex(R"(//\s*([0-9a-fA-F]{8}))"); std::regex addr_regex(R"(//\s*([0-9a-fA-F]{8}))");
@ -144,16 +169,22 @@ std::string getComment(TSNode node, const char* source_code, uint32_t source_len
size_t search_pos = before_text.length(); size_t search_pos = before_text.length();
while (search_pos > 0) { while (search_pos > 0) {
size_t comment_pos = before_text.rfind("//", search_pos - 1); size_t comment_pos = before_text.rfind("//", search_pos - 1);
if (comment_pos == std::string::npos) break; if (comment_pos == std::string::npos)
break;
size_t line_end = before_text.find('\n', comment_pos); size_t line_end = before_text.find('\n', comment_pos);
if (line_end == std::string::npos) line_end = before_text.length(); if (line_end == std::string::npos)
line_end = before_text.length();
std::string comment_line = before_text.substr(comment_pos, line_end - comment_pos); std::string comment_line =
before_text.substr(comment_pos, line_end - comment_pos);
if (std::regex_search(comment_line, match, addr_regex)) { if (std::regex_search(comment_line, match, addr_regex)) {
size_t newlines_between = std::count(before_text.begin() + comment_pos, before_text.begin() + start_byte, '\n'); size_t newlines_between =
if (newlines_between <= 20) return comment_line; std::count(before_text.begin() + comment_pos,
before_text.begin() + start_byte, '\n');
if (newlines_between <= 20)
return comment_line;
} }
search_pos = comment_pos; search_pos = comment_pos;
} }
@ -164,7 +195,8 @@ std::string getComment(TSNode node, const char* source_code, uint32_t source_len
size_t comment_pos = remaining.find("//"); size_t comment_pos = remaining.find("//");
if (comment_pos != std::string::npos) { if (comment_pos != std::string::npos) {
size_t line_end = remaining.find('\n', comment_pos); size_t line_end = remaining.find('\n', comment_pos);
if (line_end == std::string::npos) line_end = remaining.length(); if (line_end == std::string::npos)
line_end = remaining.length();
return remaining.substr(comment_pos, line_end - comment_pos); return remaining.substr(comment_pos, line_end - comment_pos);
} }
} }
@ -172,32 +204,40 @@ std::string getComment(TSNode node, const char* source_code, uint32_t source_len
} }
bool hasFunctionBody(TSNode node) { bool hasFunctionBody(TSNode node) {
if (strcmp(ts_node_type(node), "function_definition") != 0) return false; if (strcmp(ts_node_type(node), "function_definition") != 0)
return false;
uint32_t child_count = ts_node_child_count(node); uint32_t child_count = ts_node_child_count(node);
for (uint32_t i = 0; i < child_count; i++) { for (uint32_t i = 0; i < child_count; i++) {
if (strcmp(ts_node_type(ts_node_child(node, i)), "compound_statement") == 0) { if (strcmp(ts_node_type(ts_node_child(node, i)), "compound_statement") ==
0) {
return true; return true;
} }
} }
return false; return false;
} }
void findFunctions(TSNode node, const char* source_code, uint32_t source_length, std::vector<FunctionInfo>& functions) { void findFunctions(TSNode node, const char *source_code, uint32_t source_length,
const char* type = ts_node_type(node); std::vector<FunctionInfo> &functions) {
const char *type = ts_node_type(node);
if (strcmp(type, "function_definition") == 0 || strcmp(type, "declaration") == 0) { if (strcmp(type, "function_definition") == 0 ||
strcmp(type, "declaration") == 0) {
std::string func_name = getFunctionName(node, source_code); std::string func_name = getFunctionName(node, source_code);
if (!func_name.empty()) { if (!func_name.empty()) {
std::string address = extractAddress(getComment(node, source_code, source_length, false)); std::string address =
extractAddress(getComment(node, source_code, source_length, false));
if (address.empty() && strcmp(type, "function_definition") == 0) { if (address.empty() && strcmp(type, "function_definition") == 0) {
address = extractAddress(getComment(node, source_code, source_length, true)); address =
extractAddress(getComment(node, source_code, source_length, true));
} }
if (!address.empty()) { if (!address.empty()) {
FunctionInfo func{func_name, address, "", FunctionInfo func{func_name, address, "",
strcmp(type, "function_definition") == 0 ? !hasFunctionBody(node) : true}; strcmp(type, "function_definition") == 0
? !hasFunctionBody(node)
: true};
functions.push_back(func); functions.push_back(func);
} }
} }
@ -205,11 +245,12 @@ void findFunctions(TSNode node, const char* source_code, uint32_t source_length,
uint32_t child_count = ts_node_child_count(node); uint32_t child_count = ts_node_child_count(node);
for (uint32_t i = 0; i < child_count; i++) { for (uint32_t i = 0; i < child_count; i++) {
findFunctions(ts_node_child(node, i), source_code, source_length, functions); findFunctions(ts_node_child(node, i), source_code, source_length,
functions);
} }
} }
std::vector<std::string> readFileList(const std::string& list_file) { std::vector<std::string> readFileList(const std::string &list_file) {
std::vector<std::string> files; std::vector<std::string> files;
std::ifstream file(list_file); std::ifstream file(list_file);
if (!file.is_open()) { if (!file.is_open()) {
@ -219,7 +260,8 @@ std::vector<std::string> readFileList(const std::string& list_file) {
std::string line; std::string line;
while (std::getline(file, line)) { while (std::getline(file, line)) {
if (line.empty() || line[0] == '#') continue; if (line.empty() || line[0] == '#')
continue;
if (line.find('*') != std::string::npos) { if (line.find('*') != std::string::npos) {
std::cout << "Skipping wildcard pattern: " << line << std::endl; std::cout << "Skipping wildcard pattern: " << line << std::endl;
@ -235,19 +277,21 @@ std::vector<std::string> readFileList(const std::string& list_file) {
return files; return files;
} }
bool processFile(const std::string& filepath, DatabaseManager& db) { bool processFile(const std::string &filepath, DatabaseManager &db) {
std::ifstream file(filepath); std::ifstream file(filepath);
if (!file.is_open()) { if (!file.is_open()) {
std::cerr << "Error: Could not open file " << filepath << std::endl; std::cerr << "Error: Could not open file " << filepath << std::endl;
return false; return false;
} }
std::string file_content((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>()); std::string file_content((std::istreambuf_iterator<char>(file)),
std::istreambuf_iterator<char>());
TSParser *parser = ts_parser_new(); TSParser *parser = ts_parser_new();
ts_parser_set_language(parser, tree_sitter_cpp()); ts_parser_set_language(parser, tree_sitter_cpp());
TSTree *tree = ts_parser_parse_string(parser, nullptr, file_content.c_str(), file_content.length()); TSTree *tree = ts_parser_parse_string(parser, nullptr, file_content.c_str(),
file_content.length());
TSNode root_node = ts_tree_root_node(tree); TSNode root_node = ts_tree_root_node(tree);
if (ts_node_is_null(root_node)) { if (ts_node_is_null(root_node)) {
@ -260,32 +304,38 @@ bool processFile(const std::string& filepath, DatabaseManager& db) {
db.clearEntriesForFile(filepath); db.clearEntriesForFile(filepath);
std::vector<FunctionInfo> functions; std::vector<FunctionInfo> functions;
findFunctions(root_node, file_content.c_str(), file_content.length(), functions); findFunctions(root_node, file_content.c_str(), file_content.length(),
functions);
for (auto& func : functions) { for (auto &func : functions) {
func.filepath = filepath; func.filepath = filepath;
db.insertFunction(func); db.insertFunction(func);
std::cout << (func.is_import ? "Import: " : "Function: ") std::cout << (func.is_import ? "Import: " : "Function: ") << func.name
<< func.name << " @ " << func.address << " in " << filepath << std::endl; << " @ " << func.address << " in " << filepath << std::endl;
} }
std::cout << "Processed " << functions.size() << " functions/imports from " << filepath << std::endl; std::cout << "Processed " << functions.size() << " functions/imports from "
<< filepath << std::endl;
ts_tree_delete(tree); ts_tree_delete(tree);
ts_parser_delete(parser); ts_parser_delete(parser);
return true; return true;
} }
int main(int argc, char* argv[]) { int main(int argc, char *argv[]) {
CLI::App app{"C++ Function Parser - Extracts function addresses from C++ files"}; CLI::App app{
"C++ Function Parser - Extracts function addresses from C++ files"};
std::vector<std::string> input_files; std::vector<std::string> input_files;
std::string list_file; std::string list_file;
std::string db_path = "functions.db"; std::string db_path = "functions.db";
app.add_option("files", input_files, "Input C++ files to parse (supports @listfile.txt syntax)"); app.add_option("files", input_files,
app.add_option("-l,--list", list_file, "File containing list of files to process"); "Input C++ files to parse (supports @listfile.txt syntax)");
app.add_option("-d,--database", db_path, "SQLite database path")->default_val("functions.db"); app.add_option("-l,--list", list_file,
"File containing list of files to process");
app.add_option("-d,--database", db_path, "SQLite database path")
->default_val("functions.db");
CLI11_PARSE(app, argc, argv); CLI11_PARSE(app, argc, argv);
@ -293,13 +343,15 @@ int main(int argc, char* argv[]) {
if (!list_file.empty()) { if (!list_file.empty()) {
auto list_files = readFileList(list_file); auto list_files = readFileList(list_file);
files_to_process.insert(files_to_process.end(), list_files.begin(), list_files.end()); files_to_process.insert(files_to_process.end(), list_files.begin(),
list_files.end());
} }
for (const auto& input : input_files) { for (const auto &input : input_files) {
if (input.starts_with("@")) { if (input.starts_with("@")) {
auto list_files = readFileList(input.substr(1)); auto list_files = readFileList(input.substr(1));
files_to_process.insert(files_to_process.end(), list_files.begin(), list_files.end()); files_to_process.insert(files_to_process.end(), list_files.begin(),
list_files.end());
} else if (std::filesystem::exists(input)) { } else if (std::filesystem::exists(input)) {
files_to_process.push_back(input); files_to_process.push_back(input);
} else { } else {
@ -308,7 +360,8 @@ int main(int argc, char* argv[]) {
} }
if (files_to_process.empty()) { if (files_to_process.empty()) {
std::cerr << "No files to process. Use --help for usage information." << std::endl; std::cerr << "No files to process. Use --help for usage information."
<< std::endl;
return 1; return 1;
} }
@ -320,13 +373,15 @@ int main(int argc, char* argv[]) {
db.beginTransaction(); db.beginTransaction();
for (const auto& filepath : files_to_process) { for (const auto &filepath : files_to_process) {
std::cout << "\n=== Processing: " << filepath << " ===" << std::endl; std::cout << "\n=== Processing: " << filepath << " ===" << std::endl;
if (processFile(filepath, db)) processed_count++; if (processFile(filepath, db))
processed_count++;
if (++current_batch >= batch_size) { if (++current_batch >= batch_size) {
db.commitTransaction(); db.commitTransaction();
std::cout << "Committed batch of " << current_batch << " files to database" << std::endl; std::cout << "Committed batch of " << current_batch
<< " files to database" << std::endl;
db.beginTransaction(); db.beginTransaction();
current_batch = 0; current_batch = 0;
} }
@ -334,14 +389,16 @@ int main(int argc, char* argv[]) {
if (current_batch > 0) { if (current_batch > 0) {
db.commitTransaction(); db.commitTransaction();
std::cout << "Committed final batch of " << current_batch << " files to database" << std::endl; std::cout << "Committed final batch of " << current_batch
<< " files to database" << std::endl;
} }
std::cout << "\n=== Summary ===" << std::endl; std::cout << "\n=== Summary ===" << std::endl;
std::cout << "Processed " << processed_count << " files successfully" << std::endl; std::cout << "Processed " << processed_count << " files successfully"
<< std::endl;
std::cout << "Database saved to: " << db_path << std::endl; std::cout << "Database saved to: " << db_path << std::endl;
} catch (const std::exception& e) { } catch (const std::exception &e) {
std::cerr << "Database error: " << e.what() << std::endl; std::cerr << "Database error: " << e.what() << std::endl;
return 1; return 1;
} }