Skip to content

Commit 7c0d60a

Browse files
committed
Implement std::formatter for enums
1 parent 1e160c5 commit 7c0d60a

File tree

3 files changed

+98
-31
lines changed

3 files changed

+98
-31
lines changed

argument_parser.cpp

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -128,9 +128,9 @@ namespace partdiff {
128128
this->add_argument_description(
129129
"method", method,
130130
std::format("calculation method (1 .. 2)\n"
131-
"{0}{1:d}: Gauß-Seidel\n"
132-
"{0}{2:d}: Jacobi",
133-
indent, to_underlying(calculation_method::gauss_seidel), to_underlying(calculation_method::jacobi)),
131+
"{0}{1:d}: {1:s}\n"
132+
"{0}{2:d}: {2:s}",
133+
indent, calculation_method::gauss_seidel, calculation_method::jacobi),
134134
[method] { return (*method == calculation_method::gauss_seidel || *method == calculation_method::jacobi); });
135135

136136
auto interlines = &(this->options.interlines);
@@ -141,27 +141,26 @@ namespace partdiff {
141141
[interlines] { return (interlines_bounds.contains(*interlines)); });
142142

143143
auto pert_func = &(this->options.pert_func);
144-
this->add_argument_description(
145-
"func", pert_func,
146-
std::format("perturbation function (1 .. 2)\n"
147-
"{0}{1:d}: f(x,y) = 0\n"
148-
"{0}{2:d}: f(x,y) = 2 * pi^2 * sin(pi * x) * sin(pi * y)",
149-
indent, to_underlying(perturbation_function::f0), to_underlying(perturbation_function::fpisin)),
150-
[pert_func] {
151-
return (*pert_func == perturbation_function::f0 || *pert_func == perturbation_function::fpisin);
152-
});
144+
this->add_argument_description("func", pert_func,
145+
std::format("perturbation function (1 .. 2)\n"
146+
"{0}{1:d}: {1:s}\n"
147+
"{0}{2:d}: {2:s}",
148+
indent, perturbation_function::f0, perturbation_function::fpisin),
149+
[pert_func] {
150+
return (*pert_func == perturbation_function::f0 ||
151+
*pert_func == perturbation_function::fpisin);
152+
});
153153

154154
auto termination = &(this->options.termination);
155-
this->add_argument_description("term", termination,
156-
std::format("termination condition ( 1.. 2)\n"
157-
"{0}{1:d}: sufficient accuracy\n"
158-
"{0}{2:d}: number of iterations",
159-
indent, to_underlying(termination_condition::accuracy),
160-
to_underlying(termination_condition::iterations)),
161-
[termination] {
162-
return (*termination == termination_condition::accuracy ||
163-
*termination == termination_condition::iterations);
164-
});
155+
this->add_argument_description(
156+
"term", termination,
157+
std::format("termination condition ( 1.. 2)\n"
158+
"{0}{1:d}: {1:s}\n"
159+
"{0}{2:d}: {2:s}",
160+
indent, termination_condition::accuracy, termination_condition::iterations),
161+
[termination] {
162+
return (*termination == termination_condition::accuracy || *termination == termination_condition::iterations);
163+
});
165164

166165
this->add_argument_description("acc/iter", std::format("depending on term:\n"
167166
"{0}accuracy: {1:.0e}\n"

enums.hpp

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,84 @@
11
#pragma once
22

3+
#include <array>
4+
#include <concepts>
35
#include <cstdint>
6+
#include <format>
7+
#include <iostream>
8+
#include <stdexcept>
9+
#include <string_view>
10+
#include <type_traits>
11+
#include <utility>
412

513
namespace partdiff {
614
enum class calculation_method : uint64_t { gauss_seidel = 1, jacobi = 2 };
715
enum class perturbation_function : uint64_t { f0 = 1, fpisin = 2 };
816
enum class termination_condition : uint64_t { accuracy = 1, iterations = 2 };
917

1018
} // namespace partdiff
19+
20+
template <typename Enum>
21+
struct enum_member_names;
22+
23+
template <>
24+
struct enum_member_names<partdiff::calculation_method> {
25+
static constexpr std::array names = {
26+
std::pair{partdiff::calculation_method::gauss_seidel, "Gauß-Seidel"},
27+
std::pair{partdiff::calculation_method::jacobi, "Jacobi"},
28+
};
29+
};
30+
31+
template <>
32+
struct enum_member_names<partdiff::perturbation_function> {
33+
static constexpr std::array names = {
34+
std::pair{partdiff::perturbation_function::f0, "f(x,y) = 0"},
35+
std::pair{partdiff::perturbation_function::fpisin, "f(x,y) = 2 * pi^2 * sin(pi * x) * sin(pi * y)"},
36+
};
37+
};
38+
39+
template <>
40+
struct enum_member_names<partdiff::termination_condition> {
41+
static constexpr std::array names = {
42+
std::pair{partdiff::termination_condition::accuracy, "Required accuracy"},
43+
std::pair{partdiff::termination_condition::iterations, "Number of iterations"},
44+
};
45+
};
46+
47+
// Some template magic to print the above enum classes with "{:d}" and "{:s}"
48+
49+
template <typename Enum>
50+
requires std::is_enum_v<Enum> && requires { enum_member_names<Enum>::names; }
51+
struct std::formatter<Enum> : std::formatter<std::string_view> {
52+
enum class mode { string, number } fmt_mode = mode::string;
53+
constexpr auto parse(std::format_parse_context &ctx) {
54+
auto it = ctx.begin();
55+
auto end = ctx.end();
56+
57+
if (it != end) {
58+
if (*it == 's') {
59+
fmt_mode = mode::string;
60+
++it;
61+
} else if (*it == 'd') {
62+
fmt_mode = mode::number;
63+
++it;
64+
} else if (*it != '}')
65+
throw std::format_error("invalid format specifier for enum");
66+
}
67+
return it;
68+
}
69+
70+
template <class FormatContext>
71+
auto format(Enum value, FormatContext &ctx) const {
72+
if (fmt_mode == mode::number) {
73+
return std::format_to(ctx.out(), "{}", static_cast<std::underlying_type_t<Enum>>(value));
74+
}
75+
std::string_view name = "unknown";
76+
for (auto &&[v, n] : enum_member_names<Enum>::names) {
77+
if (v == value) {
78+
name = n;
79+
break;
80+
}
81+
}
82+
return std::format_to(ctx.out(), "{}", name);
83+
}
84+
};

partdiff.cpp

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -95,19 +95,13 @@ namespace partdiff {
9595
const int N = arguments.N;
9696
const double time = std::chrono::duration<double>(results.end_time - results.start_time).count();
9797
const double memory_consumption = (N + 1) * (N + 1) * sizeof(double) * arguments.num_matrices / 1024.0 / 1024.0;
98-
const std::string_view calculation_method_display =
99-
options.method == calculation_method::gauss_seidel ? "Gauß-Seidel" : "Jacobi";
100-
const std::string_view perturbation_function_display =
101-
options.pert_func == perturbation_function::f0 ? "f(x,y) = 0" : "f(x,y) = 2 * pi^2 * sin(pi * x) * sin(pi * y)";
102-
const std::string_view termination_display =
103-
options.termination == termination_condition::accuracy ? "Required accuracy" : "Number of iterations";
10498

10599
std::println("Calculation time: {:0.6f} s", time);
106100
std::println("Memory usage: {:0.6f} MiB", memory_consumption);
107-
std::println("Calculation method: {:s}", calculation_method_display);
101+
std::println("Calculation method: {:s}", options.method);
108102
std::println("Interlines: {:d}", options.interlines);
109-
std::println("Perturbation function: {:s}", perturbation_function_display);
110-
std::println("Termination: {:s}", termination_display);
103+
std::println("Perturbation function: {:s}", options.pert_func);
104+
std::println("Termination: {:s}", options.termination);
111105
std::println("Number of iterations: {:d}", results.stat_iteration);
112106
std::println("Residuum: {:e}", results.stat_accuracy);
113107
}

0 commit comments

Comments
 (0)