Skip to content
Permalink

Comparing changes

This is a direct comparison between two commits made in this repository or its related repositories. View the default comparison for this range or learn more about diff comparisons.

Open a pull request

Create a new pull request by comparing changes across two branches. If you need to, you can also . Learn more about diff comparisons here.
base repository: EleutherAI/elk
Failed to load repositories. Confirm that selected base ref is valid, then try again.
Loading
base: 6803fc5f4157678da66179edb04afcd8d8a2eab3
Choose a base ref
..
head repository: EleutherAI/elk
Failed to load repositories. Confirm that selected head ref is valid, then try again.
Loading
compare: 5ad74e9c86b31cc760a3075388cb6d5306ae0050
Choose a head ref
Showing with 9 additions and 11 deletions.
  1. +1 −1 .pre-commit-config.yaml
  2. +0 −1 elk/extraction/extraction.py
  3. +3 −3 elk/plotting/visualize.py
  4. +3 −5 elk/promptsource/templates.py
  5. +2 −1 elk/training/platt_scaling.py
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -15,7 +15,7 @@ repos:
hooks:
- id: black
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: 'v0.9.9'
rev: 'v0.9.10'
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
1 change: 0 additions & 1 deletion elk/extraction/extraction.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""Functions for extracting the hidden states of a model."""

import os
from collections import defaultdict
from dataclasses import InitVar, dataclass, replace
6 changes: 3 additions & 3 deletions elk/plotting/visualize.py
Original file line number Diff line number Diff line change
@@ -78,9 +78,9 @@ def render(
y=dataset_data["auroc_estimate"],
mode="lines",
name=ensemble,
showlegend=(
False if dataset_name != unique_datasets[0] else True
),
showlegend=False
if dataset_name != unique_datasets[0]
else True,
line=dict(color=color_map[ensemble]),
),
row=row,
8 changes: 3 additions & 5 deletions elk/promptsource/templates.py
Original file line number Diff line number Diff line change
@@ -215,11 +215,9 @@ def _escape_pipe(cls, example):
# Replaces any occurrences of the "|||" separator in the example, which
# which will be replaced back after splitting
protected_example = {
key: (
value.replace("|||", cls.pipe_protector)
if isinstance(value, str)
else value
)
key: value.replace("|||", cls.pipe_protector)
if isinstance(value, str)
else value
for key, value in example.items()
}
return protected_example
3 changes: 2 additions & 1 deletion elk/training/platt_scaling.py
Original file line number Diff line number Diff line change
@@ -12,7 +12,8 @@ class PlattMixin(ABC):
scale: nn.Parameter

@abstractmethod
def __call__(self, *args: Any, **kwds: Any) -> Any: ...
def __call__(self, *args: Any, **kwds: Any) -> Any:
...

def platt_scale(self, labels: Tensor, hiddens: Tensor, max_iter: int = 100):
"""Fit the scale and bias terms to data with LBFGS.