Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Perplexity colors extension updates #6764

Open
wants to merge 1 commit into
base: dev
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 55 additions & 25 deletions extensions/perplexity_colors/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,32 +96,50 @@ def logits_processor_modifier(logits_processor_list, input_ids):
logits_processor_list.append(ppl_logits_processor)


def get_last_token(text, tokens_list, token_ids_list, token_probs_list):
for token, token_id, prob in zip(tokens_list, token_ids_list, token_probs_list):
if text.strip().endswith(token.strip()): # Whitespace could be a problem
return token, token_id, prob
# Unknown?
print("Last token not found in list:", tokens_list)
return '', -1, 0.0


def output_modifier(text):
global ppl_logits_processor
#t0 = time.time()
original_text = text

if not params['active'] or ppl_logits_processor is None:
return text

# Space at the beginning to account for tokenization spaces...
text = ' ' + html.unescape(text)

# TODO: It's probably more efficient to do this above rather than modifying all these lists
# Remove last element of perplexities_list, top_token_ids_list, top_tokens_list, top_probs_list since everything is off by one because this extension runs before generation
perplexities = ppl_logits_processor.perplexities_list[:-1]
top_token_ids_list = ppl_logits_processor.top_token_ids_list[:-1]
perplexities = ppl_logits_processor.perplexities_list
top_token_ids_list = ppl_logits_processor.top_token_ids_list
top_tokens_list = [[shared.tokenizer.decode(token_id) for token_id in top_token_ids[0]] for top_token_ids in top_token_ids_list]
top_probs_list = ppl_logits_processor.top_probs_list[:-1]
top_probs_list = ppl_logits_processor.top_probs_list
# Remove first element of generated_token_ids, generated_tokens, selected_probs because they are for the last token of the prompt
gen_token_ids = ppl_logits_processor.generated_token_ids[1:]
# Add last sampled token, if possible (it could be past the end of the top 5 list)
last_token, last_token_id, last_prob = get_last_token(text, top_tokens_list[-1], top_token_ids_list[-1][0], top_probs_list[-1][0])
if last_token_id != -1:
gen_token_ids.append(last_token_id)
gen_tokens = [shared.tokenizer.decode(token_id) for token_id in gen_token_ids]
sel_probs = ppl_logits_processor.selected_probs[1:]
if last_token_id != -1:
sel_probs.append(last_prob)

end_part = '</div></div>' if params['probability_dropdown'] else '</span>' # Helps with finding the index after replacing part of the text.

# Initial space added to deal with some tokenizers...
# Used to find where the message started generating, for working with "continue" generations
# Doesn't work for longer messages... Not sure how I should handle this
full_msg = shared.tokenizer.decode([token_id for token_id in gen_token_ids[:-1]]).strip()
# Space at the beginning to account for tokenization spaces...
text = ' ' + html.unescape(text)

# There was an issue with tab lengths being off by one...
# Seems like it might be model-dependent...
#text = re.sub(r'( {3,})', r'\1 ', text)
Expand All @@ -137,6 +155,7 @@ def output_modifier(text):
#i = 0
# Add token index for ability to regenerate from there
nonwhitespace_token_found = False
missing_token_count = 0
for index, token, prob, ppl, top_tokens, top_probs in zip(range(len(gen_tokens)), gen_tokens, sel_probs, perplexities, top_tokens_list, top_probs_list):
# Somehow this works without issues, but not sure how...
if not nonwhitespace_token_found and token.strip() == '':
Expand All @@ -153,14 +172,20 @@ def output_modifier(text):
color = probability_color_scale(prob)
if token.strip() in text[i:]:
if params['probability_dropdown']:
text = text[:i] + text[i:].replace(token.replace('\n', ''), add_dropdown_html(token, index, color, top_tokens, top_probs[0], ppl), 1)
text = text[:i] + text[i:].replace(token.replace('\n', ''), add_dropdown_html(token, index, i, color, top_tokens, top_probs[0], ppl), 1)
else:
text = text[:i] + text[i:].replace(token.replace('\n', ''), add_color_html(token, color), 1)

# This might be slightly inefficient
i += text[i:].find(end_part) + len(end_part)
else:
missing_token_count += 1
print('Missing token:', token, '...', text[i:i+20])
# If there are any missing tokens, then either the tokenization was off, or this is the start of a conversation, or something else went wrong
if missing_token_count > 5:
print("Canceling token coloring...")
return original_text


# Use full perplexity list for calculating the average here.
# Fix issue with mean of empty slice
Expand Down Expand Up @@ -236,11 +261,11 @@ def add_color_html(token, color):
# I think the issue is from HTML elements taking up space in the visible history, and things like history deepcopy add latency proportional to the size of the history.
# Potential solution is maybe to modify the main generation code to send just the internal text and not the visible history, to avoid moving too much around.
# I wonder if we can also avoid using deepcopy here.
def add_dropdown_html(token, index, color, top_tokens, top_probs, perplexity=0):
def add_dropdown_html(token, index, msg_position, color, top_tokens, top_probs, perplexity=0):
#print("Token:", token, token.isspace(), '\n' in token or '\r' in token)
output = ''
# Use the repr to get characters like \n visible. Exclude the quotes around it
output += f'<div class="hoverable" id="tok_{index}"><span style="color: #{color}">{html.escape(repr(token)[1:-1])}</span><div class="dropdown"><table class="dropdown-content"><tbody>'
output += f'<div class="hoverable" name="tok_{index}_{msg_position}"><span style="color: #{color}">{html.escape(repr(token)[1:-1])}</span><div class="dropdown"><table class="dropdown-content"><tbody>'
for i, token_option, prob in zip(range(len(top_tokens)), top_tokens, top_probs):
# TODO: Bold for selected token?
# Using divs prevented the problem of divs inside spans causing issues.
Expand All @@ -249,7 +274,7 @@ def add_dropdown_html(token, index, color, top_tokens, top_probs, perplexity=0):
row_color = probability_color_scale(prob)
row_class = ' class="selected"' if token_option == token else ''
# This time we want to include the quotes around it so that we can see where the spaces are.
output += f'<tr{row_class}><td id="opt_{index}_{i}" style="color: #{row_color}">{html.escape(repr(token_option))}</td><td style="color: #{row_color}">{prob:.4f}</td></tr>'
output += f'<tr{row_class}><td name="opt_{index}_{i}_{msg_position}" style="color: #{row_color}">{html.escape(repr(token_option))}</td><td style="color: #{row_color}">{prob:.4f}</td></tr>'
if perplexity != 0:
ppl_color = perplexity_color_scale(perplexity)
output += f'<tr><td>Perplexity:</td><td style="color: #{ppl_color}">{perplexity:.4f}</td></tr>'
Expand Down Expand Up @@ -324,11 +349,12 @@ def custom_js():
// Note that this will only work as intended on the last agent message
document.addEventListener("click", async function(event) {
//console.log(event.target);
const id = event.target.id;
if (id.includes("opt_")) {
const id_parts = id.split("_");
const token_index = id_parts[1];
const option_index = id_parts[2];
const name = event.target.getAttribute("name");
if (name != null && name.includes("opt_")) {
const name_parts = name.split("_");
const token_index = name_parts[1];
const option_index = name_parts[2];
const msg_pos = name_parts[3];
// Exclude the quotes and convert newlines... Not sure about the newlines though
// TODO: Seems like continuing generation from a newline causes problems whether you add it or not!
const token_string = event.target.innerHTML.substring(1, event.target.innerHTML.length-1).replace(new RegExp(String.fromCharCode(92)+String.fromCharCode(92)+"r", "g"), '').replace(new RegExp(String.fromCharCode(92)+String.fromCharCode(92)+"n", "g"), '');
Expand All @@ -341,17 +367,21 @@ def custom_js():
var msg_part = msg_parts[i];
if (msg_part.nodeType === Node.ELEMENT_NODE) {
if (msg_part.nodeName == "DIV") {
var current_token_index = msg_part.id.split("_")[1];
if (current_token_index == token_index) {
// Use the replacement token
// TODO: Don't have access to the tokenizer here, and sometimes there needs to be a space added before this token
msg_text += token_string //.replace(new RegExp(String.fromCharCode(92)+String.fromCharCode(92)+"r", "g"), '').replace(new RegExp(String.fromCharCode(92)+String.fromCharCode(92)+"n", "g"), '');
break;
}
else {
// Replace here or at the end?
var text = msg_part.firstChild.innerHTML.replace(new RegExp(String.fromCharCode(92)+String.fromCharCode(92)+"r", "g"), '').replace(new RegExp(String.fromCharCode(92)+String.fromCharCode(92)+"n", "g"), '')
msg_text += text;
msg_part_name = msg_part.getAttribute("name")
if (msg_part_name != null) {
var current_token_index = msg_part_name.split("_")[1];
var current_message_pos = msg_part_name.split("_")[2];
if (current_token_index == token_index && current_message_pos == msg_pos) {
// Use the replacement token
// TODO: Don't have access to the tokenizer here, and sometimes there needs to be a space added before this token
msg_text += token_string //.replace(new RegExp(String.fromCharCode(92)+String.fromCharCode(92)+"r", "g"), '').replace(new RegExp(String.fromCharCode(92)+String.fromCharCode(92)+"n", "g"), '');
break;
}
else {
// Replace here or at the end?
var text = msg_part.firstChild.innerHTML.replace(new RegExp(String.fromCharCode(92)+String.fromCharCode(92)+"r", "g"), '').replace(new RegExp(String.fromCharCode(92)+String.fromCharCode(92)+"n", "g"), '')
msg_text += text;
}
}
}
else {
Expand Down