Skip to content

Commit

Permalink
linter
Browse files Browse the repository at this point in the history
  • Loading branch information
hualxie committed Feb 21, 2025
1 parent 0995c8d commit 93ae4c4
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 13 deletions.
File renamed without changes.
28 changes: 15 additions & 13 deletions examples/bge/user_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ def __init__(self, model, session):
self.model = model
self.session = session
self.tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-small-en-v1.5")
self.total = 0
self.max_len = 0

def encode(self, corpus: List, **kwargs):
if self.model.framework == Framework.ONNX:
Expand All @@ -44,28 +42,30 @@ def encode(self, corpus: List, **kwargs):
"attention_mask": encoded_input.attention_mask,
"token_type_ids": encoded_input.token_type_ids,
}
self.max_len = max(self.max_len, model_inputs["input_ids"].shape[1])
print(self.max_len)
with torch.no_grad():
model_output = self.model.run_session(self.session, model_inputs)
model_output = model_output.last_hidden_state.numpy()
# select the last hidden state of the first token (i.e., [CLS]) as the sentence embedding.
model_output = model_output[:, 0, :]
self.total += len(corpus)
print(self.total)
return model_output
return model_output[:, 0, :]

Check failure

Code scanning / CodeQL

Potentially uninitialized local variable Error

Local variable 'model_output' may be used before it is initialized.


def eval_accuracy(model: OliveModelHandler, device, execution_providers, tasks):
sess = model.prepare_session(inference_settings=None, device=device, execution_providers=execution_providers)

evaluation = mteb.MTEB(tasks=tasks)
oliveEncoder = OliveEncoder(model, sess)
results = evaluation.run(oliveEncoder, output_folder=None)
olive_encoder = OliveEncoder(model, sess)
results = evaluation.run(olive_encoder, output_folder=None)
return results[0].scores["test"][0]["main_score"]


if __name__ == "__main__":
import logging
import sys

logger = logging.getLogger("bge")
logger.addHandler(logging.StreamHandler(sys.stdout))
logger.setLevel(logging.INFO)

# Greedy search for the best combination of ops to quantize
all_ops = [
"Mul",
Expand All @@ -87,15 +87,17 @@ def eval_accuracy(model: OliveModelHandler, device, execution_providers, tasks):
target_accuracy = 0.8
with Path("bge-small-en-v1.5.json").open() as fin:
olive_config = json.load(fin)
for i, op in enumerate(all_ops):
for op in all_ops:
if op in olive_config["passes"]["OnnxQuantization"]["op_types_to_quantize"]:
continue
olive_config["passes"]["OnnxQuantization"]["op_types_to_quantize"].append(op)
result = olive_run(olive_config)
footprint: Footprint = next(iter(result.values()))
node: FootprintNode = next(iter(footprint.nodes.values()))
accuracy = node.metrics.value["accuracy-accuracy_custom"].value
print(f"Ops: {olive_config["passes"]["OnnxQuantization"]["op_types_to_quantize"]} Accuracy: {accuracy}")
logger.info(
"Ops: %s Accuracy: %f", olive_config["passes"]["OnnxQuantization"]["op_types_to_quantize"], accuracy
)
if accuracy < target_accuracy:
olive_config["passes"]["OnnxQuantization"]["op_types_to_quantize"].remove(op)
print(f"Used Ops: {olive_config["passes"]["OnnxQuantization"]["op_types_to_quantize"]}")
logger.info("Final Ops: %s", olive_config["passes"]["OnnxQuantization"]["op_types_to_quantize"])

0 comments on commit 93ae4c4

Please sign in to comment.