File tree 1 file changed +4
-8
lines changed
1 file changed +4
-8
lines changed Original file line number Diff line number Diff line change @@ -954,18 +954,14 @@ def forward(
954
954
955
955
layer_kv_cache_dict = {}
956
956
for b_idx , block in enumerate (self .blocks ):
957
- # Added some assert statements
958
- assert isinstance (block , torch .nn .Module )
959
- assert isinstance (block .norm_attn_norm , torch .nn .Module )
960
- attn_block = block .norm_attn_norm .attn if self .blocks_fuse_norm_attn_norm else block .attn
961
- assert isinstance (attn_block , torch .nn .Module )
957
+ attn_block = block .norm_attn_norm .attn if self .blocks_fuse_norm_attn_norm else block .attn # type: ignore
962
958
if attn_block .reuse_kv_layer_idx is not None : # type: ignore
963
- if attn_block .reuse_kv_layer_idx not in layer_kv_cache_dict :
959
+ if attn_block .reuse_kv_layer_idx not in layer_kv_cache_dict : # type: ignore
964
960
raise KeyError (
965
- f'kv cache for layer { block .reuse_kv_layer_idx } not found in { layer_kv_cache_dict = } .' ,
961
+ f'kv cache for layer { block .reuse_kv_layer_idx } not found in { layer_kv_cache_dict = } .' , # type: ignore
966
962
)
967
963
prev_layer_key_value = layer_kv_cache_dict [
968
- attn_block .reuse_kv_layer_idx ]
964
+ attn_block .reuse_kv_layer_idx ] # type: ignore
969
965
else :
970
966
prev_layer_key_value = None
971
967
if output_hidden_states :
You can’t perform that action at this time.
0 commit comments