From 3c4f60bf8f4f82a355e18f068a7cf0810c3a344e Mon Sep 17 00:00:00 2001 From: chn Date: Thu, 12 Oct 2023 17:58:13 +0800 Subject: [PATCH] =?UTF-8?q?=E7=BB=9F=E4=B8=80=E5=A4=96=E9=83=A8=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E7=9A=84=E8=AF=BB=E5=8F=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- include/ufo/fold.hpp | 2 +- include/ufo/plot.hpp | 8 ++--- include/ufo/solver.hpp | 14 ++++++++ include/ufo/unfold.hpp | 13 ++------ src/fold.cpp | 4 +-- src/plot.cpp | 21 ++++-------- src/solver.cpp | 30 ++++++++++++++++++ src/unfold.cpp | 72 +++++++++++------------------------------- 8 files changed, 78 insertions(+), 86 deletions(-) diff --git a/include/ufo/fold.hpp b/include/ufo/fold.hpp index 7da8e38..6ce66dc 100644 --- a/include/ufo/fold.hpp +++ b/include/ufo/fold.hpp @@ -11,7 +11,7 @@ namespace ufo Eigen::Vector SuperCellMultiplier; std::optional> SuperCellDeformation; std::vector Qpoints; - std::string OutputFilename; + DataFile OutputFile; InputType(std::string config_file); }; diff --git a/include/ufo/plot.hpp b/include/ufo/plot.hpp index bb0fb85..8c203ee 100644 --- a/include/ufo/plot.hpp +++ b/include/ufo/plot.hpp @@ -16,10 +16,8 @@ namespace ufo std::pair Resolution; std::pair Range; std::optional> YTicks; - std::string PictureFilename; - - struct DataFileType { std::string Filename, Format; }; - std::optional> DataFiles; + DataFile PictureFile; + std::optional> DataFiles; }; std::vector Figures; @@ -28,7 +26,7 @@ namespace ufo UnfoldedDataType(std::string filename); UnfoldedDataType() = default; }; - std::string UnfoldedDataFilename; + DataFile UnfoldedDataFile; UnfoldedDataType UnfoldedData; InputType(std::string config_file); diff --git a/include/ufo/solver.hpp b/include/ufo/solver.hpp index 391a012..e1e2c36 100644 --- a/include/ufo/solver.hpp +++ b/include/ufo/solver.hpp @@ -121,6 +121,20 @@ namespace ufo protected: std::optional File_; }; + + struct DataFile + { + std::string Filename; + std::string Format; + std::map ExtraParameters; + inline DataFile() = default; + DataFile + ( + YAML::Node node, std::set supported_format, + std::string config_file, bool allow_same_as_config_file = false + ); + }; + }; } diff --git a/include/ufo/unfold.hpp b/include/ufo/unfold.hpp index a652e29..c53dc00 100644 --- a/include/ufo/unfold.hpp +++ b/include/ufo/unfold.hpp @@ -32,17 +32,10 @@ namespace ufo // 在单胞内取几个平面波的基矢 Eigen::Vector PrimativeCellBasisNumber; - struct InputOutputFile - { - std::string FileName; - std::string Format; - std::map ExtraParameters; - }; - // 从哪个文件读入 AtomPosition, 以及这个文件的格式, 格式可选值包括 "yaml" - InputOutputFile AtomPositionInputFile; + DataFile AtomPositionInputFile; // 从哪个文件读入 QpointData, 以及这个文件的格式, 格式可选值包括 "yaml" 和 "hdf5" - InputOutputFile QpointDataInputFile; + DataFile QpointDataInputFile; // 超胞中原子的坐标,每行表示一个原子的坐标,单位为埃 Eigen::MatrixX3d AtomPosition; @@ -72,7 +65,7 @@ namespace ufo // yaml-human-readable: 使用 yaml 格式输出, 但是输出的结果更适合人类阅读, // 包括合并相近的模式, 去除权重过小的模式, 限制输出的小数位数. // zpp: 使用 zpp-bits 序列化, 可以直接被 plot.cpp 读取 - std::vector QpointDataOutputFile; + std::vector QpointDataOutputFile; // 从文件中读取输入 (包括一个较小的配置文件, 和一个 hdf5 或者一个 yaml 文件), 文件中应当包含: // 单胞的格矢: PrimativeCell 单位为埃 直接从 phonopy 的输出中复制 diff --git a/src/fold.cpp b/src/fold.cpp index bb6958a..9732be9 100644 --- a/src/fold.cpp +++ b/src/fold.cpp @@ -17,7 +17,7 @@ namespace ufo for (auto& qpoint : input["Qpoints"].as>>()) Qpoints.push_back(Eigen::Vector3d {{qpoint.at(0)}, {qpoint.at(1)}, {qpoint.at(2)}}); - OutputFilename = input["OutputFilename"].as(); + OutputFile = DataFile(input["OutputFile"], {"yaml"}, config_file); } void FoldSolver::OutputType::write(std::string filename) const { @@ -44,7 +44,7 @@ namespace ufo Input_.SuperCellDeformation )); } - Output_->write(Input_.OutputFilename); + Output_->write(Input_.OutputFile.Filename); return *this; } diff --git a/src/plot.cpp b/src/plot.cpp index b9fbb70..f607807 100644 --- a/src/plot.cpp +++ b/src/plot.cpp @@ -30,27 +30,20 @@ namespace ufo throw std::runtime_error("Not enough lines in a figure"); Figures.back().Resolution = figure["Resolution"].as>(); Figures.back().Range = figure["Range"].as>(); - Figures.back().PictureFilename = figure["PictureFilename"].as(); + Figures.back().PictureFile + = DataFile(figure["PictureFile"], {"png"}, config_file); if (figure["YTicks"]) Figures.back().YTicks = figure["YTicks"].as>(); if (figure["DataFiles"]) { Figures.back().DataFiles.emplace(); for (auto& data_file : figure["DataFiles"].as>()) - { - Figures.back().DataFiles->emplace_back - ( - data_file["Filename"].as(), - data_file["Format"].as() - ); - if (!std::set{ "hdf5"s, "zpp"s }.contains(Figures.back().DataFiles->back().Format)) - throw std::runtime_error(fmt::format("Unknown data file format: {}", - Figures.back().DataFiles->back().Format)); - } + Figures.back().DataFiles->emplace_back() + = DataFile(data_file, {"hdf5", "zpp"}, config_file); } } - UnfoldedDataFilename = input["UnfoldedDataFilename"].as(); - UnfoldedData = UnfoldedDataType(UnfoldedDataFilename); + UnfoldedDataFile = DataFile(input["UnfoldedDataFile"], {"zpp"}, config_file); + UnfoldedData = UnfoldedDataType(UnfoldedDataFile.Filename); } const PlotSolver::OutputType& PlotSolver::OutputType::write(std::string filename, std::string format) const { @@ -96,7 +89,7 @@ namespace ufo auto y_ticks = figure.YTicks.value_or(std::vector{}); for (auto& _ : y_ticks) _ = (_ - figure.Range.first) / (figure.Range.second - figure.Range.first) * figure.Resolution.second; - plot(values, figure.PictureFilename, x_ticks, y_ticks); + plot(values, figure.PictureFile.Filename, x_ticks, y_ticks); Output_->emplace_back(); Output_->back().Values = std::move(values); Output_->back().XTicks = std::move(x_ticks); diff --git a/src/solver.cpp b/src/solver.cpp index a5e640d..4d6ce5e 100644 --- a/src/solver.cpp +++ b/src/solver.cpp @@ -14,4 +14,34 @@ namespace ufo x * range[1] * range[2] + y * range[2] + z }; } + + Solver::DataFile::DataFile + (YAML::Node node, std::set supported_format, std::string config_file, bool allow_same_as_config_file) + { + if (auto _ = node["SameAsConfigFile"]) + { + auto __ = _.as(); + if (__ && !allow_same_as_config_file) + throw std::runtime_error("\"SameAsConfigFile: true\" is not allowed here."); + ExtraParameters["SameAsConfigFile"] = __; + if (__) + { + Filename = config_file; + Format = "yaml"; + return; + } + } + Filename = node["Filename"].as(); + Format = node["Format"].as(); + if (!supported_format.contains(Format)) + throw std::runtime_error(fmt::format("Unsupported format: \"{}\"", Format)); + if (auto _ = node["RelativeToConfigFile"]) + { + auto __ = _.as(); + ExtraParameters["RelativeToConfigFile"] = __; + if (__) + Filename = std::filesystem::path(config_file).parent_path() / Filename; + } + }; + } diff --git a/src/unfold.cpp b/src/unfold.cpp index 84eb6a1..e3eea4d 100644 --- a/src/unfold.cpp +++ b/src/unfold.cpp @@ -25,67 +25,31 @@ namespace ufo for (unsigned i = 0; i < 3; i++) PrimativeCellBasisNumber(i) = node["PrimativeCellBasisNumber"][i].as(); - auto read_file_config = [filename](YAML::Node source, InputOutputFile& config) - { - if (auto _ = source["SameAsConfigFile"]) - { - auto __ = _.as(); - config.ExtraParameters["SameAsConfigFile"] = __; - if (__) - { - config.FileName = filename; - config.Format = "yaml"; - return; - } - } - config.FileName = source["FileName"].as(); - config.Format = source["Format"].as(); - if (auto _ = source["RelativeToConfigFile"]) - { - auto __ = _.as(); - config.ExtraParameters["RelativeToConfigFile"] = __; - if (__) - config.FileName = std::filesystem::path(filename).parent_path() / config.FileName; - } - }; - read_file_config(node["AtomPositionInputFile"], AtomPositionInputFile); - if (!std::set{"yaml"}.contains(AtomPositionInputFile.Format)) - throw std::runtime_error(fmt::format - ("Unknown AtomPositionInputFile.Format: {}, should be \"yaml\".", AtomPositionInputFile.Format)); - read_file_config(node["QpointDataInputFile"], QpointDataInputFile); - if (!std::set{"yaml", "hdf5"}.contains(QpointDataInputFile.Format)) - throw std::runtime_error(fmt::format - ("Unknown QpointDataInputFile.Format: {}, should be \"yaml\" or \"hdf5\".", QpointDataInputFile.Format)); + AtomPositionInputFile = DataFile + ( + node["AtomPositionInputFile"], {"yaml"}, + filename, true + ); + QpointDataInputFile = DataFile + ( + node["QpointDataInputFile"], {"yaml", "hdf5"}, + filename, true + ); if (auto value = node["QpointDataOutputFile"]) { QpointDataOutputFile.resize(value.size()); for (unsigned i = 0; i < value.size(); i++) - { - read_file_config(value[i], QpointDataOutputFile[i]); - if + QpointDataOutputFile[i] = DataFile ( - QpointDataOutputFile[i].ExtraParameters.contains("SameAsConfigFile") - && std::any_cast(QpointDataOutputFile[i].ExtraParameters["SameAsConfigFile"]) - ) - throw std::runtime_error("QpointDataOutputFile.SameAsConfigFile should not be set."); - if - ( - !std::set{"yaml", "yaml-human-readable", "zpp", "hdf5"} - .contains(QpointDataOutputFile[i].Format) - ) - throw std::runtime_error(fmt::format - ( - "Unknown QpointDataOutputFile[{}].Format: {}, should be " - "\"yaml\", \"yaml-human-readable\", \"zpp\" or \"hdf5\".", - i, QpointDataOutputFile[i].Format - )); - } + value[i], {"yaml", "yaml-human-readable", "zpp", "hdf5"}, + filename, false + ); } } if (AtomPositionInputFile.Format == "yaml") { - auto node = YAML::LoadFile(AtomPositionInputFile.FileName); + auto node = YAML::LoadFile(AtomPositionInputFile.Filename); std::vector points; if (auto _ = node["points"]) points = _.as>(); @@ -101,7 +65,7 @@ namespace ufo } if (QpointDataInputFile.Format == "yaml") { - auto node = YAML::LoadFile(QpointDataInputFile.FileName); + auto node = YAML::LoadFile(QpointDataInputFile.Filename); auto phonon = node["phonon"].as>(); QpointData.resize(phonon.size()); for (unsigned i = 0; i < phonon.size(); i++) @@ -135,7 +99,7 @@ namespace ufo { std::vector>> frequency, path; std::vector>>> eigenvector_vector; - Hdf5file{}.open_for_read(QpointDataInputFile.FileName).read(frequency, "/frequency") + Hdf5file{}.open_for_read(QpointDataInputFile.Filename).read(frequency, "/frequency") .read(eigenvector_vector, "/eigenvector") .read(path, "/path"); std::vector size = { frequency.size(), frequency[0].size(), frequency[0][0].size() }; @@ -163,7 +127,7 @@ namespace ufo (decltype(InputType::QpointDataOutputFile) output_files) const { for (auto& output_file : output_files) - write(output_file.FileName, output_file.Format); + write(output_file.Filename, output_file.Format); } void UnfoldSolver::OutputType::write(std::string filename, std::string format, unsigned percision) const {