Skip to content

Commit 4a41262

Browse files
committed
Backport PR #1249
1 parent 3817c36 commit 4a41262

File tree

3 files changed

+83
-57
lines changed

3 files changed

+83
-57
lines changed

docs/source/developers/index.md

+32
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,38 @@ custom = "custom_package:CustomChatHandler"
461461
Then, install your package so that Jupyter AI adds custom chat handlers
462462
to the existing chat handlers.
463463

464+
## Overriding or disabling a built-in slash command
465+
466+
You can define a custom implementation of a built-in slash command by following the steps above on building a custom slash command. This will involve creating and installing a new package. Then, to override a chat handler with this custom implementation, provide an entry point with a name matching the ID of the chat handler to override.
467+
468+
For example, to override `/ask` with a `CustomAskChatHandler` class, add the following to `pyproject.toml` and re-install the new package:
469+
470+
```python
471+
[project.entry-points."jupyter_ai.chat_handlers"]
472+
ask = "<module-path>:CustomAskChatHandler"
473+
```
474+
475+
You can also disable a built-in slash command by providing a mostly-empty chat handler with `disabled = True`. For example, to disable the default `ask` chat handler of Jupyter AI, define a new `DisabledAskChatHandler`:
476+
477+
```python
478+
class DisabledAskChatHandler:
479+
id = 'ask'
480+
disabled = True
481+
```
482+
483+
Then, provide this as an entry point in your custom package:
484+
```python
485+
[project.entry-points."jupyter_ai.chat_handlers"]
486+
ask = "<module-path>:DisabledAskChatHandler"
487+
```
488+
489+
Finally, re-install your custom package. After starting JupyterLab, the `/ask` command should now be disabled.
490+
491+
:::{warning}
492+
:name: entry-point-name
493+
To override or disable a built-in slash command via an entry point, the name of the entry point (left of the `=` symbol) must match the chat handler ID exactly.
494+
:::
495+
464496
## Streaming output from custom slash commands
465497

466498
Jupyter AI supports streaming output in the chat session. When a response is

packages/jupyter-ai/jupyter_ai/extension.py

+44-57
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,7 @@
1111
from tornado.web import StaticFileHandler
1212
from traitlets import Dict, Integer, List, Unicode
1313

14-
from .chat_handlers import (
15-
AskChatHandler,
16-
ClearChatHandler,
17-
DefaultChatHandler,
18-
ExportChatHandler,
19-
FixChatHandler,
20-
GenerateChatHandler,
21-
HelpChatHandler,
22-
LearnChatHandler,
23-
)
14+
from .chat_handlers.base import BaseChatHandler
2415
from .completions.handlers import DefaultInlineCompletionHandler
2516
from .config_manager import ConfigManager
2617
from .context_providers import BaseCommandContextProvider, FileContextProvider
@@ -316,9 +307,7 @@ def _show_help_message(self):
316307
# call `send_help_message()` on any instance of `BaseChatHandler`. The
317308
# `default` chat handler should always exist, so we reference that
318309
# object when calling `send_help_message()`.
319-
default_chat_handler: DefaultChatHandler = self.settings["jai_chat_handlers"][
320-
"default"
321-
]
310+
default_chat_handler = self.settings["jai_chat_handlers"]["default"]
322311
default_chat_handler.send_help_message()
323312

324313
async def _get_dask_client(self):
@@ -350,45 +339,34 @@ async def _stop_extension(self):
350339

351340
def _init_chat_handlers(self):
352341
eps = entry_points()
353-
chat_handler_eps = eps.select(group="jupyter_ai.chat_handlers")
354-
chat_handlers = {}
342+
all_chat_handler_eps = eps.select(group="jupyter_ai.chat_handlers")
343+
344+
# Override native chat handlers if duplicates are present
345+
sorted_eps = sorted(
346+
all_chat_handler_eps, key=lambda ep: ep.dist.name != "jupyter_ai"
347+
)
348+
seen = {}
349+
for ep in sorted_eps:
350+
seen[ep.name] = ep
351+
chat_handler_eps = list(seen.values())
352+
353+
chat_handlers: Dict[str, BaseChatHandler] = {}
354+
355355
chat_handler_kwargs = {
356356
"log": self.log,
357-
"config_manager": self.settings["jai_config_manager"],
358-
"model_parameters": self.settings["model_parameters"],
359-
"root_chat_handlers": self.settings["jai_root_chat_handlers"],
360-
"chat_history": self.settings["chat_history"],
357+
"config_manager": self.settings.get("jai_config_manager"),
358+
"model_parameters": self.settings.get("model_parameters"),
361359
"llm_chat_memory": self.settings["llm_chat_memory"],
362360
"root_dir": self.serverapp.root_dir,
363-
"dask_client_future": self.settings["dask_client_future"],
361+
"dask_client_future": self.settings.get("dask_client_future"),
364362
"preferred_dir": self.serverapp.contents_manager.preferred_dir,
365363
"help_message_template": self.help_message_template,
366364
"chat_handlers": chat_handlers,
367-
"context_providers": self.settings["jai_context_providers"],
368-
"message_interrupted": self.settings["jai_message_interrupted"],
365+
"context_providers": self.settings.get("jai_context_providers"),
366+
"message_interrupted": self.settings.get("jai_message_interrupted"),
369367
"log_dir": self.error_logs_dir,
370368
}
371369

372-
default_chat_handler = DefaultChatHandler(**chat_handler_kwargs)
373-
generate_chat_handler = GenerateChatHandler(**chat_handler_kwargs)
374-
clear_chat_handler = ClearChatHandler(**chat_handler_kwargs)
375-
learn_chat_handler = LearnChatHandler(**chat_handler_kwargs)
376-
# Store learn_chat_handler before initializing AskChatHandler,
377-
# as it is required for initializing the Retriever.
378-
chat_handlers["/learn"] = learn_chat_handler
379-
ask_chat_handler = AskChatHandler(**chat_handler_kwargs)
380-
381-
export_chat_handler = ExportChatHandler(**chat_handler_kwargs)
382-
383-
fix_chat_handler = FixChatHandler(**chat_handler_kwargs)
384-
385-
chat_handlers["default"] = default_chat_handler
386-
chat_handlers["/ask"] = ask_chat_handler
387-
chat_handlers["/clear"] = clear_chat_handler
388-
chat_handlers["/generate"] = generate_chat_handler
389-
chat_handlers["/export"] = export_chat_handler
390-
chat_handlers["/fix"] = fix_chat_handler
391-
392370
slash_command_pattern = r"^[a-zA-Z0-9_]+$"
393371
for chat_handler_ep in chat_handler_eps:
394372
try:
@@ -400,21 +378,34 @@ def _init_chat_handlers(self):
400378
)
401379
continue
402380

381+
# Skip disabled entrypoints
382+
ep_disabled = getattr(chat_handler, "disabled", False)
383+
if ep_disabled:
384+
self.log.warn(
385+
f"Skipping registration of chat handler `{chat_handler_ep.name}` as it is explicitly disabled."
386+
)
387+
continue
388+
403389
if chat_handler.routing_type.routing_method == "slash_command":
404-
# Each slash ID must be used only once.
405-
# Slash IDs may contain only alphanumerics and underscores.
406-
slash_id = chat_handler.routing_type.slash_id
390+
# Set default slash_id if it's the default chat handler
391+
slash_id = (
392+
"default"
393+
if chat_handler.id == "default"
394+
else chat_handler.routing_type.slash_id
395+
)
407396

408-
if slash_id is None:
397+
if not slash_id:
409398
self.log.error(
410399
f"Handler `{chat_handler_ep.name}` has an invalid slash command "
411400
+ f"`None`; only the default chat handler may use this"
412401
)
413402
continue
414403

415-
# Validate slash ID (/^[A-Za-z0-9_]+$/)
404+
# Validate the slash command name
416405
if re.match(slash_command_pattern, slash_id):
417-
command_name = f"/{slash_id}"
406+
command_name = (
407+
"default" if slash_id == "default" else f"/{slash_id}"
408+
)
418409
else:
419410
self.log.error(
420411
f"Handler `{chat_handler_ep.name}` has an invalid slash command "
@@ -424,22 +415,18 @@ def _init_chat_handlers(self):
424415
continue
425416

426417
if command_name in chat_handlers:
427-
self.log.error(
428-
f"Unable to register chat handler `{chat_handler.id}` because command `{command_name}` already has a handler"
418+
self.log.warn(
419+
f"Overriding existing handler `{command_name}` with `{chat_handler.id}`."
429420
)
430-
continue
431421

432-
# The entry point is a class; we need to instantiate the class to send messages to it
422+
# Registering chat handler
433423
chat_handlers[command_name] = chat_handler(**chat_handler_kwargs)
424+
434425
self.log.info(
435426
f"Registered chat handler `{chat_handler.id}` with command `{command_name}`."
436427
)
437428

438-
# Make help always appear as the last command
439-
chat_handlers["/help"] = HelpChatHandler(**chat_handler_kwargs)
440-
441-
# bind chat handlers to settings
442-
self.settings["jai_chat_handlers"] = chat_handlers
429+
return chat_handlers
443430

444431
def _init_context_provders(self):
445432
eps = entry_points()

packages/jupyter-ai/pyproject.toml

+7
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,13 @@ dynamic = ["version", "description", "authors", "urls", "keywords"]
4242
[project.entry-points."jupyter_ai.default_tasks"]
4343
core_default_tasks = "jupyter_ai:tasks"
4444

45+
[project.entry-points."jupyter_ai.chat_handlers"]
46+
default = "jupyter_ai.chat_handlers.default:DefaultChatHandler"
47+
ask = "jupyter_ai.chat_handlers.ask:AskChatHandler"
48+
generate = "jupyter_ai.chat_handlers.generate:GenerateChatHandler"
49+
learn = "jupyter_ai.chat_handlers.learn:LearnChatHandler"
50+
help = "jupyter_ai.chat_handlers.help:HelpChatHandler"
51+
4552
[project.optional-dependencies]
4653
test = [
4754
"jupyter-server[test]>=1.6,<3",

0 commit comments

Comments
 (0)