Skip to content

Commit 9e5eb87

Browse files
Make Native Chat Handlers Overridable via Entry Points (#1249)
* make native chat handlers customizable * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove-ci-error * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add-disabled-check-and-sort-entrypoints * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Refactor Chat Handlers to Simplify Initialization (#1257) * simplify-entrypoints-loading * fix-lint * fix-tests * add-retriever-typing * remove-retriever-from-base * fix-circular-import(ydoc-import) * fix-tests * fix-type-check-failure * refactor-retriever-init * Allow chat handlers to be initialized in any order (#1268) * lazy-initialize-retriever * add-retriever-property * rebase-into-main * update-docs * update-documentation --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 45ab90f commit 9e5eb87

File tree

3 files changed

+79
-37
lines changed

3 files changed

+79
-37
lines changed

docs/source/developers/index.md

+32
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,38 @@ custom = "custom_package:CustomChatHandler"
461461
Then, install your package so that Jupyter AI adds custom chat handlers
462462
to the existing chat handlers.
463463

464+
## Overriding or disabling a built-in slash command
465+
466+
You can define a custom implementation of a built-in slash command by following the steps above on building a custom slash command. This will involve creating and installing a new package. Then, to override a chat handler with this custom implementation, provide an entry point with a name matching the ID of the chat handler to override.
467+
468+
For example, to override `/ask` with a `CustomAskChatHandler` class, add the following to `pyproject.toml` and re-install the new package:
469+
470+
```python
471+
[project.entry-points."jupyter_ai.chat_handlers"]
472+
ask = "<module-path>:CustomAskChatHandler"
473+
```
474+
475+
You can also disable a built-in slash command by providing a mostly-empty chat handler with `disabled = True`. For example, to disable the default `ask` chat handler of Jupyter AI, define a new `DisabledAskChatHandler`:
476+
477+
```python
478+
class DisabledAskChatHandler:
479+
id = 'ask'
480+
disabled = True
481+
```
482+
483+
Then, provide this as an entry point in your custom package:
484+
```python
485+
[project.entry-points."jupyter_ai.chat_handlers"]
486+
ask = "<module-path>:DisabledAskChatHandler"
487+
```
488+
489+
Finally, re-install your custom package. After starting JupyterLab, the `/ask` command should now be disabled.
490+
491+
:::{warning}
492+
:name: entry-point-name
493+
To override or disable a built-in slash command via an entry point, the name of the entry point (left of the `=` symbol) must match the chat handler ID exactly.
494+
:::
495+
464496
## Streaming output from custom slash commands
465497

466498
Jupyter AI supports streaming output in the chat session. When a response is

packages/jupyter-ai/jupyter_ai/extension.py

+40-37
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,7 @@
1919
from tornado.web import StaticFileHandler
2020
from traitlets import Integer, List, Unicode
2121

22-
from .chat_handlers import (
23-
AskChatHandler,
24-
BaseChatHandler,
25-
DefaultChatHandler,
26-
GenerateChatHandler,
27-
HelpChatHandler,
28-
LearnChatHandler,
29-
)
22+
from .chat_handlers.base import BaseChatHandler
3023
from .completions.handlers import DefaultInlineCompletionHandler
3124
from .config_manager import ConfigManager
3225
from .constants import BOT
@@ -459,36 +452,36 @@ def _init_chat_handlers(self, ychat: YChat) -> Dict[str, BaseChatHandler]:
459452
assert self.serverapp
460453

461454
eps = entry_points()
462-
chat_handler_eps = eps.select(group="jupyter_ai.chat_handlers")
455+
all_chat_handler_eps = eps.select(group="jupyter_ai.chat_handlers")
456+
457+
# Override native chat handlers if duplicates are present
458+
sorted_eps = sorted(
459+
all_chat_handler_eps, key=lambda ep: ep.dist.name != "jupyter_ai"
460+
)
461+
seen = {}
462+
for ep in sorted_eps:
463+
seen[ep.name] = ep
464+
chat_handler_eps = list(seen.values())
465+
463466
chat_handlers: Dict[str, BaseChatHandler] = {}
464467
llm_chat_memory = YChatHistory(ychat, k=self.default_max_chat_history)
465468

466469
chat_handler_kwargs = {
467470
"log": self.log,
468-
"config_manager": self.settings["jai_config_manager"],
469-
"model_parameters": self.settings["model_parameters"],
471+
"config_manager": self.settings.get("jai_config_manager"),
472+
"model_parameters": self.settings.get("model_parameters"),
470473
"llm_chat_memory": llm_chat_memory,
471474
"root_dir": self.serverapp.root_dir,
472-
"dask_client_future": self.settings["dask_client_future"],
475+
"dask_client_future": self.settings.get("dask_client_future"),
473476
"preferred_dir": self.serverapp.contents_manager.preferred_dir,
474477
"help_message_template": self.help_message_template,
475478
"chat_handlers": chat_handlers,
476-
"context_providers": self.settings["jai_context_providers"],
477-
"message_interrupted": self.settings["jai_message_interrupted"],
479+
"context_providers": self.settings.get("jai_context_providers"),
480+
"message_interrupted": self.settings.get("jai_message_interrupted"),
478481
"ychat": ychat,
479482
"log_dir": self.error_logs_dir,
480483
}
481484

482-
default_chat_handler = DefaultChatHandler(**chat_handler_kwargs)
483-
generate_chat_handler = GenerateChatHandler(**chat_handler_kwargs)
484-
learn_chat_handler = LearnChatHandler(**chat_handler_kwargs)
485-
ask_chat_handler = AskChatHandler(**chat_handler_kwargs)
486-
487-
chat_handlers["default"] = default_chat_handler
488-
chat_handlers["/ask"] = ask_chat_handler
489-
chat_handlers["/generate"] = generate_chat_handler
490-
chat_handlers["/learn"] = learn_chat_handler
491-
492485
slash_command_pattern = r"^[a-zA-Z0-9_]+$"
493486
for chat_handler_ep in chat_handler_eps:
494487
try:
@@ -500,21 +493,34 @@ def _init_chat_handlers(self, ychat: YChat) -> Dict[str, BaseChatHandler]:
500493
)
501494
continue
502495

496+
# Skip disabled entrypoints
497+
ep_disabled = getattr(chat_handler, "disabled", False)
498+
if ep_disabled:
499+
self.log.warn(
500+
f"Skipping registration of chat handler `{chat_handler_ep.name}` as it is explicitly disabled."
501+
)
502+
continue
503+
503504
if chat_handler.routing_type.routing_method == "slash_command":
504-
# Each slash ID must be used only once.
505-
# Slash IDs may contain only alphanumerics and underscores.
506-
slash_id = chat_handler.routing_type.slash_id
505+
# Set default slash_id if it's the default chat handler
506+
slash_id = (
507+
"default"
508+
if chat_handler.id == "default"
509+
else chat_handler.routing_type.slash_id
510+
)
507511

508-
if slash_id is None:
512+
if not slash_id:
509513
self.log.error(
510514
f"Handler `{chat_handler_ep.name}` has an invalid slash command "
511515
+ f"`None`; only the default chat handler may use this"
512516
)
513517
continue
514518

515-
# Validate slash ID (/^[A-Za-z0-9_]+$/)
519+
# Validate the slash command name
516520
if re.match(slash_command_pattern, slash_id):
517-
command_name = f"/{slash_id}"
521+
command_name = (
522+
"default" if slash_id == "default" else f"/{slash_id}"
523+
)
518524
else:
519525
self.log.error(
520526
f"Handler `{chat_handler_ep.name}` has an invalid slash command "
@@ -524,20 +530,17 @@ def _init_chat_handlers(self, ychat: YChat) -> Dict[str, BaseChatHandler]:
524530
continue
525531

526532
if command_name in chat_handlers:
527-
self.log.error(
528-
f"Unable to register chat handler `{chat_handler.id}` because command `{command_name}` already has a handler"
533+
self.log.warn(
534+
f"Overriding existing handler `{command_name}` with `{chat_handler.id}`."
529535
)
530-
continue
531536

532-
# The entry point is a class; we need to instantiate the class to send messages to it
537+
# Registering chat handler
533538
chat_handlers[command_name] = chat_handler(**chat_handler_kwargs)
539+
534540
self.log.info(
535541
f"Registered chat handler `{chat_handler.id}` with command `{command_name}`."
536542
)
537543

538-
# Make help always appear as the last command
539-
chat_handlers["/help"] = HelpChatHandler(**chat_handler_kwargs)
540-
541544
return chat_handlers
542545

543546
def _init_context_providers(self):

packages/jupyter-ai/pyproject.toml

+7
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,13 @@ dynamic = ["version", "description", "authors", "urls", "keywords"]
4343
[project.entry-points."jupyter_ai.default_tasks"]
4444
core_default_tasks = "jupyter_ai:tasks"
4545

46+
[project.entry-points."jupyter_ai.chat_handlers"]
47+
default = "jupyter_ai.chat_handlers.default:DefaultChatHandler"
48+
ask = "jupyter_ai.chat_handlers.ask:AskChatHandler"
49+
generate = "jupyter_ai.chat_handlers.generate:GenerateChatHandler"
50+
learn = "jupyter_ai.chat_handlers.learn:LearnChatHandler"
51+
help = "jupyter_ai.chat_handlers.help:HelpChatHandler"
52+
4653
[project.optional-dependencies]
4754
test = [
4855
"jupyter-server[test]>=1.6,<3",

0 commit comments

Comments
 (0)