Skip to content

Commit 917002c

Browse files
authored
fix: round quantiles when target is integer (#7)
1 parent 91e0c4d commit 917002c

File tree

2 files changed

+14
-11
lines changed

2 files changed

+14
-11
lines changed

README.md

+10-10
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,11 @@ When the input data is a pandas DataFrame, the output is also a pandas DataFrame
5151

5252
| house_id | 0.025 | 0.05 | 0.1 | 0.9 | 0.95 | 0.975 |
5353
|-----------:|--------:|-------:|-------:|-------:|-------:|--------:|
54-
| 1357 | 114783 | 120894 | 131618 | 175760 | 188051 | 205448 |
55-
| 2367 | 67416 | 80073 | 86753 | 117854 | 127582 | 142321 |
56-
| 2822 | 119422 | 132047 | 138724 | 178526 | 197246 | 214205 |
57-
| 2126 | 94030 | 99849 | 110891 | 150249 | 164703 | 182528 |
58-
| 1544 | 68996 | 81516 | 88231 | 121774 | 132425 | 147110 |
54+
| 1357 | 114784 | 120894 | 131618 | 175761 | 188052 | 205449 |
55+
| 2367 | 67417 | 80074 | 86754 | 117854 | 127583 | 142322 |
56+
| 2822 | 119423 | 132048 | 138725 | 178526 | 197246 | 214206 |
57+
| 2126 | 94031 | 99850 | 110891 | 150249 | 164703 | 182528 |
58+
| 1544 | 68996 | 81516 | 88232 | 121774 | 132425 | 147110 |
5959

6060
Let's visualize the predicted quantiles on the test set:
6161

@@ -116,11 +116,11 @@ When the input data is a pandas DataFrame, the output is also a pandas DataFrame
116116

117117
| house_id | 0.025 | 0.975 |
118118
|-----------:|--------:|--------:|
119-
| 1357 | 107202 | 206290 |
120-
| 2367 | 66665 | 146004 |
121-
| 2822 | 115591 | 220314 |
122-
| 2126 | 85288 | 183037 |
123-
| 1544 | 67889 | 150646 |
119+
| 1357 | 107203 | 206290 |
120+
| 2367 | 66665 | 146005 |
121+
| 2822 | 115592 | 220315 |
122+
| 2126 | 85288 | 183038 |
123+
| 1544 | 67890 | 150646 |
124124

125125
## Contributing
126126

src/conformal_tights/_conformal_coherent_quantile_regressor.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,10 @@ def predict_quantiles(
275275
Δŷ_quantiles = Δŷ_quantiles[
276276
np.arange(Δŷ_quantiles.shape[0]), :, np.argmin(dispersion, axis=-1)
277277
]
278-
ŷ_quantiles: FloatMatrix[F] = (ŷ[:, np.newaxis] + Δŷ_quantiles).astype(self.y_dtype_)
278+
ŷ_quantiles: FloatMatrix[F] = ŷ[:, np.newaxis] + Δŷ_quantiles
279+
if self.y_is_integer_:
280+
ŷ_quantiles = np.round(ŷ_quantiles)
281+
ŷ_quantiles = ŷ_quantiles.astype(self.y_dtype_)
279282
# Convert ŷ_quantiles to a pandas DataFrame if X is a pandas DataFrame.
280283
if hasattr(X, "dtypes") and hasattr(X, "index"):
281284
try:

0 commit comments

Comments
 (0)