We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent ef21071 commit 1ece941Copy full SHA for 1ece941
crossfit/backend/torch/hf/memory_curve_utils.py
@@ -12,13 +12,14 @@
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
15
+from __future__ import annotations
16
17
import joblib
18
import numpy as np
19
import torch
20
+import transformers
21
from sklearn.linear_model import LinearRegression
22
from tqdm import tqdm
-from transformers import PreTrainedModel
23
24
from crossfit.utils.model_adapter import adapt_model_input
25
from crossfit.utils.torch_utils import (
@@ -29,7 +30,7 @@
29
30
31
32
def fit_memory_estimate_curve(
- model: PreTrainedModel,
33
+ model: "transformers.PreTrainedModel",
34
path_or_name: str,
35
start_batch_size: int = 1,
36
end_batch_size: int = 2048,
0 commit comments