Skip to content

Commit 84959f9

Browse files
authored
[ENH] Bundle CLI in python package (#3986)
## Description of changes * Add bindings for the CLI. * Currently only using them for the "run" command in the python CLI app. ## Test plan - [x] Tests pass locally with `pytest` for python, `yarn test` for js, `cargo test` for rust ## Documentation Changes N/A
1 parent df0d5d5 commit 84959f9

File tree

10 files changed

+135
-111
lines changed

10 files changed

+135
-111
lines changed

.github/workflows/_python-tests.yml

+24
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,18 @@ jobs:
4444
uses: ./.github/actions/python
4545
with:
4646
python-version: ${{ matrix.python }}
47+
- name: Setup Rust
48+
uses: ./.github/actions/rust
49+
with:
50+
github-token: ${{ github.token }}
51+
- name: Build Rust bindings
52+
uses: PyO3/maturin-action@v1
53+
with:
54+
command: build
55+
sccache: true
56+
- name: Install built wheel
57+
shell: bash
58+
run: pip install --no-index --find-links target/wheels/ chromadb
4759
- name: Test
4860
run: python -m pytest ${{ matrix.parallelized && '-n auto' || '' }} ${{ matrix.test-globs }}
4961
shell: bash
@@ -125,6 +137,18 @@ jobs:
125137
uses: actions/checkout@v4
126138
- name: Set up Python (${{ matrix.python }})
127139
uses: ./.github/actions/python
140+
- name: Setup Rust
141+
uses: ./.github/actions/rust
142+
with:
143+
github-token: ${{ github.token }}
144+
- name: Build Rust bindings
145+
uses: PyO3/maturin-action@v1
146+
with:
147+
command: build
148+
sccache: true
149+
- name: Install built wheel
150+
shell: bash
151+
run: pip install --no-index --find-links target/wheels/ chromadb
128152
- name: Integration Test
129153
run: bin/python-integration-test ${{ matrix.test-globs }}
130154
shell: bash

Cargo.lock

+1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ chroma-sysdb = { path = "rust/sysdb" }
7575
chroma-tracing = { path = "rust/tracing" }
7676
chroma-types = { path = "rust/types" }
7777
chroma-sqlite = { path = "rust/sqlite" }
78+
chroma-cli = { path = "rust/cli" }
7879
mdac = { path = "rust/mdac" }
7980
wal3 = { path = "rust/wal3" }
8081
worker = { path = "rust/worker" }

chromadb/cli/cli.py

+17-58
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Optional
22

3+
import chromadb_rust_bindings
34
from rich.console import Console
45
from rich.progress import Progress, SpinnerColumn, TextColumn
56
import typer.rich_utils
@@ -20,20 +21,16 @@
2021
utils_app = typer.Typer(short_help="Use maintenance utilities")
2122
app.add_typer(utils_app, name="utils")
2223

23-
_logo = """
24-
\033[38;5;069m((((((((( \033[38;5;203m(((((\033[38;5;220m####
25-
\033[38;5;069m(((((((((((((\033[38;5;203m(((((((((\033[38;5;220m#########
26-
\033[38;5;069m(((((((((((((\033[38;5;203m(((((((((((\033[38;5;220m###########
27-
\033[38;5;069m((((((((((((((\033[38;5;203m((((((((((((\033[38;5;220m############
28-
\033[38;5;069m(((((((((((((\033[38;5;203m((((((((((((((\033[38;5;220m#############
29-
\033[38;5;069m(((((((((((((\033[38;5;203m((((((((((((((\033[38;5;220m#############
30-
\033[38;5;069m((((((((((((\033[38;5;203m(((((((((((((\033[38;5;220m##############
31-
\033[38;5;069m((((((((((((\033[38;5;203m((((((((((((\033[38;5;220m##############
32-
\033[38;5;069m((((((((((\033[38;5;203m(((((((((((\033[38;5;220m#############
33-
\033[38;5;069m((((((((\033[38;5;203m((((((((\033[38;5;220m##############
34-
\033[38;5;069m(((((\033[38;5;203m(((( \033[38;5;220m#########\033[0m
3524

36-
"""
25+
def build_cli_args(**kwargs):
26+
args = []
27+
for key, value in kwargs.items():
28+
if isinstance(value, bool):
29+
if value:
30+
args.append(f"--{key}")
31+
elif value is not None:
32+
args.extend([f"--{key}", str(value)])
33+
return args
3734

3835

3936
@app.command() # type: ignore
@@ -44,54 +41,16 @@ def run(
4441
host: Annotated[
4542
Optional[str], typer.Option(help="The host to listen to. Default: localhost")
4643
] = "localhost",
47-
log_path: Annotated[
48-
Optional[str], typer.Option(help="The path to the log file.")
49-
] = "chroma.log",
5044
port: int = typer.Option(8000, help="The port to run the server on."),
51-
test: bool = typer.Option(False, help="Test mode.", show_envvar=False, hidden=True),
5245
) -> None:
5346
"""Run a chroma server"""
54-
console = Console()
55-
56-
print("\033[1m") # Bold logo
57-
print(_logo)
58-
print("\033[1m") # Bold
59-
print("Running Chroma")
60-
print("\033[0m") # Reset
61-
62-
console.print(f"[bold]Saving data to:[/bold] [green]{path}[/green]")
63-
console.print(
64-
f"[bold]Connect to chroma at:[/bold] [green]http://{host}:{port}[/green]"
65-
)
66-
console.print(
67-
"[bold]Getting started guide[/bold]: [blue]https://docs.trychroma.com/getting-started[/blue]\n\n"
68-
)
69-
70-
# set ENV variable for PERSIST_DIRECTORY to path
71-
os.environ["IS_PERSISTENT"] = "True"
72-
os.environ["PERSIST_DIRECTORY"] = path
73-
os.environ["CHROMA_SERVER_NOFILE"] = "65535"
74-
os.environ["CHROMA_CLI"] = "True"
75-
76-
# get the path where chromadb is installed
77-
chromadb_path = os.path.dirname(os.path.realpath(__file__))
78-
79-
# this is the path of the CLI, we want to move up one directory
80-
chromadb_path = os.path.dirname(chromadb_path)
81-
log_config = set_log_file_path(f"{chromadb_path}/log_config.yml", f"{log_path}")
82-
config = {
83-
"app": "chromadb.app:app",
84-
"host": host,
85-
"port": port,
86-
"workers": 1,
87-
"log_config": log_config, # Pass the modified log_config dictionary
88-
"timeout_keep_alive": 30,
89-
}
90-
91-
if test:
92-
return
93-
94-
uvicorn.run(**config)
47+
cli_args = ["chroma", "run"]
48+
cli_args.extend(build_cli_args(
49+
path=path,
50+
host=host,
51+
port=port,
52+
))
53+
chromadb_rust_bindings.cli(cli_args)
9554

9655

9756
@utils_app.command() # type: ignore

chromadb/test/test_cli.py

+22-14
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
import multiprocessing
22
import multiprocessing.context
3+
import os
4+
import time
35
from multiprocessing.synchronize import Event
46

57
from typer.testing import CliRunner
68

9+
import chromadb
710
from chromadb.api.client import Client
811
from chromadb.api.models.Collection import Collection
9-
from chromadb.cli.cli import app
12+
from chromadb.cli.cli import app, build_cli_args
1013
from chromadb.cli.utils import set_log_file_path
1114
from chromadb.config import Settings, System
1215
from chromadb.db.base import get_sql
@@ -18,21 +21,26 @@
1821

1922
runner = CliRunner()
2023

24+
def start_app(args: list[str]) -> None:
25+
runner.invoke(app, args)
2126

2227
def test_app() -> None:
23-
result = runner.invoke(
24-
app,
25-
[
26-
"run",
27-
"--path",
28-
"chroma_test_data",
29-
"--port",
30-
"8001",
31-
"--test",
32-
],
33-
)
34-
assert "chroma_test_data" in result.stdout
35-
assert "8001" in result.stdout
28+
kwargs = {"path": "chroma_test_data", "port": 8001}
29+
args = ["run"]
30+
args.extend(build_cli_args(**kwargs))
31+
print(args)
32+
server_process = multiprocessing.Process(target=start_app, args=(args,))
33+
server_process.start()
34+
time.sleep(5)
35+
36+
host = os.getenv("CHROMA_SERVER_HOST", kwargs.get("host", "localhost"))
37+
port = os.getenv("CHROMA_SERVER_HTTP_PORT", kwargs.get("port", 8000))
38+
39+
client = chromadb.HttpClient(host=host, port=port)
40+
heartbeat = client.heartbeat()
41+
server_process.terminate()
42+
server_process.join()
43+
assert heartbeat > 0
3644

3745

3846
def test_utils_set_log_file_path() -> None:

rust/cli/src/lib.rs

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
mod commands;
2+
mod utils;
3+
4+
use crate::commands::run::{run, RunArgs};
5+
use clap::{Parser, Subcommand};
6+
7+
#[derive(Subcommand, Debug)]
8+
enum Command {
9+
Docs,
10+
Run(RunArgs),
11+
Support,
12+
}
13+
14+
#[derive(Parser, Debug)]
15+
#[command(name = "chroma")]
16+
#[command(version = "1.0.0")]
17+
#[command(about = "A CLI for Chroma", long_about = None)]
18+
struct Cli {
19+
#[command(subcommand)]
20+
command: Command,
21+
}
22+
23+
pub fn chroma_cli(args: Vec<String>) {
24+
let cli = Cli::parse_from(args);
25+
26+
match cli.command {
27+
Command::Docs => {
28+
let url = "https://docs.trychroma.com";
29+
if webbrowser::open(url).is_err() {
30+
eprintln!("Error: Failed to open the browser. Visit {}.", url);
31+
}
32+
}
33+
Command::Run(args) => {
34+
run(args);
35+
}
36+
Command::Support => {
37+
let url = "https://discord.gg/MMeYNTmh3x";
38+
if webbrowser::open(url).is_err() {
39+
eprintln!("Error: Failed to open the browser. Visit {}.", url);
40+
}
41+
}
42+
}
43+
}

rust/cli/src/main.rs

+5-38
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,9 @@
1-
mod commands;
2-
mod utils;
3-
use crate::commands::run::{run, RunArgs};
4-
use clap::{Parser, Subcommand};
1+
#![windows_subsystem = "console"]
52

6-
#[derive(Subcommand, Debug)]
7-
enum Command {
8-
Docs,
9-
Run(RunArgs),
10-
Support,
11-
}
12-
13-
#[derive(Parser, Debug)]
14-
#[command(name = "chroma")]
15-
#[command(version = "0.0.1")]
16-
#[command(about = "A CLI for Chroma", long_about = None)]
17-
struct Cli {
18-
#[command(subcommand)]
19-
command: Command,
20-
}
3+
use chroma_cli::chroma_cli;
4+
use std::env;
215

226
fn main() {
23-
let cli = Cli::parse();
24-
25-
match cli.command {
26-
Command::Docs => {
27-
let url = "https://docs.trychroma.com";
28-
if webbrowser::open(url).is_err() {
29-
eprintln!("Error: Failed to open the browser. Visit {}.", url);
30-
}
31-
}
32-
Command::Run(args) => {
33-
run(args);
34-
}
35-
Command::Support => {
36-
let url = "https://discord.gg/MMeYNTmh3x";
37-
if webbrowser::open(url).is_err() {
38-
eprintln!("Error: Failed to open the browser. Visit {}.", url);
39-
}
40-
}
41-
}
7+
let args: Vec<String> = env::args().collect();
8+
chroma_cli(args)
429
}

rust/python_bindings/Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -28,5 +28,6 @@ chroma-cache = { workspace = true }
2828
chroma-system = { workspace = true }
2929
chroma-types = { workspace = true, features = ["pyo3"] }
3030
chroma-error = { workspace = true, features = ["validator"] }
31+
chroma-cli = { workspace = true }
3132
mdac = { workspace = true }
3233

rust/python_bindings/src/bindings.rs

+18-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use crate::errors::{ChromaPyResult, WrappedPyErr, WrappedSerdeJsonError, WrappedUuidError};
22
use chroma_cache::FoyerCacheConfig;
3+
use chroma_cli::chroma_cli;
34
use chroma_config::{registry::Registry, Configurable};
45
use chroma_frontend::{
56
executor::config::{ExecutorConfig, LocalExecutorConfig},
@@ -22,7 +23,9 @@ use chroma_types::{
2223
ListCollectionsRequest, ListDatabasesRequest, Metadata, QueryResponse, UpdateCollectionRequest,
2324
UpdateMetadata,
2425
};
25-
use pyo3::{exceptions::PyValueError, pyclass, pymethods, types::PyAnyMethods, PyObject, Python};
26+
use pyo3::{
27+
exceptions::PyValueError, pyclass, pyfunction, pymethods, types::PyAnyMethods, PyObject, Python,
28+
};
2629
use std::time::SystemTime;
2730

2831
const DEFAULT_DATABASE: &str = "default_database";
@@ -48,6 +51,20 @@ impl PythonBindingsConfig {
4851
}
4952
}
5053

54+
#[pyfunction]
55+
#[pyo3(signature = (py_args=None))]
56+
#[allow(dead_code)]
57+
pub fn cli(py_args: Option<Vec<String>>) -> ChromaPyResult<()> {
58+
let args = py_args.unwrap_or_else(|| std::env::args().collect());
59+
let args = if args.is_empty() {
60+
vec!["chroma".to_string()]
61+
} else {
62+
args
63+
};
64+
chroma_cli(args);
65+
Ok(())
66+
}
67+
5168
//////////////////////// PyMethods Implementation ////////////////////////
5269
#[pymethods]
5370
impl Bindings {

rust/python_bindings/src/lib.rs

+3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
mod bindings;
22
mod errors;
33

4+
use crate::bindings::cli;
45
use bindings::{Bindings, PythonBindingsConfig};
56
use pyo3::prelude::*;
67

@@ -21,6 +22,8 @@ fn chromadb_rust_bindings(m: &Bound<'_, PyModule>) -> PyResult<()> {
2122
m.add_class::<MigrationMode>()?;
2223
m.add_class::<MigrationHash>()?;
2324

25+
m.add_function(wrap_pyfunction!(cli, m)?)?;
26+
2427
// Log config classes
2528
// TODO
2629
Ok(())

0 commit comments

Comments
 (0)