From 86e85f87e3e35ff3b40f699f55b953c1589c7866 Mon Sep 17 00:00:00 2001 From: chn Date: Sat, 4 May 2024 11:57:13 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AE=8C=E5=96=84=E6=A0=A1=E9=AA=8C=E5=8A=9F?= =?UTF-8?q?=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- CMakeLists.txt | 1 + flake.nix | 4 +-- include/hpcstat/sql.hpp | 20 +++++++++++--- src/lfs.cpp | 4 +-- src/main.cpp | 17 +++++++++--- src/sql.cpp | 59 +++++++++++++++++++++++++++++++++++++++-- src/ssh.cpp | 2 +- 7 files changed, 94 insertions(+), 13 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 94b485c..af6c5da 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -15,6 +15,7 @@ find_package(zxorm REQUIRED) find_package(nlohmann_json REQUIRED) find_path(ZPP_BITS_INCLUDE_DIR zpp_bits.h REQUIRED) find_package(range-v3 REQUIRED) +find_path(NAMEOF_INCLUDE_DIR nameof.hpp REQUIRED) add_executable(hpcstat src/main.cpp src/env.cpp src/keys.cpp src/ssh.cpp src/sql.cpp src/lfs.cpp src/common.cpp) target_compile_features(hpcstat PUBLIC cxx_std_23) diff --git a/flake.nix b/flake.nix index 8d15180..54951ad 100644 --- a/flake.nix +++ b/flake.nix @@ -23,7 +23,7 @@ name = "hpcstat"; src = ./.; buildInputs = with pkgs.pkgsStatic; - [ boost fmt localPackages.zxorm nlohmann_json localPackages.zpp-bits range-v3 ]; + [ boost fmt localPackages.zxorm nlohmann_json localPackages.zpp-bits range-v3 localPackages.nameof ]; nativeBuildInputs = with pkgs; [ cmake pkg-config ]; postInstall = "cp ${openssh}/bin/{ssh-add,ssh-keygen} $out/bin"; }; @@ -35,7 +35,7 @@ { nativeBuildInputs = with pkgs; [ pkg-config cmake clang-tools_18 ]; buildInputs = (with pkgs.pkgsStatic; - [ fmt boost localPackages.zxorm nlohmann_json localPackages.zpp-bits range-v3 ]); + [ fmt boost localPackages.zxorm nlohmann_json localPackages.zpp-bits range-v3 localPackages.nameof ]); # hardeningDisable = [ "all" ]; # NIX_DEBUG = "1"; CMAKE_EXPORT_COMPILE_COMMANDS = "1"; diff --git a/include/hpcstat/sql.hpp b/include/hpcstat/sql.hpp index a616b06..651d829 100644 --- a/include/hpcstat/sql.hpp +++ b/include/hpcstat/sql.hpp @@ -12,6 +12,7 @@ namespace hpcstat::sql std::optional Subaccount, Ip; bool Interactive; using serialize = zpp::bits::members<8>; + bool operator==(const LoginData& other) const = default; }; using LoginTable = zxorm::Table < @@ -25,7 +26,13 @@ namespace hpcstat::sql zxorm::Column<"ip", &LoginData::Ip>, zxorm::Column<"interactive", &LoginData::Interactive> >; - struct LogoutData { unsigned Id = 0; long Time; std::string SessionId; }; + struct LogoutData + { + unsigned Id = 0; + long Time; + std::string SessionId; + bool operator==(const LogoutData& other) const = default; + }; using LogoutTable = zxorm::Table < "logout", LogoutData, @@ -41,6 +48,7 @@ namespace hpcstat::sql std::string Key, SessionId, SubmitDir, JobCommand, Signature = ""; std::optional Subaccount, Ip; using serialize = zpp::bits::members<10>; + bool operator==(const SubmitJobData& other) const = default; }; using SubmitJobTable = zxorm::Table < @@ -61,9 +69,10 @@ namespace hpcstat::sql unsigned Id = 0; long Time; unsigned JobId; - std::string JobResult, SubmitTime, JobDetail, Signature = ""; + std::string JobResult, SubmitTime, JobDetail, Key, Signature = ""; double CpuTime; - using serialize = zpp::bits::members<8>; + using serialize = zpp::bits::members<9>; + bool operator==(const FinishJobData& other) const = default; }; using FinishJobTable = zxorm::Table < @@ -74,6 +83,7 @@ namespace hpcstat::sql zxorm::Column<"job_result", &FinishJobData::JobResult>, zxorm::Column<"submit_time", &FinishJobData::SubmitTime>, zxorm::Column<"job_detail", &FinishJobData::JobDetail>, + zxorm::Column<"key", &FinishJobData::Key>, zxorm::Column<"signature", &FinishJobData::Signature>, zxorm::Column<"cpu_time", &FinishJobData::CpuTime> >; @@ -85,4 +95,8 @@ namespace hpcstat::sql bool writedb(auto value); // 查询 bjobs -a 的结果中,有哪些是已经被写入到数据库中的(按照任务 id 和提交时间计算),返回未被写入的任务 id std::optional> finishjob_remove_existed(std::map jobid_submit_time); + // 检查数据库中已经有的数据是否被修改过,如果有修改过,返回 std::nullopt,否则返回新增的数据,用于校验签名 + // 三个字符串分别是序列化后的数据,签名,指纹 + std::optional>> + verify(std::string old_db, std::string new_db); } diff --git a/src/lfs.cpp b/src/lfs.cpp index 37f74b4..babce83 100644 --- a/src/lfs.cpp +++ b/src/lfs.cpp @@ -17,7 +17,7 @@ namespace hpcstat::lfs else { std::set valid_args = { "J", "q", "n", "R", "o" }; - for (auto it = args.begin(); it != args.end(); it++) + for (auto it = args.begin(); it != args.end(); ++it) { if (it->length() > 0 && (*it)[0] == '-') { @@ -29,7 +29,7 @@ namespace hpcstat::lfs "please submit issue on [github](https://github.com/CHN-beta/hpcstat) or contact chn@chn.moe.\n"; return std::nullopt; } - else if (it + 1 != args.end() && ((it + 1)->length() == 0 || (*(it + 1))[0] != '-')) it++; + else if (it + 1 != args.end() && ((it + 1)->length() == 0 || (*(it + 1))[0] != '-')) ++it; } else break; } diff --git a/src/main.cpp b/src/main.cpp index f41ba23..f13024d 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -1,3 +1,4 @@ +# include # include # include # include @@ -10,6 +11,7 @@ int main(int argc, const char** argv) { using namespace hpcstat; + using namespace std::literals; std::vector args(argv, argv + argc); if (args.size() == 1) { std::cout << "Usage: hpcstat initdb|login|logout|submitjob|finishjob\n"; return 1; } @@ -20,6 +22,7 @@ int main(int argc, const char** argv) else if (args[1] == "login") { if (env::interactive()) std::cout << "Communicating with the agent..." << std::flush; + std::this_thread::sleep_for(1s); // might silly but it tells everyone that we are doing something if (auto fp = ssh::fingerprint(); !fp) return 1; else if (auto session = env::env("XDG_SESSION_ID", true); !session) return 1; @@ -34,7 +37,8 @@ int main(int argc, const char** argv) if (!signature) return 1; data.Signature = *signature; sql::writedb(data); - if (env::interactive()) std::cout << fmt::format("\33[2K\rLogged in as {}.\n", Keys[*fp].Username); + if (env::interactive()) + std::cout << fmt::format("\33[2K\rLogged in as {} (fingerprint: SHA256:{}).\n", Keys[*fp].Username, *fp); } } else if (args[1] == "logout") @@ -94,7 +98,7 @@ int main(int argc, const char** argv) sql::FinishJobData data { .Time = now(), .JobId = jobid, .JobResult = std::get<1>(all_jobs->at(jobid)), - .SubmitTime = std::get<0>(all_jobs->at(jobid)), .JobDetail = *detail, + .SubmitTime = std::get<0>(all_jobs->at(jobid)), .JobDetail = *detail, .Key = *fp, .CpuTime = std::get<2>(all_jobs->at(jobid)), }; if @@ -107,7 +111,14 @@ int main(int argc, const char** argv) } } } + else if (args[1] == "verify") + { + if (args.size() < 4) { std::cerr << "Usage: hpcstat verify \n"; return 1; } + if (auto db_verify_result = sql::verify(args[2], args[3]); !db_verify_result) return 1; + else for (auto& data : *db_verify_result) + if (!std::apply(ssh::verify, data)) + { std::cerr << fmt::format("Failed to verify data: {}\n", std::get<0>(data)); return 1; } + } else { std::cerr << "Unknown command.\n"; return 1; } - return 0; } diff --git a/src/sql.cpp b/src/sql.cpp index 956e5a0..148e099 100644 --- a/src/sql.cpp +++ b/src/sql.cpp @@ -4,6 +4,8 @@ # include # include # include +# include +# include namespace hpcstat::sql { @@ -17,9 +19,12 @@ namespace hpcstat::sql template std::string serialize(LoginData); template std::string serialize(SubmitJobData); template std::string serialize(FinishJobData); - std::optional> connect() + std::optional> connect + (std::optional dbfile = std::nullopt) { - if (auto datadir = env::env("HPCSTAT_DATADIR", true); !datadir) + if (dbfile) return std::make_optional> + (dbfile->c_str()); + else if (auto datadir = env::env("HPCSTAT_DATADIR", true); !datadir) return std::nullopt; else { @@ -52,4 +57,54 @@ namespace hpcstat::sql return not_logged_job; } } + std::optional>> + verify(std::string old_db, std::string new_db) + { + auto old_conn = connect(old_db), new_conn = connect(new_db); + if (!old_conn || !new_conn) { std::cerr << "Failed to connect to database.\n"; return std::nullopt; } + else + { + auto check_one = [&]() + -> std::optional>> + { + auto old_query = old_conn->select_query().many().exec(), + new_query = new_conn->select_query().many().exec(); + auto old_data_it = old_query.begin(), new_data_it = new_query.begin(); + for (; old_data_it != old_query.end() && new_data_it != new_query.end(); ++old_data_it, ++new_data_it) + if (*old_data_it != *new_data_it) + { + std::cerr << fmt::format + ("Data mismatch: {} {} != {}.\n", nameof::nameof_type(), (*old_data_it).Id, (*new_data_it).Id); + return std::nullopt; + } + if (old_data_it != old_query.end() && new_data_it == new_query.end()) + { + std::cerr << fmt::format("Data mismatch in {}.\n", nameof::nameof_type()); + return std::nullopt; + } + else if constexpr (requires(T data) { data.Signature; }) + { + std::vector> diff; + for (; old_data_it != old_query.end(); ++old_data_it) + { + auto data = *old_data_it; + data.Signature = ""; + data.Id = 0; + diff.push_back({ serialize(data), (*old_data_it).Signature, (*old_data_it).Key }); + } + return diff; + } + else return std::vector>{}; + }; + auto check_many = [&](auto&& self) + -> std::optional>> + { + if (auto diff = check_one.operator()(); !diff) return std::nullopt; + else if constexpr (sizeof...(Ts) == 0) return diff; + else if (auto diff2 = self.template operator()(self); !diff2) return std::nullopt; + else { diff->insert(diff->end(), diff2->begin(), diff2->end()); return diff; } + }; + return check_many.operator()(check_many); + } + } } diff --git a/src/ssh.cpp b/src/ssh.cpp index 49ea70d..46e0c14 100644 --- a/src/ssh.cpp +++ b/src/ssh.cpp @@ -30,7 +30,7 @@ namespace hpcstat::ssh for ( auto i = std::sregex_iterator(output->begin(), output->end(), pattern); - i != std::sregex_iterator(); i++ + i != std::sregex_iterator(); ++i ) if (Keys.contains(i->str(1))) return i->str(1); std::cerr << fmt::format("No valid fingerprint found in:\n{}\n", *output);