From 29cdc32ec0e19285d74240a804974de0b1f82686 Mon Sep 17 00:00:00 2001 From: chn Date: Thu, 5 Oct 2023 16:20:53 +0800 Subject: [PATCH] merge plot into main --- CMakeLists.txt | 7 +- include/ufo/plot.hpp | 60 +++++++ include/ufo/solver.hpp | 2 +- src/main.cpp | 12 +- src/plot.cpp | 344 +++++++++++++++++++++++------------------ src/unfold.cpp | 18 --- 6 files changed, 264 insertions(+), 179 deletions(-) create mode 100644 include/ufo/plot.hpp diff --git a/CMakeLists.txt b/CMakeLists.txt index e026da1..f02465c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -22,16 +22,11 @@ find_package(TBB REQUIRED) find_package(Matplot++ REQUIRED) find_path(ZPP_BITS_INCLUDE_DIR zpp_bits.h REQUIRED) -add_executable(ufo src/unfold.cpp src/main.cpp) +add_executable(ufo src/unfold.cpp src/plot.cpp src/main.cpp) target_include_directories(ufo PRIVATE ${PROJECT_SOURCE_DIR}/include ${ZPP_BITS_INCLUDE_DIR}) target_link_libraries(ufo PRIVATE yaml-cpp Eigen3::Eigen fmt::fmt concurrencpp::concurrencpp HighFive_HighFive TBB::tbb Matplot++::matplot) -add_executable(plot src/plot.cpp) -target_include_directories(plot PRIVATE ${PROJECT_SOURCE_DIR}/include ${ZPP_BITS_INCLUDE_DIR}) -target_link_libraries(plot PRIVATE - yaml-cpp Eigen3::Eigen fmt::fmt concurrencpp::concurrencpp HighFive_HighFive TBB::tbb Matplot++::matplot) - get_property(ImportedTargets DIRECTORY "${CMAKE_SOURCE_DIR}" PROPERTY IMPORTED_TARGETS) message("Imported targets: ${ImportedTargets}") message("List of compile features: ${CMAKE_CXX_COMPILE_FEATURES}") diff --git a/include/ufo/plot.hpp b/include/ufo/plot.hpp new file mode 100644 index 0000000..33012fd --- /dev/null +++ b/include/ufo/plot.hpp @@ -0,0 +1,60 @@ +# pragma once +# include + +namespace ufo +{ + class PlotSolver : public Solver + { + public: + struct InputType + { + Eigen::Matrix3d PrimativeCell; + + struct FigureConfigType + { + std::vector> Qpoints; + std::pair Resolution; + std::pair Range; + std::string Filename; + }; + std::vector Figures; + + struct SourceType : public UnfoldSolver::OutputType + { + SourceType(std::string filename); + SourceType() = default; + }; + std::string SourceFilename; + SourceType Source; + + InputType(std::string config_file); + }; + protected: + InputType Input_; + public: + PlotSolver(std::string config_file); + PlotSolver& operator()() override; + + // 根据 q 点路径, 搜索要使用的 q 点 + static std::vector> search_qpoints + ( + const std::pair& path, + const decltype(InputType::SourceType::QPointData)& available_qpoints, + double threshold, bool exclude_endpoint = false + ); + // 根据搜索到的 q 点, 计算每个点的数值 + static std::vector> calculate_values + ( + const std::vector>& path, + const std::vector>>& qpoints, + const decltype(InputType::FigureConfigType::Resolution)& resolution, + const decltype(InputType::FigureConfigType::Range)& range + ); + // 根据数值, 画图 + static void plot + ( + const std::vector>& values, + const decltype(InputType::FigureConfigType::Filename)& filename + ); + }; +} diff --git a/include/ufo/solver.hpp b/include/ufo/solver.hpp index 2f72654..6275aa6 100644 --- a/include/ufo/solver.hpp +++ b/include/ufo/solver.hpp @@ -50,7 +50,7 @@ namespace ufo { public: virtual Solver& operator()() = 0; - ~Solver() = default; + virtual ~Solver() = default; inline static concurrencpp::generator, unsigned>> triplet_sequence(Eigen::Vector range) diff --git a/src/main.cpp b/src/main.cpp index 638944f..c53b33a 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -1,9 +1,13 @@ -# include +# include int main(int argc, const char** argv) { if (argc != 2) - throw std::runtime_error(fmt::format("Usage: {} config.yaml", argv[0])); - - ufo::UnfoldSolver{argv[1]}(); + throw std::runtime_error(fmt::format("Usage: {} task config.yaml", argv[0])); + if (argv[1] == std::string("unfold")) + ufo::UnfoldSolver{argv[1]}(); + else if (argv[1] == std::string("plot")) + ufo::PlotSolver{argv[1]}(); + else + throw std::runtime_error(fmt::format("Unknown task: {}", argv[1])); } diff --git a/src/plot.cpp b/src/plot.cpp index 55e3c6c..c468a93 100644 --- a/src/plot.cpp +++ b/src/plot.cpp @@ -1,171 +1,215 @@ -# include +# include -// 要被用来画图的路径 -std::vector Qpoints = +namespace ufo { - { 0, 0, 0 }, - { 0.5, 0, 0 }, - { 1. / 3, 1. / 3, 0 }, - { 0, 0, 0 }, - { 0, 0, 0.5 }, - { 0.5, 0, 0.5 }, - { 1. / 3, 1. / 3, 0.5 }, - { 0, 0, 0.5 } -}; -double Threshold = 0.001; - -struct Point -{ - Eigen::Vector3d QPoint; - Eigen::VectorXd Frequency, Weight; - double Distance; -}; - -int main(int argc, char** argv) -{ - if (argc != 2) - throw std::runtime_error("Usage: plot output.zpp"); - - Output output(argv[1]); - - std::vector Points; - - double total_distance = 0; - // 对于每一条路径进行搜索 - for (unsigned i = 0; i < Qpoints.size() - 1; i++) + PlotSolver::InputType::SourceType::SourceType(std::string filename) { - std::vector point_of_this_path; - // 对于 output 中的每一个点, 检查这个点是否在路径上. 如果在, 把它加入到 point_of_this_path 中 - for (auto& qpoint : output.QPointData) + auto input = std::ifstream(filename, std::ios::binary | std::ios::in); + input.exceptions(std::ios::badbit | std::ios::failbit); + static_assert(sizeof(std::byte) == sizeof(char)); + std::vector data; + { + std::vector string(std::istreambuf_iterator(input), {}); + data.assign + ( + reinterpret_cast(string.data()), + reinterpret_cast(string.data() + string.size()) + ); + } + auto in = zpp::bits::in(data); + UnfoldSolver::OutputType output; + in(output).or_throw(); + static_cast(*this) = std::move(output); + } + + PlotSolver::InputType::InputType(std::string config_file) + { + auto input = YAML::LoadFile(config_file); + for (unsigned i = 0; i < 3; i++) + for (unsigned j = 0; j < 3; j++) + PrimativeCell(i, j) = input["PrimativeCell"][i][j].as(); + for (auto& figure : input["Figures"].as>()) + { + Figures.emplace_back(); + auto qpoints = figure["Qpoints"] + .as>>>(); + for (auto& line : qpoints) + { + Figures.back().Qpoints.emplace_back(); + for (auto& point : line) + Figures.back().Qpoints.back().emplace_back(point.at(0), point.at(1), point.at(2)); + if (Figures.back().Qpoints.back().size() < 2) + throw std::runtime_error("Not enough points in a line"); + } + if (Figures.back().Qpoints.size() < 1) + throw std::runtime_error("Not enough lines in a figure"); + Figures.back().Resolution = figure["Resolution"].as>(); + Figures.back().Range = figure["Range"].as>(); + Figures.back().Filename = figure["Filename"].as(); + } + SourceFilename = input["SourceFilename"].as(); + Source = SourceType(SourceFilename); + } + + PlotSolver::PlotSolver(std::string config_file) : Input_(config_file) {} + + PlotSolver& PlotSolver::operator()() + { + for (auto& figure : Input_.Figures) + { + // 外层表示不同的线段的端点,内层表示这个线段上的 q 点 + std::vector>> qpoints; + std::vector> lines; + for (auto& path : figure.Qpoints) + for (unsigned i = 0; i < path.size() - 1; i++) + { + lines.emplace_back(path[i], path[i + 1]); + qpoints.push_back(search_qpoints + ( + lines.back(), Input_.Source.QPointData, + 0.001, i != path.size() - 2 + )); + } + auto values = calculate_values(lines, qpoints, figure.Resolution, figure.Range); + plot(values, figure.Filename); + } + return *this; + } + + std::vector> PlotSolver::search_qpoints + ( + const std::pair& path, + const decltype(InputType::SourceType::QPointData)& available_qpoints, + double threshold, bool exclude_endpoint + ) + { + std::multimap> selected_qpoints; + // 对于 output 中的每一个点, 检查这个点是否在路径上. 如果在, 把它加入到 selected_qpoints 中 + for (auto& qpoint : available_qpoints) { // 计算三点围成的三角形的面积的两倍 - auto area = (Qpoints[i + 1] - Qpoints[i]).cross(qpoint.QPoint - Qpoints[i]).norm(); + auto area = (path.second - path.first).cross(qpoint.QPoint - path.first).norm(); // 计算这个点到前两个点所在直线的距离 - auto distance = area / (Qpoints[i + 1] - Qpoints[i]).norm(); + auto distance = area / (path.second - path.first).norm(); // 如果这个点到前两个点所在直线的距离小于阈值, 则认为这个点在路径上 - if (distance < Threshold) + if (distance < threshold) { // 计算这个点到前两个点的距离, 两个距离都应该小于两点之间的距离 - auto distance1 = (qpoint.QPoint - Qpoints[i]).norm(); - auto distance2 = (qpoint.QPoint - Qpoints[i + 1]).norm(); - auto distance3 = (Qpoints[i + 1] - Qpoints[i]).norm(); - if (distance1 < distance3 + Threshold && distance2 < distance3 + Threshold) - // 如果这个点在终点处, 且这条路径不是最后一条, 则不加入 - if (distance2 > Threshold || i == Qpoints.size() - 2) - { - auto& _ = point_of_this_path.emplace_back(); - _.QPoint = qpoint.QPoint; - _.Distance = distance1; - _.Frequency.resize(qpoint.ModeData.size()); - _.Weight.resize(qpoint.ModeData.size()); - for (unsigned j = 0; j < qpoint.ModeData.size(); j++) - { - _.Frequency(j) = qpoint.ModeData[j].Frequency; - _.Weight(j) = qpoint.ModeData[j].Weight; - } - } + auto distance1 = (qpoint.QPoint - path.first).norm(); + auto distance2 = (qpoint.QPoint - path.second).norm(); + auto distance3 = (path.second - path.first).norm(); + if (distance1 < distance3 + threshold && distance2 < distance3 + threshold) + // 如果这个点不在终点处, 或者不排除终点, 则加入 + if (distance2 > threshold || !exclude_endpoint) + selected_qpoints.emplace(distance1, std::ref(qpoint)); } } - // 对筛选结果排序 - std::sort(point_of_this_path.begin(), point_of_this_path.end(), - [](const Point& a, const Point& b) { return a.Distance < b.Distance; }); // 去除非常接近的点 - for (unsigned j = 1; j < point_of_this_path.size(); j++) - while - ( - j < point_of_this_path.size() - && point_of_this_path[j].Distance - point_of_this_path[j - 1].Distance < Threshold - ) - point_of_this_path.erase(point_of_this_path.begin() + j); - // 将结果加入 - for (auto& point : point_of_this_path) - Points.emplace_back(point.QPoint, point.Frequency, point.Weight, point.Distance + total_distance); - total_distance += (Qpoints[i + 1] - Qpoints[i]).norm(); + for (auto it = selected_qpoints.begin(); it != selected_qpoints.end();) + { + auto next = std::next(it); + if (next == selected_qpoints.end()) + break; + else if (next->first - it->first < threshold) + selected_qpoints.erase(next); + else + it = next; + } + if (selected_qpoints.empty()) + throw std::runtime_error("No q points found"); + std::vector> result; + for (auto& qpoint : selected_qpoints) + result.push_back(qpoint.second); + return result; } - // 打印结果看一下 - for (auto& point : Points) - std::cout << point.Distance << " " << point.QPoint.transpose() << std::endl; - - // 对结果插值 - std::vector interpolated_points; - for (unsigned i = 0; i < 1024; i++) + std::vector> PlotSolver::calculate_values + ( + const std::vector>& path, + const std::vector>>& qpoints, + const decltype(InputType::FigureConfigType::Resolution)& resolution, + const decltype(InputType::FigureConfigType::Range)& range + ) { - auto current_distance = i * total_distance / 1024; - auto& _ = interpolated_points.emplace_back(); - _.Distance = current_distance; - auto it = std::lower_bound(Points.begin(), Points.end(), current_distance, - [](const Point& a, double b) { return a.Distance < b; }); - // 如果是开头或者结尾, 直接赋值, 否则插值 - if (it == Points.begin()) + // 整理输入 + std::map> qpoints_with_distance; + double total_distance = 0; + for (unsigned i = 0; i < path.size(); i++) { - _.Frequency = Points.front().Frequency; - _.Weight = Points.front().Weight; + for (auto& _ : qpoints[i]) + qpoints_with_distance.emplace(total_distance + (_.get().QPoint - path[i].first).norm(), _); + total_distance += (path[i].second - path[i].first).norm(); } - else if (it == Points.end() - 1) - { - _.Frequency = Points.back().Frequency; - _.Weight = Points.back().Weight; - } - else - { - _.Frequency = (it->Frequency * (it->Distance - current_distance) - + (it - 1)->Frequency * (current_distance - (it - 1)->Distance)) / (it->Distance - (it - 1)->Distance); - _.Weight = (it->Weight * (it->Distance - current_distance) - + (it - 1)->Weight * (current_distance - (it - 1)->Distance)) / (it->Distance - (it - 1)->Distance); - } - } - // 将结果对应到像素上的值 - std::vector> weight(1024, std::vector(400, 0)); - for (auto& point : interpolated_points) + // 插值 + std::vector> values; + auto blend = [] + ( + const UnfoldSolver::OutputType::QPointDataType& a, + const UnfoldSolver::OutputType::QPointDataType& b, + double ratio, unsigned resolution, std::pair range + ) -> std::vector + { + // 计算插值结果 + std::vector frequency, weight; + for (unsigned i = 0; i < a.ModeData.size(); i++) + { + frequency.push_back(a.ModeData[i].Frequency * ratio + b.ModeData[i].Frequency * (1 - ratio)); + weight.push_back(a.ModeData[i].Weight * ratio + b.ModeData[i].Weight * (1 - ratio)); + } + std::vector result(resolution); + for (unsigned i = 0; i < frequency.size(); i++) + { + int index = (frequency[i] - range.first) / (range.second - range.first) * resolution; + if (index >= 0 && index < static_cast(resolution)) + result[index] += weight[i]; + } + return result; + }; + for (unsigned i = 0; i < resolution.first; i++) + { + auto current_distance = total_distance * i / resolution.first; + auto it = qpoints_with_distance.lower_bound(current_distance); + if (it == qpoints_with_distance.begin()) + values.push_back(blend(it->second.get(), it->second.get(), 1, resolution.second, range)); + else if (it == qpoints_with_distance.end()) + values.push_back(blend(std::prev(it)->second.get(), std::prev(it)->second.get(), 1, resolution.second, + range)); + else + values.push_back(blend + ( + std::prev(it)->second.get(), it->second.get(), + (current_distance - std::prev(it)->first) / (it->first - std::prev(it)->first), + resolution.second, range) + ); + } + return values; + } + void PlotSolver::plot + ( + const std::vector>& values, + const decltype(InputType::FigureConfigType::Filename)& filename + ) { - int x = point.Distance / total_distance * 1024; - if (x < 0) - x = 0; - else if (x >= 1024) - x = 1023; - for (unsigned i = 0; i < point.Frequency.size(); i++) - { - auto y = (point.Frequency(i) + 5) * 10; - if (y < 0) - y = 0; - else if (y >= 400) - y = 399; - weight[x][y] += point.Weight(i); - } + std::vector> + r(values[0].size(), std::vector(values.size(), 0)), + g(values[0].size(), std::vector(values.size(), 0)), + b(values[0].size(), std::vector(values.size(), 0)); + for (unsigned i = 0; i < values[0].size(); i++) + for (unsigned j = 0; j < values.size(); j++) + { + r[i][j] = 255; + g[i][j] = 255 - values[j][i] * 2 * 255; + if (g[i][j] < 0) + g[i][j] = 0; + b[i][j] = 255 - values[j][i] * 2 * 255; + if (b[i][j] < 0) + b[i][j] = 0; + } + auto f = matplot::figure(true); + auto ax = f->current_axes(); + ax->image(std::tie(r, g, b)); + ax->y_axis().reverse(false); + f->save(filename); } - - std::ofstream fout("weight.txt"); - for (unsigned i = 0; i < 400; i++) - fout << fmt::format(" {:.6f}", i * 0.1 - 5); - fout << std::endl; - for (unsigned i = 0; i < 1024; i++) - { - fout << fmt::format("{:.6f} ", total_distance / 1024 * i); - for (unsigned j = 0; j < 400; j++) - fout << fmt::format("{:.6f} ", weight[i][j]); - fout << std::endl; - } - - std::vector> - r(400, std::vector(1024, 0)), - g(400, std::vector(1024, 0)), - b(400, std::vector(1024, 0)); - for (unsigned i = 0; i < 400; i++) - for (unsigned j = 0; j < 1024; j++) - { - r[i][j] = 255; - g[i][j] = 255 - weight[j][i] * 2 * 255; - if (g[i][j] < 0) - g[i][j] = 0; - b[i][j] = 255 - weight[j][i] * 2 * 255; - if (b[i][j] < 0) - b[i][j] = 0; - } - auto f = matplot::figure(true); - auto ax = f->current_axes(); - ax->image(std::tie(r, g, b)); - ax->y_axis().reverse(false); - f->show(); } diff --git a/src/unfold.cpp b/src/unfold.cpp index af60b50..c489514 100644 --- a/src/unfold.cpp +++ b/src/unfold.cpp @@ -478,21 +478,3 @@ namespace ufo return output; } } - -// inline Output::Output(std::string filename) -// { -// auto input = std::ifstream(filename, std::ios::binary | std::ios::in); -// input.exceptions(std::ios::badbit | std::ios::failbit); -// std::vector data; -// { -// std::vector string(std::istreambuf_iterator(input), {}); -// data.assign -// ( -// reinterpret_cast(string.data()), -// reinterpret_cast(string.data() + string.size()) -// ); -// } -// auto in = zpp::bits::in(data); -// in(*this).or_throw(); -// } -