Skip to content

Commit 1ece941

Browse files
committed
Fix cuda context bug caused by PreTrainedModel
Signed-off-by: Vibhu Jawa <[email protected]>
1 parent ef21071 commit 1ece941

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

crossfit/backend/torch/hf/memory_curve_utils.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,14 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from __future__ import annotations
1516

1617
import joblib
1718
import numpy as np
1819
import torch
20+
import transformers
1921
from sklearn.linear_model import LinearRegression
2022
from tqdm import tqdm
21-
from transformers import PreTrainedModel
2223

2324
from crossfit.utils.model_adapter import adapt_model_input
2425
from crossfit.utils.torch_utils import (
@@ -29,7 +30,7 @@
2930

3031

3132
def fit_memory_estimate_curve(
32-
model: PreTrainedModel,
33+
model: "transformers.PreTrainedModel",
3334
path_or_name: str,
3435
start_batch_size: int = 1,
3536
end_batch_size: int = 2048,

0 commit comments

Comments
 (0)