Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backport PR #1249 on branch 2.x (Make Native Chat Handlers Overridable via Entry Points) #1274

Merged
merged 5 commits into from
Mar 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions docs/source/developers/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,38 @@ custom = "custom_package:CustomChatHandler"
Then, install your package so that Jupyter AI adds custom chat handlers
to the existing chat handlers.

## Overriding or disabling a built-in slash command

You can define a custom implementation of a built-in slash command by following the steps above on building a custom slash command. This will involve creating and installing a new package. Then, to override a chat handler with this custom implementation, provide an entry point with a name matching the ID of the chat handler to override.

For example, to override `/ask` with a `CustomAskChatHandler` class, add the following to `pyproject.toml` and re-install the new package:

```python
[project.entry-points."jupyter_ai.chat_handlers"]
ask = "<module-path>:CustomAskChatHandler"
```

You can also disable a built-in slash command by providing a mostly-empty chat handler with `disabled = True`. For example, to disable the default `ask` chat handler of Jupyter AI, define a new `DisabledAskChatHandler`:

```python
class DisabledAskChatHandler:
id = 'ask'
disabled = True
```

Then, provide this as an entry point in your custom package:
```python
[project.entry-points."jupyter_ai.chat_handlers"]
ask = "<module-path>:DisabledAskChatHandler"
```

Finally, re-install your custom package. After starting JupyterLab, the `/ask` command should now be disabled.

:::{warning}
:name: entry-point-name
To override or disable a built-in slash command via an entry point, the name of the entry point (left of the `=` symbol) must match the chat handler ID exactly.
:::

## Streaming output from custom slash commands

Jupyter AI supports streaming output in the chat session. When a response is
Expand Down
100 changes: 46 additions & 54 deletions packages/jupyter-ai/jupyter_ai/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,7 @@
from tornado.web import StaticFileHandler
from traitlets import Dict, Integer, List, Unicode

from .chat_handlers import (
AskChatHandler,
ClearChatHandler,
DefaultChatHandler,
ExportChatHandler,
FixChatHandler,
GenerateChatHandler,
HelpChatHandler,
LearnChatHandler,
)
from .chat_handlers.base import BaseChatHandler
from .completions.handlers import DefaultInlineCompletionHandler
from .config_manager import ConfigManager
from .context_providers import BaseCommandContextProvider, FileContextProvider
Expand Down Expand Up @@ -316,9 +307,7 @@ def _show_help_message(self):
# call `send_help_message()` on any instance of `BaseChatHandler`. The
# `default` chat handler should always exist, so we reference that
# object when calling `send_help_message()`.
default_chat_handler: DefaultChatHandler = self.settings["jai_chat_handlers"][
"default"
]
default_chat_handler = self.settings["jai_chat_handlers"]["default"]
default_chat_handler.send_help_message()

async def _get_dask_client(self):
Expand Down Expand Up @@ -350,45 +339,36 @@ async def _stop_extension(self):

def _init_chat_handlers(self):
eps = entry_points()
chat_handler_eps = eps.select(group="jupyter_ai.chat_handlers")
chat_handlers = {}
all_chat_handler_eps = eps.select(group="jupyter_ai.chat_handlers")

# Override native chat handlers if duplicates are present
sorted_eps = sorted(
all_chat_handler_eps, key=lambda ep: ep.dist.name != "jupyter_ai"
)
seen = {}
for ep in sorted_eps:
seen[ep.name] = ep
chat_handler_eps = list(seen.values())

chat_handlers: Dict[str, BaseChatHandler] = {}

chat_handler_kwargs = {
"log": self.log,
"config_manager": self.settings["jai_config_manager"],
"model_parameters": self.settings["model_parameters"],
"config_manager": self.settings.get("jai_config_manager"),
"model_parameters": self.settings.get("model_parameters"),
"root_chat_handlers": self.settings["jai_root_chat_handlers"],
"chat_history": self.settings["chat_history"],
"llm_chat_memory": self.settings["llm_chat_memory"],
"root_dir": self.serverapp.root_dir,
"dask_client_future": self.settings["dask_client_future"],
"dask_client_future": self.settings.get("dask_client_future"),
"preferred_dir": self.serverapp.contents_manager.preferred_dir,
"help_message_template": self.help_message_template,
"chat_handlers": chat_handlers,
"context_providers": self.settings["jai_context_providers"],
"message_interrupted": self.settings["jai_message_interrupted"],
"context_providers": self.settings.get("jai_context_providers"),
"message_interrupted": self.settings.get("jai_message_interrupted"),
"log_dir": self.error_logs_dir,
}

default_chat_handler = DefaultChatHandler(**chat_handler_kwargs)
generate_chat_handler = GenerateChatHandler(**chat_handler_kwargs)
clear_chat_handler = ClearChatHandler(**chat_handler_kwargs)
learn_chat_handler = LearnChatHandler(**chat_handler_kwargs)
# Store learn_chat_handler before initializing AskChatHandler,
# as it is required for initializing the Retriever.
chat_handlers["/learn"] = learn_chat_handler
ask_chat_handler = AskChatHandler(**chat_handler_kwargs)

export_chat_handler = ExportChatHandler(**chat_handler_kwargs)

fix_chat_handler = FixChatHandler(**chat_handler_kwargs)

chat_handlers["default"] = default_chat_handler
chat_handlers["/ask"] = ask_chat_handler
chat_handlers["/clear"] = clear_chat_handler
chat_handlers["/generate"] = generate_chat_handler
chat_handlers["/export"] = export_chat_handler
chat_handlers["/fix"] = fix_chat_handler

slash_command_pattern = r"^[a-zA-Z0-9_]+$"
for chat_handler_ep in chat_handler_eps:
try:
Expand All @@ -400,21 +380,34 @@ def _init_chat_handlers(self):
)
continue

# Skip disabled entrypoints
ep_disabled = getattr(chat_handler, "disabled", False)
if ep_disabled:
self.log.warn(
f"Skipping registration of chat handler `{chat_handler_ep.name}` as it is explicitly disabled."
)
continue

if chat_handler.routing_type.routing_method == "slash_command":
# Each slash ID must be used only once.
# Slash IDs may contain only alphanumerics and underscores.
slash_id = chat_handler.routing_type.slash_id
# Set default slash_id if it's the default chat handler
slash_id = (
"default"
if chat_handler.id == "default"
else chat_handler.routing_type.slash_id
)

if slash_id is None:
if not slash_id:
self.log.error(
f"Handler `{chat_handler_ep.name}` has an invalid slash command "
+ f"`None`; only the default chat handler may use this"
)
continue

# Validate slash ID (/^[A-Za-z0-9_]+$/)
# Validate the slash command name
if re.match(slash_command_pattern, slash_id):
command_name = f"/{slash_id}"
command_name = (
"default" if slash_id == "default" else f"/{slash_id}"
)
else:
self.log.error(
f"Handler `{chat_handler_ep.name}` has an invalid slash command "
Expand All @@ -424,22 +417,21 @@ def _init_chat_handlers(self):
continue

if command_name in chat_handlers:
self.log.error(
f"Unable to register chat handler `{chat_handler.id}` because command `{command_name}` already has a handler"
self.log.warn(
f"Overriding existing handler `{command_name}` with `{chat_handler.id}`."
)
continue

# The entry point is a class; we need to instantiate the class to send messages to it
# Registering chat handler
chat_handlers[command_name] = chat_handler(**chat_handler_kwargs)

self.log.info(
f"Registered chat handler `{chat_handler.id}` with command `{command_name}`."
)

# Make help always appear as the last command
chat_handlers["/help"] = HelpChatHandler(**chat_handler_kwargs)
# bind chat handlers to settings
self.settings["jai_chat_handlers"] = chat_handlers

# bind chat handlers to settings
self.settings["jai_chat_handlers"] = chat_handlers
return chat_handlers

def _init_context_provders(self):
eps = entry_points()
Expand Down
10 changes: 10 additions & 0 deletions packages/jupyter-ai/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,16 @@ dynamic = ["version", "description", "authors", "urls", "keywords"]
[project.entry-points."jupyter_ai.default_tasks"]
core_default_tasks = "jupyter_ai:tasks"

[project.entry-points."jupyter_ai.chat_handlers"]
default = "jupyter_ai.chat_handlers.default:DefaultChatHandler"
learn = "jupyter_ai.chat_handlers.learn:LearnChatHandler"
ask = "jupyter_ai.chat_handlers.ask:AskChatHandler"
clear = "jupyter_ai.chat_handlers.clear:ClearChatHandler"
generate = "jupyter_ai.chat_handlers.generate:GenerateChatHandler"
export = "jupyter_ai.chat_handlers.export:ExportChatHandler"
fix = "jupyter_ai.chat_handlers.fix:FixChatHandler"
help = "jupyter_ai.chat_handlers.help:HelpChatHandler"

[project.optional-dependencies]
test = [
"jupyter-server[test]>=1.6,<3",
Expand Down
Loading