packages.ufo: fix

This commit is contained in:
2024-09-13 14:56:51 +08:00
parent 9baa47e5b0
commit 317f94a875
6 changed files with 156 additions and 5 deletions

View File

@@ -118,9 +118,16 @@ namespace biu
template<> std::string read<char>(const std::filesystem::path& path);
template<> std::vector<std::byte> read<std::byte>(std::istream& input);
template<> std::string read<char>(std::istream& input);
namespace detail_
{
struct ToLvalueHelper {};
template <typename T> T& operator|(T&& obj, const ToLvalueHelper&);
}
constexpr detail_::ToLvalueHelper toLvalue;
}
using common::hash, common::unused, common::block_forever, common::is_interactive, common::env, common::int128_t,
common::uint128_t, common::Empty, common::CaseInsensitiveStringLessComparator, common::RemoveMemberPointer,
common::MoveQualifiers, common::FallbackIfNoTypeDeclared, common::exec, common::serialize, common::deserialize,
common::sequence, common::read;
common::sequence, common::read, common::toLvalue;
}

View File

@@ -70,4 +70,5 @@ namespace biu::common
for (std::size_t i = 0; i < from.size(); i++) from[i] = 0;
return sequence(from, to);
}
template <typename T> T& detail_::operator|(T&& obj, const ToLvalueHelper&) { return static_cast<T&>(obj); }
}

View File

@@ -5,4 +5,9 @@ int main()
using namespace biu::literals;
for (auto [a, b] : biu::sequence(std::array{2, 2, 2}))
std::cout << "{} {}\n"_f(a, b);
std::optional<std::vector<int>> a;
auto b = a.value_or(std::vector<int>{1, 2, 3})
| biu::toLvalue
| ranges::views::transform([](int i){ return i + 1; })
| ranges::to_vector;
}

View File

@@ -12,9 +12,11 @@ namespace ufo
void fold(std::string config_file);
void unfold(std::string config_file);
void plot(std::string config_file);
void plot_band(std::string config_file);
void plot_point(std::string config_file);
// unfold 和 plot 都需要用到这个,所以写出来
// TODO: 把输入的数据也保留进来
struct UnfoldOutput
{
Eigen::Matrix3d PrimativeCell;

View File

@@ -8,8 +8,10 @@ int main(int argc, const char** argv)
ufo::fold(argv[2]);
else if (argv[1] == std::string("unfold"))
ufo::unfold(argv[2]);
else if (argv[1] == std::string("plot"))
ufo::plot(argv[2]);
else if (argv[1] == std::string("plot-band"))
ufo::plot_band(argv[2]);
else if (argv[1] == std::string("plot-point"))
ufo::plot_point(argv[2]);
else
throw std::runtime_error(fmt::format("Unknown task: {}", argv[1]));
}

View File

@@ -2,7 +2,7 @@
# include <matplot/matplot.h>
# include <boost/container/flat_map.hpp>
void ufo::plot(std::string config_file)
void ufo::plot_band(std::string config_file)
{
struct Input
{
@@ -258,3 +258,137 @@ void ufo::plot(std::string config_file)
.write("Resolution", std::vector{input.Resolution.X, input.Resolution.Y})
.write("Range", std::vector{input.FrequencyRange.Min, input.FrequencyRange.Max});
}
void ufo::plot_point(std::string config_file)
{
struct Input
{
std::string UnfoldedDataFile;
// 要画图的 q 点
Eigen::Vector3d Qpoint;
// x 方向为频率y 方向没有用
struct { std::size_t X, Y; } Resolution;
// 画图的频率范围
struct { double Min, Max; } FrequencyRange;
// 搜索 q 点时的阈值,单位为埃^-1
std::optional<double> ThresholdWhenSearchingQpoints;
// 是否要在 z 轴上作一些标记
std::optional<std::vector<double>> XTicks;
// 是否输出图片
std::optional<std::string> OutputPictureFile;
// 是否输出数据,可以进一步使用 matplotlib 画图
std::optional<std::string> OutputDataFile;
};
// 根据 q 点路径, 搜索要使用的 q 点,返回的是 q 点在 QpointData 中的索引
auto search_qpoints = []
(
const Eigen::Matrix3d& primative_cell,
const Eigen::Vector3d& qpoint, const std::vector<Eigen::Vector3d>& qpoints,
double threshold
)
{
biu::Logger::Guard log(qpoint);
// 对于 output 中的每一个点, 检查这个点是否与所寻找的点足够近,如果足够近则返回
for (std::size_t i = 0; i < qpoints.size(); i++)
for (auto cell_shift
: biu::sequence(Eigen::Vector3i(-1, -1, -1), Eigen::Vector3i(2, 2, 2)))
{
auto this_qpoint
= (primative_cell.reverse().transpose() * (qpoints[i] + cell_shift.first.cast<double>())).eval();
if ((this_qpoint - primative_cell.reverse().transpose() * qpoint).norm() < threshold) return log.rtn(i);
}
throw std::runtime_error("No q points found");
};
// 根据搜索到的 q 点, 计算图中每个点的值
auto calculate_values = []
(
// q 点的数据(需要用到它的频率和权重)
const UnfoldOutput::QpointDataType& qpoint,
// 用于插值的分辨率和范围
unsigned resolution,
const std::pair<double, double>& frequency_range
)
{
biu::Logger::Guard log;
std::vector<double> result(resolution);
for (auto& mode : qpoint.ModeData)
{
int index = mode.Frequency - frequency_range.first / (frequency_range.second - frequency_range.first)
* resolution;
if (index >= 0 && index < static_cast<int>(resolution)) result[index] += mode.Weight;
}
return log.rtn(result);
};
// 根据数值, 画图
auto plot = []
(
const std::vector<double>& values, const std::string& filename,
const std::vector<double>& x_ticks, unsigned y_resolution
)
{
biu::Logger::Guard log;
std::vector<std::vector<double>>
r(y_resolution, std::vector<double>(values.size(), 0)),
g(y_resolution, std::vector<double>(values.size(), 0)),
b(y_resolution, std::vector<double>(values.size(), 0)),
a(y_resolution, std::vector<double>(values.size(), 0));
for (unsigned i = 0; i < y_resolution; i++) for (unsigned j = 0; j < values.size(); j++)
{
auto v = values[j];
if (v < 0.05) v = 0;
a[i][j] = v * 100 * 255;
if (a[i][j] > 255) a[i][j] = 255;
r[i][j] = 255 - v * 2 * 255;
if (r[i][j] < 0) r[i][j] = 0;
g[i][j] = 255 - v * 2 * 255;
if (g[i][j] < 0) g[i][j] = 0;
b[i][j] = 255;
}
auto f = matplot::figure(true);
auto ax = f->current_axes();
auto image = ax->image(std::tie(r, g, b));
image->matrix_a(a);
ax->y_axis().reverse(false);
ax->x_axis().tick_values(x_ticks);
ax->x_axis().tick_length(1);
f->save(filename, "png");
};
biu::Logger::Guard log;
auto input = YAML::LoadFile(config_file).as<Input>();
auto unfolded_data = biu::deserialize<UnfoldOutput>
(biu::read<std::byte>(input.UnfoldedDataFile));
auto qpoint_index = search_qpoints
(
unfolded_data.PrimativeCell, input.Qpoint,
unfolded_data.QpointData
| ranges::views::transform(&UnfoldOutput::QpointDataType::Qpoint)
| ranges::to_vector,
input.ThresholdWhenSearchingQpoints.value_or(0.001)
);
auto values = calculate_values
(
unfolded_data.QpointData[qpoint_index],
input.Resolution.X, {input.FrequencyRange.Min, input.FrequencyRange.Max}
);
auto x_ticks = input.XTicks.value_or(std::vector<double>{})
| biu::toLvalue
| ranges::views::transform([&](auto i)
{
return (i - input.FrequencyRange.Min) / (input.FrequencyRange.Max - input.FrequencyRange.Min)
* input.Resolution.X;
})
| ranges::to_vector;
if (input.OutputPictureFile)
plot(values, input.OutputPictureFile.value(), x_ticks, input.Resolution.Y);
if (input.OutputDataFile)
biu::Hdf5file(input.OutputDataFile.value(), true)
.write("Values", values)
.write("XTicks", x_ticks)
.write("Resolution", std::vector{input.Resolution.X, input.Resolution.Y})
.write("Range", std::vector{input.FrequencyRange.Min, input.FrequencyRange.Max});
}